2023-02-15 22:24:40 -08:00
|
|
|
from collections import namedtuple
|
|
|
|
from collections import deque
|
|
|
|
import random
|
2023-02-15 19:24:19 -08:00
|
|
|
import math
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
import torch
|
2023-02-15 19:24:19 -08:00
|
|
|
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
# Glue layer
|
|
|
|
from celeste import Celeste
|
|
|
|
|
|
|
|
|
|
|
|
compute_device = torch.device(
|
|
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Epsilon-greedy parameters
|
|
|
|
#
|
|
|
|
# Original docs:
|
|
|
|
# EPS_START is the starting value of epsilon
|
|
|
|
# EPS_END is the final value of epsilon
|
|
|
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
|
|
EPS_START = 0.9
|
|
|
|
EPS_END = 0.05
|
|
|
|
EPS_DECAY = 1000
|
|
|
|
|
|
|
|
|
|
|
|
# Outline our network
|
|
|
|
class DQN(torch.nn.Module):
|
|
|
|
def __init__(self, n_observations: int, n_actions: int):
|
|
|
|
super(DQN, self).__init__()
|
|
|
|
self.layer1 = torch.nn.Linear(n_observations, 128)
|
|
|
|
self.layer2 = torch.nn.Linear(128, 128)
|
|
|
|
self.layer3 = torch.nn.Linear(128, n_actions)
|
|
|
|
|
|
|
|
# Can be called with one input, or with a batch.
|
|
|
|
#
|
|
|
|
# Returns tensor(
|
|
|
|
# [ Q(s, left), Q(s, right) ], ...
|
|
|
|
# )
|
|
|
|
#
|
|
|
|
# Recall that Q(s, a) is the (expected) return of taking
|
|
|
|
# action `a` at state `s`
|
|
|
|
def forward(self, x):
|
|
|
|
x = torch.nn.functional.relu(self.layer1(x))
|
|
|
|
x = torch.nn.functional.relu(self.layer2(x))
|
|
|
|
return self.layer3(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Celeste env properties
|
|
|
|
n_observations = 4
|
|
|
|
n_actions = len(Celeste.action_space)
|
|
|
|
|
|
|
|
policy_net = DQN(
|
|
|
|
n_observations,
|
|
|
|
n_actions
|
|
|
|
).to(compute_device)
|
|
|
|
|
|
|
|
|
|
|
|
def select_action(state, steps_done):
|
|
|
|
"""
|
|
|
|
Select an action using an epsilon-greedy policy.
|
|
|
|
|
|
|
|
Sometimes use our model, sometimes sample one uniformly.
|
|
|
|
|
|
|
|
P(random action) starts at EPS_START and decays to EPS_END.
|
|
|
|
Decay rate is controlled by EPS_DECAY.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Random number 0 <= x < 1
|
|
|
|
sample = random.random()
|
|
|
|
|
|
|
|
# Calculate random step threshhold
|
|
|
|
eps_threshold = (
|
|
|
|
EPS_END + (EPS_START - EPS_END) *
|
|
|
|
math.exp(
|
|
|
|
-1.0 * steps_done /
|
|
|
|
EPS_DECAY
|
2023-02-15 19:24:19 -08:00
|
|
|
)
|
2023-02-15 22:24:40 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
if sample > eps_threshold:
|
|
|
|
with torch.no_grad():
|
|
|
|
# t.max(1) will return the largest column value of each row.
|
|
|
|
# second column on max result is index of where max element was
|
|
|
|
# found, so we pick action with the larger expected reward.
|
|
|
|
return policy_net(state).max(1)[1].view(1, 1).item()
|
|
|
|
|
|
|
|
else:
|
|
|
|
return random.randint( 0, n_actions-1 )
|
|
|
|
|
|
|
|
|
|
|
|
last_state = None
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
Transition = namedtuple(
|
|
|
|
"Transition",
|
|
|
|
(
|
|
|
|
"state",
|
|
|
|
"action",
|
|
|
|
"next_state",
|
|
|
|
"reward"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def on_state(celeste):
|
|
|
|
global last_state
|
|
|
|
|
|
|
|
s = celeste.status
|
|
|
|
|
|
|
|
if last_state is None:
|
|
|
|
last_state = s
|
|
|
|
return
|
|
|
|
|
|
|
|
s_next = s["next_point"]
|
|
|
|
s_dist = s["dist"]
|
|
|
|
l_next = last_state["next_point"]
|
|
|
|
l_dist = last_state["dist"]
|
|
|
|
|
|
|
|
|
|
|
|
if l_next == s_next:
|
|
|
|
reward = l_dist - s_dist
|
|
|
|
else:
|
|
|
|
reward = 10
|
|
|
|
|
|
|
|
dead = s["deaths"] != 0
|
|
|
|
frame_count = s["frame_count"]
|
|
|
|
|
|
|
|
# Values at this point
|
|
|
|
# reward: reward for last action
|
|
|
|
# dead: true if game over
|
|
|
|
|
|
|
|
state_number_map = [
|
|
|
|
"xpos",
|
|
|
|
"ypos",
|
|
|
|
"xvel",
|
|
|
|
"yvel"
|
|
|
|
]
|
|
|
|
|
|
|
|
tf_state = torch.tensor(
|
|
|
|
[s[x] for x in state_number_map],
|
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
tf_last = torch.tensor(
|
|
|
|
[last_state[x] for x in state_number_map],
|
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
action = select_action(
|
|
|
|
tf_state,
|
|
|
|
frame_count
|
2023-02-15 19:24:19 -08:00
|
|
|
)
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
# Turn number into action string
|
|
|
|
action = Celeste.action_space[action]
|
|
|
|
|
|
|
|
celeste.act(action)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Update previous state
|
|
|
|
last_state = s
|
|
|
|
|
2023-02-15 19:24:19 -08:00
|
|
|
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
c = Celeste(
|
|
|
|
on_state
|
|
|
|
)
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
c.update_loop()
|