diff --git a/celeste/celeste.py b/celeste/celeste.py new file mode 100755 index 0000000..800954c --- /dev/null +++ b/celeste/celeste.py @@ -0,0 +1,230 @@ +import subprocess +import time +import threading +import math + +class CelesteError(Exception): + pass + +class Celeste: + action_space = [ + "left", # move left + "right", # move right + "jump", # jump + + "dash-u", # dash up + "dash-r", # dash right + "dash-l", # dash left + "dash-ru", # dash right-up + "dash-lu" # dash left-up + ] + + def __init__(self, on_get_state): + + self.on_get_state = on_get_state + + # Start pico-8 + self.process = subprocess.Popen( + "bin/pico-8/linux/pico8", + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT + ) + + # Wait for window to open and get window id + time.sleep(2) + winid = subprocess.check_output([ + "xdotool", + "search", + "--class", + "pico8" + ]).decode("utf-8").strip().split("\n") + if len(winid) != 1: + raise Exception("Could not find unique PICO-8 window id") + self.winid = winid[0] + + # Load cartridge + self.keystring("load hackcel.p8") + self.keypress("Enter") + self.keystring("run") + self.keypress("Enter", post = 1000) + + # Initialize variables + self.internal_status = {} + self.dead = False + + # Score system + self.frame_counter = 0 + self.next_point = 0 + self.dist = 0 # distance to next point + self.target_points = [ + [ # Stage 1 + (28, 88), # Start pillar + (60, 80), # Middle pillar + (105, 64), # Right ledge + (25, 40), # Left ledge + (110, 16), # End ledge + (110, -2), # Next stage + ] + ] + + def act(self, action): + self.keyup("x") + self.keyup("c") + self.keyup("Left") + self.keyup("Right") + self.keyup("Down") + self.keyup("Up") + + if action is None: + return + elif action == "left": + self.keydown("Left") + elif action == "right": + self.keydown("Right") + elif action == "jump": + self.keydown("c") + + elif action == "dash-u": + self.keydown("Up") + self.keydown("x") + elif action == "dash-r": + self.keydown("Right") + self.keydown("x") + elif action == "dash-l": + self.keydown("Left") + self.keydown("x") + elif action == "dash-ru": + self.keydown("Up") + self.keydown("Right") + self.keydown("x") + elif action == "dash-lu": + self.keydown("Up") + self.keydown("Left") + self.keydown("x") + + + @property + def status(self): + try: + return { + "stage": ( + [ + [0, 1, 2, 3, 4] + ] + [int(self.internal_status["ry"])] + [int(self.internal_status["rx"])] + ), + + "xpos": int(self.internal_status["px"]), + "ypos": int(self.internal_status["py"]), + "xvel": float(self.internal_status["vx"]), + "yvel": float(self.internal_status["vy"]), + "deaths": int(self.internal_status["dc"]), + + "dist": self.dist, + "next_point": self.next_point, + "frame_count": self.frame_counter + } + except KeyError: + raise CelesteError("Not enough data to get status.") + + + def keypress(self, key: str, *, post = 200): + subprocess.run([ + "xdotool", + "key", + "--window", self.winid, + key + ]) + time.sleep(post / 1000) + + def keydown(self, key: str): + subprocess.run([ + "xdotool", + "keydown", + "--window", self.winid, + key + ]) + + def keyup(self, key: str): + subprocess.run([ + "xdotool", + "keyup", + "--window", self.winid, + key + ]) + + def keystring(self, string, *, delay = 100, post = 200): + subprocess.run([ + "xdotool", + "type", + "--window", self.winid, + "--delay", str(delay), + string + ]) + time.sleep(post / 1000) + + def reset(self): + self.internal_status = {} + self.next_point = 0 + self.frame_counter = 0 + + self.keypress("Escape") + self.keystring("run") + self.keypress("Enter", post = 1000) + self.dead = False + + def flush_reader(self): + for k in iter(self.process.stdout.readline, ""): + k = k.decode("utf-8")[:-1] + if k == "!RESTART": + break + + def update_loop(self): + + # Get state, call callback, wait for state + # One line => one frame. + + for line in iter(self.process.stdout.readline, ""): + l = line.decode("utf-8")[:-1].strip() + + # This should only occur at game start + if l in ["!RESTART"]: + continue + + self.frame_counter += 1 + + # Parse status string + for entry in l.split(";"): + if entry == "": + continue + + key, val = entry.split(":") + self.internal_status[key] = val + + + # Update checkpoints + + tx, ty = self.target_points[self.status["stage"]][self.next_point] + x = self.status["xpos"] + y = self.status["ypos"] + dist = math.sqrt( + (x-tx)*(x-tx) + + (y-ty)*(y-ty) + ) + + if dist <= 4 and y == ty: + self.next_point += 1 + + # Recalculate distance to new point + tx, ty = self.target_points[self.status["stage"]][self.next_point] + dist = math.sqrt( + (x-tx)*(x-tx) + + (y-ty)*(y-ty) + ) + + self.dist = dist + + # Call step callback + self.on_get_state(self) \ No newline at end of file diff --git a/celeste/main.py b/celeste/main.py old mode 100755 new mode 100644 index 05b08e5..8a4c582 --- a/celeste/main.py +++ b/celeste/main.py @@ -1,213 +1,179 @@ -import subprocess -import time -import threading +from collections import namedtuple +from collections import deque +import random import math -class Celeste: +import torch - def __init__(self): - # Start process - self.process = subprocess.Popen( - "bin/pico-8/linux/pico8", - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT +# Glue layer +from celeste import Celeste + + +compute_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" +) + + + +# Epsilon-greedy parameters +# +# Original docs: +# EPS_START is the starting value of epsilon +# EPS_END is the final value of epsilon +# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay +EPS_START = 0.9 +EPS_END = 0.05 +EPS_DECAY = 1000 + + +# Outline our network +class DQN(torch.nn.Module): + def __init__(self, n_observations: int, n_actions: int): + super(DQN, self).__init__() + self.layer1 = torch.nn.Linear(n_observations, 128) + self.layer2 = torch.nn.Linear(128, 128) + self.layer3 = torch.nn.Linear(128, n_actions) + + # Can be called with one input, or with a batch. + # + # Returns tensor( + # [ Q(s, left), Q(s, right) ], ... + # ) + # + # Recall that Q(s, a) is the (expected) return of taking + # action `a` at state `s` + def forward(self, x): + x = torch.nn.functional.relu(self.layer1(x)) + x = torch.nn.functional.relu(self.layer2(x)) + return self.layer3(x) + + + +# Celeste env properties +n_observations = 4 +n_actions = len(Celeste.action_space) + +policy_net = DQN( + n_observations, + n_actions +).to(compute_device) + + +def select_action(state, steps_done): + """ + Select an action using an epsilon-greedy policy. + + Sometimes use our model, sometimes sample one uniformly. + + P(random action) starts at EPS_START and decays to EPS_END. + Decay rate is controlled by EPS_DECAY. + """ + + # Random number 0 <= x < 1 + sample = random.random() + + # Calculate random step threshhold + eps_threshold = ( + EPS_END + (EPS_START - EPS_END) * + math.exp( + -1.0 * steps_done / + EPS_DECAY ) - - # Wait for window to open and get window id - time.sleep(2) - winid = subprocess.check_output([ - "xdotool", - "search", - "--class", - "pico8" - ]).decode("utf-8").strip().split("\n") - if len(winid) != 1: - raise Exception("Could not find unique PICO-8 window id") - self.winid = winid[0] - - # Load cartridge - self.keystring("load hackcel.p8") - self.keypress("Enter") - self.keystring("run") - self.keypress("Enter", post = 1000) - - # Initialize variables - self.internal_status = {} - self.dead = False - - # -1: left - # 0: not moving - # 1: moving right - self.moving = 0 - - # Start state update thread - self.update_thread = threading.Thread(target = self._update_loop) - self.update_thread.start() - - def act(self, action): - self.keyup("x") - self.keyup("c") - self.keyup("Down") - self.keyup("Up") - if self.moving != -1: - self.keyup("Left") - if self.moving != 1: - self.keyup("Right") - - if action is None: - self.moving = 0 - self.keyup("Left") - self.keyup("Right") - elif action == "left": - if self.moving != -1: - self.keydown("Left") - self.moving = -1 - elif action == "right": - if self.moving != 1: - self.keydown("Right") - self.moving = 1 - - - @property - def status(self): - return { - "stage": ( - [ - [0, 1, 2, 3, 4] - ] - [int(self.internal_status["ry"])] - [int(self.internal_status["rx"])] - ), - - "xpos": int(self.internal_status["px"]), - "ypos": int(self.internal_status["py"]), - "xvel": float(self.internal_status["vx"]), - "yvel": float(self.internal_status["vy"]) - } - - - - # Possible actions - @property - def action_space(self): - return [ - "left", # move left - "rght", # move right - "jump", # jump - - "dshn", # dash north - "dshe", # dash east - "dshw", # dash west - "dsne", # dash north-east - "dsnw" # dash north-west - ] - - - def keypress(self, key: str, *, post = 200): - subprocess.run([ - "xdotool", - "key", - "--window", self.winid, - key - ]) - time.sleep(post / 1000) - - def keydown(self, key: str): - subprocess.run([ - "xdotool", - "keydown", - "--window", self.winid, - key - ]) - - def keyup(self, key: str): - subprocess.run([ - "xdotool", - "keyup", - "--window", self.winid, - key - ]) - - def keystring(self, string, *, delay = 100, post = 200): - subprocess.run([ - "xdotool", - "type", - "--window", self.winid, - "--delay", str(delay), - string - ]) - time.sleep(post / 1000) - - def reset(self): - self.internal_status = {} - if not self.dead: - self.keypress("Escape") - self.keystring("run") - self.keypress("Enter", post = 1000) - self.dead = False - - def _update_loop(self): - # Poll process for new output until finished - for line in iter(self.process.stdout.readline, ""): - l = line.decode("utf-8")[:-1] - - if l in ["!RESTART"]: - continue - - for entry in l.split(";"): - key, val = entry.split(":") - self.internal_status[key] = val - - # Exit game on death - if "dc" in self.internal_status and self.internal_status["dc"] != "0": - self.keypress("Escape") - self.dead = True - - # Flush stream reader - for k in iter(self.process.stdout.readline, ""): - k = k.decode("utf-8")[:-1] - if k == "!RESTART": - break - - -# Stage 1: - - -next_point = 0 -target_points = [ - (28, 88), # Start pillar - (60, 80), # Middle pillar - (105, 64), # Right ledge - (25, 40), # Left ledge - (110, 16), # End ledge - (110, -2), # Next stage -] - -# += 5 - -c = Celeste() -while True: - if c.dead: - print("\n\nDead, resetting...") - c.reset() - - - tx, ty = target_points[next_point] - x = c.status["xpos"] - y = c.status["ypos"] - - dist = math.sqrt( - (x-tx)*(x-tx) + - (y-ty)*(y-ty) ) - if dist <= 4 and y == ty: - next_point += 1 + if sample > eps_threshold: + with torch.no_grad(): + # t.max(1) will return the largest column value of each row. + # second column on max result is index of where max element was + # found, so we pick action with the larger expected reward. + return policy_net(state).max(1)[1].view(1, 1).item() - print(f"Target point: {next_point:02}, Dist: {dist:0.3}") + else: + return random.randint( 0, n_actions-1 ) - #print() - #print(c.status) +last_state = None + + +Transition = namedtuple( + "Transition", + ( + "state", + "action", + "next_state", + "reward" + ) +) + + +def on_state(celeste): + global last_state + + s = celeste.status + + if last_state is None: + last_state = s + return + + s_next = s["next_point"] + s_dist = s["dist"] + l_next = last_state["next_point"] + l_dist = last_state["dist"] + + + if l_next == s_next: + reward = l_dist - s_dist + else: + reward = 10 + + dead = s["deaths"] != 0 + frame_count = s["frame_count"] + + # Values at this point + # reward: reward for last action + # dead: true if game over + + state_number_map = [ + "xpos", + "ypos", + "xvel", + "yvel" + ] + + tf_state = torch.tensor( + [s[x] for x in state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + tf_last = torch.tensor( + [last_state[x] for x in state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + + + action = select_action( + tf_state, + frame_count + ) + + # Turn number into action string + action = Celeste.action_space[action] + + celeste.act(action) + + + + # Update previous state + last_state = s + + + +c = Celeste( + on_state +) + +c.update_loop()