diff --git a/celeste/celeste_ai/celeste.py b/celeste/celeste_ai/celeste.py index 9988080..c7db9d8 100755 --- a/celeste/celeste_ai/celeste.py +++ b/celeste/celeste_ai/celeste.py @@ -3,6 +3,8 @@ import subprocess import time import math +import numpy as np + class CelesteError(Exception): pass @@ -12,8 +14,12 @@ class CelesteState(NamedTuple): stage: int # Player position + # Regular position has 0,0 in top-left, + # centered position has 0,0 in center. xpos: int ypos: int + xpos_scaled: float + ypos_scaled: float # Player velocity xvel: float @@ -37,28 +43,47 @@ class CelesteState(NamedTuple): # True if Madeline can dash can_dash: bool + can_dash_int: int class Celeste: action_space = [ - "left", # move left - "right", # move right - "jump", # jump + "left", # move left 0 + "right", # move right 1 + #"jump", # jump + "jump-l", # jump left 2 + "jump-r", # jump right 3 - "dash-u", # dash up - "dash-r", # dash right - "dash-l", # dash left - "dash-ru", # dash right-up - "dash-lu" # dash left-up + "dash-u", # dash up 4 + "dash-r", # dash right 5 + "dash-l", # dash left 6 + "dash-ru", # dash right-up 7 + "dash-lu" # dash left-up 8 ] # Map integers to state values. # This also determines what data is fed to the model. state_number_map = [ - "xpos", - "ypos", - "next_point_x", - "next_point_y" + #"xpos", + #"ypos", + "xpos_scaled", + "ypos_scaled", + "can_dash_int" + #"next_point_x", + #"next_point_y" + ] + + # Targets the agent tries to reach. + # The last target MUST be outside the frame. + target_checkpoints = [ + [ # Stage 1 + #(28, 88), # Start pillar + (60, 80), # Middle pillar + (105, 64), # Right ledge + (25, 40), # Left ledge + (110, 16), # End ledge + (110, -2), # Next stage + ] ] def __init__( @@ -110,19 +135,6 @@ class Celeste: self._resetting = False # True between a call to .reset() and the first state message from pico. self._keys = {} # Dictionary of "key": bool - # Targets the agent tries to reach. - # The last target MUST be outside the frame. - self.target_checkpoints = [ - [ # Stage 1 - #(28, 88), # Start pillar - (60, 80), # Middle pillar - (105, 64), # Right ledge - (25, 40), # Left ledge - (110, 16), # End ledge - (110, -2), # Next stage - ] - ] - def act(self, action: str): """ Specify what keys should be down. This does NOT send key events. @@ -141,6 +153,12 @@ class Celeste: self._keys["Right"] = True elif action == "jump": self._keys["c"] = True + elif action == "jump-l": + self._keys["c"] = True + self._keys["Left"] = True + elif action == "jump-r": + self._keys["c"] = True + self._keys["Right"] = True elif action == "dash-u": self._keys["Up"] = True @@ -183,12 +201,12 @@ class Celeste: [int(self._internal_state["rx"])] ) - if len(self.target_checkpoints) < stage: + if len(Celeste.target_checkpoints) < stage: next_point_x = None next_point_y = None else: - next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0] - next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1] + next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0] + next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1] return CelesteState( @@ -196,6 +214,8 @@ class Celeste: xpos = int(self._internal_state["px"]), ypos = int(self._internal_state["py"]), + xpos_scaled = int(self._internal_state["px"]) / 128.0, + ypos_scaled = int(self._internal_state["py"]) / 128.0, xvel = float(self._internal_state["vx"]), yvel = float(self._internal_state["vy"]), deaths = int(self._internal_state["dc"]), @@ -205,7 +225,8 @@ class Celeste: next_point_x = next_point_x, next_point_y = next_point_y, state_count = self._state_counter, - can_dash = self._internal_state["ds"] == "t" + can_dash = self._internal_state["ds"] == "t", + can_dash_int = 1 if self._internal_state["ds"] == "t" else 0 ) except KeyError: @@ -299,33 +320,46 @@ class Celeste: self._internal_state[key] = val - # Update checkpoints - tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx] + + + # Calculate distance to each point x = self.state.xpos y = self.state.ypos - dist = math.sqrt( - (x-tx)*(x-tx) + - ((y-ty)*(y-ty))/2 - # Possible modification: - # make x-distance twice as valuable as y-distance - ) + dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16) + for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]): + if i < self._next_checkpoint_idx: + dist[i] = 1000 + continue - if dist <= 5: - print(f"Got point {self._next_checkpoint_idx}") - self._next_checkpoint_idx += 1 + # Update checkpoints + tx, ty = c + dist[i] = (math.sqrt( + (x-tx)*(x-tx) + + ((y-ty)*(y-ty))/2 + # Possible modification: + # make x-distance twice as valuable as y-distance + )) + min_idx = int(dist.argmin()) + dist = int(dist[min_idx]) + + + if dist <= 8: + print(f"Got point {min_idx}") + self._next_checkpoint_idx = min_idx + 1 self._last_checkpoint_state = self._state_counter # Recalculate distance to new point - tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx] + tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx] dist = math.sqrt( (x-tx)*(x-tx) + ((y-ty)*(y-ty))/2 ) - + # Timeout if we spend too long between points elif self._state_counter - self._last_checkpoint_state > self.state_timeout: self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1) + self._dist = dist # Call step callbacks diff --git a/celeste/celeste_ai/train.py b/celeste/celeste_ai/train.py index adb093f..32ea61b 100644 --- a/celeste/celeste_ai/train.py +++ b/celeste/celeste_ai/train.py @@ -24,6 +24,12 @@ if __name__ == "__main__": screenshot_dir.mkdir(parents = True, exist_ok = True) + # Remove old screenshots + shots = Path("/home/mark/Desktop").glob("hackcel_*.png") + for s in shots: + s.unlink() + + compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) @@ -41,11 +47,15 @@ if __name__ == "__main__": # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9 - EPS_END = 0.05 - EPS_DECAY = 4000 + EPS_END = 0.02 + EPS_DECAY = 100 + # How many times we've reached each point. + # Used to compute epsilon-greedy probability with + # the parameters above. + point_counter = [0] * len(Celeste.target_checkpoints[0]) - BATCH_SIZE = 1_000 + BATCH_SIZE = 100 # Learning rate of target_net. # Controls how soft our soft update is. # @@ -58,7 +68,7 @@ if __name__ == "__main__": # # A value of zero makes target_net # not change at all. - TAU = 0.005 + TAU = 0.05 # GAMMA is the discount factor as mentioned in the previous section @@ -90,9 +100,10 @@ if __name__ == "__main__": target_net.load_state_dict(policy_net.state_dict()) + learning_rate = 0.001 optimizer = torch.optim.AdamW( policy_net.parameters(), - lr = 0.01, # Hyperparameter: learning rate + lr = learning_rate, amsgrad = True ) @@ -109,6 +120,7 @@ if __name__ == "__main__": memory = checkpoint["memory"] episode_number = checkpoint["episode_number"] + 1 steps_done = checkpoint["steps_done"] + point_counter = checkpoint["point_counter"] def select_action(state, steps_done): """ @@ -144,7 +156,6 @@ def select_action(state, steps_done): def optimize_model(): - if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") @@ -189,19 +200,8 @@ def optimize_model(): # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" - # If this doesn't make sense, RTFD. # Compute Q(s_t, a). - # - Use policy_net to compute Q(s_t) for each state in the batch. - # This gives a tensor of [ Q(state, left), Q(state, right) ] - # - # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] - # listing the action that was taken in each transition. - # 0 => we went left, 1 => we went right. - # - # This aligns nicely with the output of policy_net. We use - # action_batch to index the output of policy_net's prediction. - # # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. @@ -214,8 +214,7 @@ def optimize_model(): # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) - # Don't compute gradient for operations in this block. - # If you don't understand what this means, RTFD. + with torch.no_grad(): # Note the use of non_final_mask here. @@ -291,6 +290,15 @@ def on_state_before(celeste): device = compute_device ).unsqueeze(0) + + action = select_action( + pt_state, + point_counter[state.next_point] + ) + str_action = Celeste.action_space[action] + + + """ action = None while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): action = select_action( @@ -298,6 +306,8 @@ def on_state_before(celeste): steps_done ) str_action = Celeste.action_space[action] + """ + steps_done += 1 @@ -343,37 +353,37 @@ def on_state_after(celeste, before_out): ).unsqueeze(0) + if state.next_point == next_state.next_point: - reward = state.dist - next_state.dist - - # Clip rewards that are too large - if reward > 1: - reward = 1 - else: - reward = 0 - + reward = 0 else: # Reward for reaching a point - reward = 1 + reward = next_state.next_point - state.next_point + # Add to point counter + for i in range(state.next_point, state.next_point + reward): + point_counter[i] += 1 + + reward = reward * 10 pt_reward = torch.tensor([reward], device = compute_device) # Add this state transition to memory. memory.append( Transition( - pt_state, # last state + pt_state, pt_action, - pt_next_state, # next state + pt_next_state, pt_reward ) ) - print("==> ", int(reward)) + print("==> ", reward) print("") loss = None + # Only train the network if we have enough # transitions in memory to do so. if len(memory) >= BATCH_SIZE: @@ -399,7 +409,7 @@ def on_state_after(celeste, before_out): "state_count": s.state_count, "loss": None if loss is None else loss.item() }) + "\n") - + # Save model torch.save({ @@ -407,8 +417,18 @@ def on_state_after(celeste, before_out): "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "memory": memory, + "point_counter": point_counter, "episode_number": episode_number, - "steps_done": steps_done + "steps_done": steps_done, + + # Hyperparameters + "eps_start": EPS_START, + "eps_end": EPS_END, + "eps_decay": EPS_DECAY, + "batch_size": BATCH_SIZE, + "tau": TAU, + "learning_rate": learning_rate, + "gamma": GAMMA }, model_save_path) @@ -421,7 +441,7 @@ def on_state_after(celeste, before_out): for s in shots: s.rename(target / s.name) - # Save a prediction graph + # Save a snapshot if episode_number % archive_interval == 0: torch.save({ "policy_state_dict": policy_net.state_dict(),