From c1379a011627df9d6b227b4d9c5581c764e58ac8 Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 15 Feb 2023 23:38:27 -0800 Subject: [PATCH] Added RL features --- celeste/celeste.py | 11 +- celeste/main.py | 330 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 282 insertions(+), 59 deletions(-) diff --git a/celeste/celeste.py b/celeste/celeste.py index 800954c..28b530a 100755 --- a/celeste/celeste.py +++ b/celeste/celeste.py @@ -2,6 +2,7 @@ import subprocess import time import threading import math +from tqdm import tqdm class CelesteError(Exception): pass @@ -51,7 +52,6 @@ class Celeste: # Initialize variables self.internal_status = {} - self.dead = False # Score system self.frame_counter = 0 @@ -173,7 +173,8 @@ class Celeste: self.keypress("Escape") self.keystring("run") self.keypress("Enter", post = 1000) - self.dead = False + + self.flush_reader() def flush_reader(self): for k in iter(self.process.stdout.readline, ""): @@ -186,7 +187,10 @@ class Celeste: # Get state, call callback, wait for state # One line => one frame. - for line in iter(self.process.stdout.readline, ""): + it = iter(self.process.stdout.readline, "") + + + for line in it: l = line.decode("utf-8")[:-1].strip() # This should only occur at game start @@ -215,6 +219,7 @@ class Celeste: ) if dist <= 4 and y == ty: + print(f"Got point {self.next_point}") self.next_point += 1 # Recalculate distance to new point diff --git a/celeste/main.py b/celeste/main.py index 8a4c582..6379b8d 100644 --- a/celeste/main.py +++ b/celeste/main.py @@ -5,7 +5,6 @@ import math import torch - # Glue layer from celeste import Celeste @@ -15,6 +14,19 @@ compute_device = torch.device( ) +state_number_map = [ + "xpos", + "ypos", + "xvel", + "yvel", + "next_point" +] + + +# Celeste env properties +n_observations = len(state_number_map) +n_actions = len(Celeste.action_space) + # Epsilon-greedy parameters # @@ -27,6 +39,27 @@ EPS_END = 0.05 EPS_DECAY = 1000 +BATCH_SIZE = 128 +# Learning rate of target_net. +# Controls how soft our soft update is. +# +# Should be between 0 and 1. +# Large values +# Small values do the opposite. +# +# A value of one makes target_net +# change at the same rate as policy_net. +# +# A value of zero makes target_net +# not change at all. +TAU = 0.005 + + +# GAMMA is the discount factor as mentioned in the previous section +GAMMA = 0.99 + + + # Outline our network class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): @@ -50,15 +83,39 @@ class DQN(torch.nn.Module): -# Celeste env properties -n_observations = 4 -n_actions = len(Celeste.action_space) +steps_done = 0 + +num_episodes = 100 + + +# Create replay memory. +# +# Transition: a container for naming data (defined in util.py) +# Memory: a deque that holds recent states as Transitions +# Has a fixed length, drops oldest +# element if maxlen is exceeded. +memory = deque([], maxlen=10_000) + policy_net = DQN( n_observations, n_actions ).to(compute_device) +target_net = DQN( + n_observations, + n_actions +).to(compute_device) + +target_net.load_state_dict(policy_net.state_dict()) + + +optimizer = torch.optim.AdamW( + policy_net.parameters(), + lr = 1e-4, # Hyperparameter: learning rate + amsgrad = True +) + def select_action(state, steps_done): """ @@ -107,68 +164,229 @@ Transition = namedtuple( ) -def on_state(celeste): - global last_state - - s = celeste.status - - if last_state is None: - last_state = s - return - - s_next = s["next_point"] - s_dist = s["dist"] - l_next = last_state["next_point"] - l_dist = last_state["dist"] - - - if l_next == s_next: - reward = l_dist - s_dist - else: - reward = 10 - - dead = s["deaths"] != 0 - frame_count = s["frame_count"] - - # Values at this point - # reward: reward for last action - # dead: true if game over - - state_number_map = [ - "xpos", - "ypos", - "xvel", - "yvel" - ] - - tf_state = torch.tensor( - [s[x] for x in state_number_map], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) - - tf_last = torch.tensor( - [last_state[x] for x in state_number_map], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) - action = select_action( - tf_state, - frame_count + +def optimize_model(): + + if len(memory) < BATCH_SIZE: + raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") + + + + # Get a random sample of transitions + batch = random.sample(memory, BATCH_SIZE) + + # Conversion. + # Transposes batch, turning an array of Transitions + # into a Transition of arrays. + batch = Transition(*zip(*batch)) + + # Conversion. + # Combine states, actions, and rewards into their own tensors. + state_batch = torch.cat(batch.state) + action_batch = torch.cat(batch.action) + reward_batch = torch.cat(batch.reward) + + + + # Compute a mask of non_final_states. + # Each element of this tensor corresponds to an element in the batch. + # True if this is a final state, False if it is. + # + # We use this to select non-final states later. + non_final_mask = torch.tensor( + tuple(map( + lambda s: s is not None, + batch.next_state + )) ) + non_final_next_states = torch.cat( + [s for s in batch.next_state if s is not None] + ) + + + + # How .gather works: + # if out = a.gather(1, b), + # out[i, j] = a[ i ][ b[i,j] ] + # + # a is "input," b is "index" + # If this doesn't make sense, RTFD. + + # Compute Q(s_t, a). + # - Use policy_net to compute Q(s_t) for each state in the batch. + # This gives a tensor of [ Q(state, left), Q(state, right) ] + # + # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] + # listing the action that was taken in each transition. + # 0 => we went left, 1 => we went right. + # + # This aligns nicely with the output of policy_net. We use + # action_batch to index the output of policy_net's prediction. + # + # This gives us a tensor that contains the return we expect to get + # at that state if we follow the model's advice. + + state_action_values = policy_net(state_batch).gather(1, action_batch) + + + + # Compute V(s_t+1) for all next states. + # V(s_t+1) = max_a ( Q(s_t+1, a) ) + # = the maximum reward over all possible actions at state s_t+1. + next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) + + # Don't compute gradient for operations in this block. + # If you don't understand what this means, RTFD. + with torch.no_grad(): + + # Note the use of non_final_mask here. + # States that are final do not have their reward set by the line + # below, so their reward stays at zero. + # + # States that are not final get their predicted value + # set to the best value the model predicts. + # + # + # Expected values of action are selected with the "older" target net, + # and their best reward (over possible actions) is selected with max(1)[0]. + + next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] + + + + # TODO: What does this mean? + # "Compute expected Q values" + expected_state_action_values = reward_batch + (next_state_values * GAMMA) + + + + # Compute Huber loss between predicted reward and expected reward. + # Pytorch is will account for this when we compute the gradient of loss. + # + # loss is a single-element tensor (i.e, a scalar). + criterion = torch.nn.SmoothL1Loss() + loss = criterion( + state_action_values, + expected_state_action_values.unsqueeze(1) + ) + + + + # We can now run a step of backpropagation on our model. + + # TODO: what does this do? + # + # Calling .backward() multiple times will accumulate parameter gradients. + # Thus, we reset the gradient before each step. + optimizer.zero_grad() + + # Compute the gradient of loss wrt... something? + # TODO: what does this do, we never use loss again?! + loss.backward() + + + # Prevent vanishing and exploding gradients. + # Forces gradients to be in [-clip_value, +clip_value] + torch.nn.utils.clip_grad_value_( # type: ignore + policy_net.parameters(), + clip_value = 100 + ) + + # Perform a single optimizer step. + # + # Uses the current gradient, which is stored + # in the .grad attribute of the parameter. + optimizer.step() + + +def on_state(celeste): + global steps_done + + # Conversion to pytorch + + state = celeste.status + + pt_state = torch.tensor( + [state[x] for x in state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + action = select_action( + pt_state, + steps_done + ) + steps_done += 1 + # Turn number into action string - action = Celeste.action_space[action] + str_action = Celeste.action_space[action] + pt_action = torch.tensor( + [[ action ]], + device = compute_device, + dtype = torch.long + ) - celeste.act(action) + celeste.act(str_action) + + next_state = celeste.status + + if next_state["deaths"] != 0: + pt_next_state = None + reward = 0 + + else: + pt_next_state = torch.tensor( + [next_state[x] for x in state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + + if state["next_point"] == next_state["next_point"]: + reward = state["dist"] - next_state["dist"] + else: + # Score for reaching a point + reward = 10 + + pt_reward = torch.tensor([reward], device = compute_device) + + + # Add this state transition to memory. + memory.append( + Transition( + pt_state, # last state + pt_action, + pt_next_state, # next state + pt_reward + ) + ) - # Update previous state - last_state = s + # Only train the network if we have enough + # transitions in memory to do so. + if len(memory) >= BATCH_SIZE: + optimize_model() + + # Soft update target_net weights + target_net_state = target_net.state_dict() + policy_net_state = policy_net.state_dict() + for key in policy_net_state: + target_net_state[key] = ( + policy_net_state[key] * TAU + + target_net_state[key] * (1-TAU) + ) + target_net.load_state_dict(target_net_state) + + # Move on to the next episode once we reach + # a terminal state. + if (next_state["deaths"] != 0): + print("State over, resetting") + celeste.reset()