diff --git a/celeste/celeste.py b/celeste/celeste.py index 851c30d..9b1e254 100755 --- a/celeste/celeste.py +++ b/celeste/celeste.py @@ -1,12 +1,44 @@ +from typing import NamedTuple import subprocess import time -import threading import math -from tqdm import tqdm class CelesteError(Exception): pass + +class CelesteState(NamedTuple): + # Stage number + stage: int + + # Player position + xpos: int + ypos: int + + # Player velocity + xvel: float + yvel: float + + # Number of deaths since game start + deaths: int + + # Distance to next point + dist: float + + # Index of next point + next_point: int + + # Coordinates of next point + next_point_x: int + next_point_y: int + + # Number of states recieved since restart + state_count: int + + # True if Madeline can dash + can_dash: bool + + class Celeste: action_space = [ "left", # move left @@ -20,10 +52,25 @@ class Celeste: "dash-lu" # dash left-up ] - def __init__(self): + # Map integers to state values. + # This also determines what data is fed to the model. + state_number_map = [ + "xpos", + "ypos", + "next_point_x", + "next_point_y" + ] + + def __init__( + self, + *, + state_timeout = 30, + cart_name = "hackcel.p8" + ): + # Start pico-8 - self.process = subprocess.Popen( - "bin/pico-8/linux/pico8", + self._process = subprocess.Popen( + "resources/pico-8/linux/pico8", shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT @@ -39,26 +86,34 @@ class Celeste: ]).decode("utf-8").strip().split("\n") if len(winid) != 1: raise Exception("Could not find unique PICO-8 window id") - self.winid = winid[0] + self._winid = winid[0] # Load cartridge - self.keystring("load hackcel.p8") - self.keypress("Enter") - self.keystring("run") - self.keypress("Enter", post = 1000) + self._keystring(f"load {cart_name}") + self._keypress("Enter") + self._keystring("run") + self._keypress("Enter", post = 1000) - # Initialize variables - self.internal_status = {} - self.before_out = None - self.last_point_frame = 0 - # Score system - self.frame_counter = 0 - self.next_point = 0 - self.dist = 0 # distance to next point - self.target_points = [ + # Parameters + self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset. + self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience. + + # Internal variables + self._internal_state = {} # Raw data read from stdout + self._before_out = None # Output of "before" callback in update loop + self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint + self._state_counter = 0 # Number of frames we've run since last reset + self._next_checkpoint_idx = 0 # Index of next point + self._dist = 0 # Distance to next point + self._resetting = False # True between a call to .reset() and the first state message from pico. + self._keys = {} # Dictionary of "key": bool + + # Targets the agent tries to reach. + # The last target MUST be outside the frame. + self.target_checkpoints = [ [ # Stage 1 - (28, 88), # Start pillar + #(28, 88), # Start pillar (60, 80), # Middle pillar (105, 64), # Right ledge (25, 40), # Left ledge @@ -67,119 +122,150 @@ class Celeste: ] ] - def act(self, action): - self.keyup("x") - self.keyup("c") - self.keyup("Left") - self.keyup("Right") - self.keyup("Down") - self.keyup("Up") + def act(self, action: str): + """ + Specify what keys should be down. This does NOT send key events. + Celeste._apply_keys() does that at the right time. + Args: + action (str): key name, as in Celeste.action_space + """ + + self._keys = {} if action is None: return elif action == "left": - self.keydown("Left") + self._keys["Left"] = True elif action == "right": - self.keydown("Right") + self._keys["Right"] = True elif action == "jump": - self.keydown("c") + self._keys["c"] = True elif action == "dash-u": - self.keydown("Up") - self.keydown("x") + self._keys["Up"] = True + self._keys["x"] = True elif action == "dash-r": - self.keydown("Right") - self.keydown("x") + self._keys["Right"] = True + self._keys["x"] = True elif action == "dash-l": - self.keydown("Left") - self.keydown("x") + self._keys["Left"] = True + self._keys["x"] = True elif action == "dash-ru": - self.keydown("Up") - self.keydown("Right") - self.keydown("x") + self._keys["Up"] = True + self._keys["Right"] = True + self._keys["x"] = True elif action == "dash-lu": - self.keydown("Up") - self.keydown("Left") - self.keydown("x") + self._keys["Up"] = True + self._keys["Left"] = True + self._keys["x"] = True + def _apply_keys(self): + for i in [ + "x", "c", + "Left", "Right", + "Down", "Up" + ]: + if self._keys.get(i): + self._keydown(i) + else: + self._keyup(i) + @property - def status(self): + def state(self): try: - return { - "stage": ( - [ - [0, 1, 2, 3, 4] - ] - [int(self.internal_status["ry"])] - [int(self.internal_status["rx"])] - ), + stage = ( + [ + [0, 1, 2, 3, 4] + ] + [int(self._internal_state["ry"])] + [int(self._internal_state["rx"])] + ) - "xpos": int(self.internal_status["px"]), - "ypos": int(self.internal_status["py"]), - "xvel": float(self.internal_status["vx"]), - "yvel": float(self.internal_status["vy"]), - "deaths": int(self.internal_status["dc"]), + if len(self.target_checkpoints) < stage: + next_point_x = None + next_point_y = None + else: + next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0] + next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1] + + + return CelesteState( + stage = stage, + + xpos = int(self._internal_state["px"]), + ypos = int(self._internal_state["py"]), + xvel = float(self._internal_state["vx"]), + yvel = float(self._internal_state["vy"]), + deaths = int(self._internal_state["dc"]), + + dist = self._dist, + next_point = self._next_checkpoint_idx, + next_point_x = next_point_x, + next_point_y = next_point_y, + state_count = self._state_counter, + can_dash = self._internal_state["ds"] == "t" + ) - "dist": self.dist, - "next_point": self.next_point, - "frame_count": self.frame_counter - } except KeyError: - raise CelesteError("Not enough data to get status.") + raise CelesteError("Not enough data to get state.") - - def keypress(self, key: str, *, post = 200): + def _keypress(self, key: str, *, post = 200): subprocess.run([ "xdotool", "key", - "--window", self.winid, + "--window", self._winid, key ]) time.sleep(post / 1000) - def keydown(self, key: str): + def _keydown(self, key: str): subprocess.run([ "xdotool", "keydown", - "--window", self.winid, + "--window", self._winid, key ]) - def keyup(self, key: str): + def _keyup(self, key: str): subprocess.run([ "xdotool", "keyup", - "--window", self.winid, + "--window", self._winid, key ]) - def keystring(self, string, *, delay = 100, post = 200): + def _keystring(self, string, *, delay = 100, post = 200): subprocess.run([ "xdotool", "type", - "--window", self.winid, + "--window", self._winid, "--delay", str(delay), string ]) time.sleep(post / 1000) def reset(self): - self.internal_status = {} - self.next_point = 0 - self.frame_counter = 0 - self.before_out = None - self.resetting = True - self.last_point_frame = 0 + # Make sure all keys are released + self.act(None) + self._apply_keys() - self.keypress("Escape") - self.keystring("run") - self.keypress("Enter", post = 1000) + self._internal_state = {} + self._next_checkpoint_idx = 0 + self._state_counter = 0 + self._before_out = None + self._resetting = True + self._last_checkpoint_state = 0 - self.flush_reader() + self._keypress("Escape") + self._keystring("run") + self._keypress("Enter", post = 1000) - def flush_reader(self): - for k in iter(self.process.stdout.readline, ""): + + + # Clear all old stdout messages and + # wait for the game to restart. + for k in iter(self._process.stdout.readline, ""): k = k.decode("utf-8")[:-1] if k == "!RESTART": break @@ -187,61 +273,68 @@ class Celeste: def update_loop(self, before, after): - # Get state, call callback, wait for state - # One line => one frame. - - it = iter(self.process.stdout.readline, "") - - - for line in it: + # Waits for stdout from pico-8 process + for line in iter(self._process.stdout.readline, ""): l = line.decode("utf-8")[:-1].strip() - self.resetting = False + + # Release all keys + self.act(None) + self._apply_keys() + + # Clear reset state + self._resetting = False # This should only occur at game start if l in ["!RESTART"]: continue - self.frame_counter += 1 + self._state_counter += 1 - # Parse status string + # Parse state string for entry in l.split(";"): if entry == "": continue key, val = entry.split(":") - self.internal_status[key] = val + self._internal_state[key] = val # Update checkpoints - - tx, ty = self.target_points[self.status["stage"]][self.next_point] - x = self.status["xpos"] - y = self.status["ypos"] + tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx] + x = self.state.xpos + y = self.state.ypos dist = math.sqrt( (x-tx)*(x-tx) + - (y-ty)*(y-ty) + ((y-ty)*(y-ty))/2 + # Possible modification: + # make x-distance twice as valuable as y-distance ) - if dist <= 4 and y == ty: - print(f"Got point {self.next_point}") - self.next_point += 1 - self.last_point_frame = self.frame_counter + if dist <= 5: + print(f"Got point {self._next_checkpoint_idx}") + self._next_checkpoint_idx += 1 + self._last_checkpoint_state = self._state_counter # Recalculate distance to new point - tx, ty = self.target_points[self.status["stage"]][self.next_point] + tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx] dist = math.sqrt( (x-tx)*(x-tx) + - (y-ty)*(y-ty) + ((y-ty)*(y-ty))/2 ) # Timeout if we spend too long between points - elif self.frame_counter - self.last_point_frame > 40: - self.internal_status["dc"] = str(int(self.internal_status["dc"]) + 1) + elif self._state_counter - self._last_checkpoint_state > self.state_timeout: + self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1) - self.dist = dist + self._dist = dist # Call step callbacks - if self.before_out is not None: - after(self, self.before_out) - if not self.resetting: - self.before_out = before(self) \ No newline at end of file + # These should call celeste.act() to set next input + if self._before_out is not None: + after(self, self._before_out) + + # Do not run before callback if after() triggered a reset. + if not self._resetting: + self._before_out = before(self) + self._apply_keys() + \ No newline at end of file diff --git a/celeste/main.py b/celeste/main.py index 9ad288d..5b27be5 100644 --- a/celeste/main.py +++ b/celeste/main.py @@ -1,30 +1,24 @@ from collections import namedtuple from collections import deque +from pathlib import Path import random import math - +import json import torch -# Glue layer from celeste import Celeste +run_data_path = Path("out") +run_data_path.mkdir(parents = True, exist_ok = True) + compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) -state_number_map = [ - "xpos", - "ypos", - "xvel", - "yvel", - "next_point" -] - - # Celeste env properties -n_observations = len(state_number_map) +n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) @@ -39,7 +33,7 @@ EPS_END = 0.05 EPS_DECAY = 1000 -BATCH_SIZE = 128 +BATCH_SIZE = 1_000 # Learning rate of target_net. # Controls how soft our soft update is. # @@ -64,9 +58,19 @@ GAMMA = 0.99 class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() - self.layer1 = torch.nn.Linear(n_observations, 128) - self.layer2 = torch.nn.Linear(128, 128) - self.layer3 = torch.nn.Linear(128, n_actions) + + self.layers = torch.nn.Sequential( + torch.nn.Linear(n_observations, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.torch.nn.Linear(128, n_actions) + ) # Can be called with one input, or with a batch. # @@ -77,9 +81,7 @@ class DQN(torch.nn.Module): # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): - x = torch.nn.functional.relu(self.layer1(x)) - x = torch.nn.functional.relu(self.layer2(x)) - return self.layer3(x) + return self.layers(x) @@ -94,7 +96,7 @@ num_episodes = 100 # Memory: a deque that holds recent states as Transitions # Has a fixed length, drops oldest # element if maxlen is exceeded. -memory = deque([], maxlen=10_000) +memory = deque([], maxlen=100_000) policy_net = DQN( @@ -112,11 +114,10 @@ target_net.load_state_dict(policy_net.state_dict()) optimizer = torch.optim.AdamW( policy_net.parameters(), - lr = 1e-4, # Hyperparameter: learning rate + lr = 0.01, # Hyperparameter: learning rate amsgrad = True ) - def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. @@ -303,39 +304,68 @@ def optimize_model(): optimizer.step() +episode_number = 0 + + +if (run_data_path/"checkpoint.torch").is_file(): + # Load model if one exists + checkpoint = torch.load((run_data_path/"checkpoint.torch")) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + target_net.load_state_dict(checkpoint["target_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + memory = checkpoint["memory"] + episode_number = checkpoint["episode_number"] + 1 + steps_done = checkpoint["steps_done"] + + def on_state_before(celeste): global steps_done # Conversion to pytorch - state = celeste.status + state = celeste.state pt_state = torch.tensor( - [state[x] for x in state_number_map], + [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) - action = select_action( - pt_state, - steps_done - ) + action = None + while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): + action = select_action( + pt_state, + steps_done + ) + str_action = Celeste.action_space[action] steps_done += 1 - # Turn number into action string - str_action = Celeste.action_space[action] + # For manual testing + #str_action = "" + #while str_action not in Celeste.action_space: + # str_action = input("action> ") + #action = Celeste.action_space.index(str_action) + print(str_action) celeste.act(str_action) return state, action + + +image_interval = 10 + + def on_state_after(celeste, before_out): + global episode_number + global image_count state, action = before_out + next_state = celeste.state pt_state = torch.tensor( - [state[x] for x in state_number_map], + [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) @@ -346,33 +376,30 @@ def on_state_after(celeste, before_out): dtype = torch.long ) - next_state = celeste.status - - if next_state["deaths"] != 0: + if next_state.deaths != 0: pt_next_state = None reward = 0 else: pt_next_state = torch.tensor( - [next_state[x] for x in state_number_map], + [getattr(next_state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) - if state["next_point"] == next_state["next_point"]: - reward = state["dist"] - next_state["dist"] + if state.next_point == next_state.next_point: + reward = state.dist - next_state.dist - if reward > 0: + # Clip rewards that are too large + if reward > 1: reward = 1 - elif reward < 0: - reward = -1 else: reward = 0 + else: # Score for reaching a point - reward = 10 - + reward = 1 pt_reward = torch.tensor([reward], device = compute_device) @@ -387,6 +414,8 @@ def on_state_after(celeste, before_out): ) ) + print("==> ", int(reward)) + print("\n") # Only train the network if we have enough @@ -406,8 +435,51 @@ def on_state_after(celeste, before_out): # Move on to the next episode once we reach # a terminal state. - if (next_state["deaths"] != 0): + if (next_state.deaths != 0): + s = celeste.state + with open(run_data_path / "train.log", "a") as f: + f.write(json.dumps({ + "checkpoints": s.next_point, + "state_count": s.state_count + }) + "\n") + + + # Save model + torch.save({ + "policy_state_dict": policy_net.state_dict(), + "target_state_dict": target_net.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "memory": memory, + "episode_number": episode_number, + "steps_done": steps_done + }, run_data_path / "checkpoint.torch") + + + # Clean up screenshots + shots = Path("/home/mark/Desktop").glob("hackcel_*.png") + + target = run_data_path / Path(f"screenshots/{episode_number}") + target.mkdir(parents = True) + + for s in shots: + s.rename(target / s.name) + + # Save a prediction graph + if episode_number % image_interval == 0: + p = run_data_path / Path("model_images") + p.mkdir(parents = True, exist_ok = True) + torch.save({ + "policy_state_dict": policy_net.state_dict(), + "target_state_dict": target_net.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "memory": memory, + "episode_number": episode_number, + "steps_done": steps_done + }, p / f"{episode_number}.torch") + + print("State over, resetting") + episode_number += 1 celeste.reset()