from collections import namedtuple from collections import deque import random import math import torch # Glue layer from celeste import Celeste compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Epsilon-greedy parameters # # Original docs: # EPS_START is the starting value of epsilon # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 1000 # Outline our network class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layer1 = torch.nn.Linear(n_observations, 128) self.layer2 = torch.nn.Linear(128, 128) self.layer3 = torch.nn.Linear(128, n_actions) # Can be called with one input, or with a batch. # # Returns tensor( # [ Q(s, left), Q(s, right) ], ... # ) # # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): x = torch.nn.functional.relu(self.layer1(x)) x = torch.nn.functional.relu(self.layer2(x)) return self.layer3(x) # Celeste env properties n_observations = 4 n_actions = len(Celeste.action_space) policy_net = DQN( n_observations, n_actions ).to(compute_device) def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Random number 0 <= x < 1 sample = random.random() # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * math.exp( -1.0 * steps_done / EPS_DECAY ) ) if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1).item() else: return random.randint( 0, n_actions-1 ) last_state = None Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) def on_state(celeste): global last_state s = celeste.status if last_state is None: last_state = s return s_next = s["next_point"] s_dist = s["dist"] l_next = last_state["next_point"] l_dist = last_state["dist"] if l_next == s_next: reward = l_dist - s_dist else: reward = 10 dead = s["deaths"] != 0 frame_count = s["frame_count"] # Values at this point # reward: reward for last action # dead: true if game over state_number_map = [ "xpos", "ypos", "xvel", "yvel" ] tf_state = torch.tensor( [s[x] for x in state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) tf_last = torch.tensor( [last_state[x] for x in state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = select_action( tf_state, frame_count ) # Turn number into action string action = Celeste.action_space[action] celeste.act(action) # Update previous state last_state = s c = Celeste( on_state ) c.update_loop()