from collections import namedtuple from collections import deque import random import math import torch # Glue layer from celeste import Celeste compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) state_number_map = [ "xpos", "ypos", "xvel", "yvel", "next_point" ] # Celeste env properties n_observations = len(state_number_map) n_actions = len(Celeste.action_space) # Epsilon-greedy parameters # # Original docs: # EPS_START is the starting value of epsilon # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 1000 BATCH_SIZE = 128 # Learning rate of target_net. # Controls how soft our soft update is. # # Should be between 0 and 1. # Large values # Small values do the opposite. # # A value of one makes target_net # change at the same rate as policy_net. # # A value of zero makes target_net # not change at all. TAU = 0.005 # GAMMA is the discount factor as mentioned in the previous section GAMMA = 0.99 # Outline our network class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layer1 = torch.nn.Linear(n_observations, 128) self.layer2 = torch.nn.Linear(128, 128) self.layer3 = torch.nn.Linear(128, n_actions) # Can be called with one input, or with a batch. # # Returns tensor( # [ Q(s, left), Q(s, right) ], ... # ) # # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): x = torch.nn.functional.relu(self.layer1(x)) x = torch.nn.functional.relu(self.layer2(x)) return self.layer3(x) steps_done = 0 num_episodes = 100 # Create replay memory. # # Transition: a container for naming data (defined in util.py) # Memory: a deque that holds recent states as Transitions # Has a fixed length, drops oldest # element if maxlen is exceeded. memory = deque([], maxlen=10_000) policy_net = DQN( n_observations, n_actions ).to(compute_device) target_net = DQN( n_observations, n_actions ).to(compute_device) target_net.load_state_dict(policy_net.state_dict()) optimizer = torch.optim.AdamW( policy_net.parameters(), lr = 1e-4, # Hyperparameter: learning rate amsgrad = True ) def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Random number 0 <= x < 1 sample = random.random() # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * math.exp( -1.0 * steps_done / EPS_DECAY ) ) if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1).item() else: return random.randint( 0, n_actions-1 ) last_state = None Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) def optimize_model(): if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(memory, BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it is. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # If this doesn't make sense, RTFD. # Compute Q(s_t, a). # - Use policy_net to compute Q(s_t) for each state in the batch. # This gives a tensor of [ Q(state, left), Q(state, right) ] # # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] # listing the action that was taken in each transition. # 0 => we went left, 1 => we went right. # # This aligns nicely with the output of policy_net. We use # action_batch to index the output of policy_net's prediction. # # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) # Don't compute gradient for operations in this block. # If you don't understand what this means, RTFD. with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = torch.nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. optimizer.step() def on_state_before(celeste): global steps_done # Conversion to pytorch state = celeste.status pt_state = torch.tensor( [state[x] for x in state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = select_action( pt_state, steps_done ) steps_done += 1 # Turn number into action string str_action = Celeste.action_space[action] celeste.act(str_action) return state, action def on_state_after(celeste, before_out): state, action = before_out pt_state = torch.tensor( [state[x] for x in state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) pt_action = torch.tensor( [[ action ]], device = compute_device, dtype = torch.long ) next_state = celeste.status if next_state["deaths"] != 0: pt_next_state = None reward = 0 else: pt_next_state = torch.tensor( [next_state[x] for x in state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) if state["next_point"] == next_state["next_point"]: reward = state["dist"] - next_state["dist"] if reward > 0: reward = 1 elif reward < 0: reward = -1 else: reward = 0 else: # Score for reaching a point reward = 10 pt_reward = torch.tensor([reward], device = compute_device) # Add this state transition to memory. memory.append( Transition( pt_state, # last state pt_action, pt_next_state, # next state pt_reward ) ) # Only train the network if we have enough # transitions in memory to do so. if len(memory) >= BATCH_SIZE: optimize_model() # Soft update target_net weights target_net_state = target_net.state_dict() policy_net_state = policy_net.state_dict() for key in policy_net_state: target_net_state[key] = ( policy_net_state[key] * TAU + target_net_state[key] * (1-TAU) ) target_net.load_state_dict(target_net_state) # Move on to the next episode once we reach # a terminal state. if (next_state["deaths"] != 0): print("State over, resetting") celeste.reset() c = Celeste() c.update_loop( on_state_before, on_state_after )