Mark
/
celeste-ai
Archived
1
0
Fork 0

Cleaned up celeste wrapper

master
Mark 2023-02-18 19:28:02 -08:00
parent 85d8c7a300
commit 610e5eef92
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 319 additions and 154 deletions

View File

@ -1,12 +1,44 @@
from typing import NamedTuple
import subprocess import subprocess
import time import time
import threading
import math import math
from tqdm import tqdm
class CelesteError(Exception): class CelesteError(Exception):
pass pass
class CelesteState(NamedTuple):
# Stage number
stage: int
# Player position
xpos: int
ypos: int
# Player velocity
xvel: float
yvel: float
# Number of deaths since game start
deaths: int
# Distance to next point
dist: float
# Index of next point
next_point: int
# Coordinates of next point
next_point_x: int
next_point_y: int
# Number of states recieved since restart
state_count: int
# True if Madeline can dash
can_dash: bool
class Celeste: class Celeste:
action_space = [ action_space = [
"left", # move left "left", # move left
@ -20,10 +52,25 @@ class Celeste:
"dash-lu" # dash left-up "dash-lu" # dash left-up
] ]
def __init__(self): # Map integers to state values.
# This also determines what data is fed to the model.
state_number_map = [
"xpos",
"ypos",
"next_point_x",
"next_point_y"
]
def __init__(
self,
*,
state_timeout = 30,
cart_name = "hackcel.p8"
):
# Start pico-8 # Start pico-8
self.process = subprocess.Popen( self._process = subprocess.Popen(
"bin/pico-8/linux/pico8", "resources/pico-8/linux/pico8",
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT stderr=subprocess.STDOUT
@ -39,26 +86,34 @@ class Celeste:
]).decode("utf-8").strip().split("\n") ]).decode("utf-8").strip().split("\n")
if len(winid) != 1: if len(winid) != 1:
raise Exception("Could not find unique PICO-8 window id") raise Exception("Could not find unique PICO-8 window id")
self.winid = winid[0] self._winid = winid[0]
# Load cartridge # Load cartridge
self.keystring("load hackcel.p8") self._keystring(f"load {cart_name}")
self.keypress("Enter") self._keypress("Enter")
self.keystring("run") self._keystring("run")
self.keypress("Enter", post = 1000) self._keypress("Enter", post = 1000)
# Initialize variables
self.internal_status = {}
self.before_out = None
self.last_point_frame = 0
# Score system # Parameters
self.frame_counter = 0 self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
self.next_point = 0 self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
self.dist = 0 # distance to next point
self.target_points = [ # Internal variables
self._internal_state = {} # Raw data read from stdout
self._before_out = None # Output of "before" callback in update loop
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
self._state_counter = 0 # Number of frames we've run since last reset
self._next_checkpoint_idx = 0 # Index of next point
self._dist = 0 # Distance to next point
self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
self.target_checkpoints = [
[ # Stage 1 [ # Stage 1
(28, 88), # Start pillar #(28, 88), # Start pillar
(60, 80), # Middle pillar (60, 80), # Middle pillar
(105, 64), # Right ledge (105, 64), # Right ledge
(25, 40), # Left ledge (25, 40), # Left ledge
@ -67,119 +122,150 @@ class Celeste:
] ]
] ]
def act(self, action): def act(self, action: str):
self.keyup("x") """
self.keyup("c") Specify what keys should be down. This does NOT send key events.
self.keyup("Left") Celeste._apply_keys() does that at the right time.
self.keyup("Right")
self.keyup("Down")
self.keyup("Up")
Args:
action (str): key name, as in Celeste.action_space
"""
self._keys = {}
if action is None: if action is None:
return return
elif action == "left": elif action == "left":
self.keydown("Left") self._keys["Left"] = True
elif action == "right": elif action == "right":
self.keydown("Right") self._keys["Right"] = True
elif action == "jump": elif action == "jump":
self.keydown("c") self._keys["c"] = True
elif action == "dash-u": elif action == "dash-u":
self.keydown("Up") self._keys["Up"] = True
self.keydown("x") self._keys["x"] = True
elif action == "dash-r": elif action == "dash-r":
self.keydown("Right") self._keys["Right"] = True
self.keydown("x") self._keys["x"] = True
elif action == "dash-l": elif action == "dash-l":
self.keydown("Left") self._keys["Left"] = True
self.keydown("x") self._keys["x"] = True
elif action == "dash-ru": elif action == "dash-ru":
self.keydown("Up") self._keys["Up"] = True
self.keydown("Right") self._keys["Right"] = True
self.keydown("x") self._keys["x"] = True
elif action == "dash-lu": elif action == "dash-lu":
self.keydown("Up") self._keys["Up"] = True
self.keydown("Left") self._keys["Left"] = True
self.keydown("x") self._keys["x"] = True
def _apply_keys(self):
for i in [
"x", "c",
"Left", "Right",
"Down", "Up"
]:
if self._keys.get(i):
self._keydown(i)
else:
self._keyup(i)
@property @property
def status(self): def state(self):
try: try:
return { stage = (
"stage": ( [
[ [0, 1, 2, 3, 4]
[0, 1, 2, 3, 4] ]
] [int(self._internal_state["ry"])]
[int(self.internal_status["ry"])] [int(self._internal_state["rx"])]
[int(self.internal_status["rx"])] )
),
"xpos": int(self.internal_status["px"]), if len(self.target_checkpoints) < stage:
"ypos": int(self.internal_status["py"]), next_point_x = None
"xvel": float(self.internal_status["vx"]), next_point_y = None
"yvel": float(self.internal_status["vy"]), else:
"deaths": int(self.internal_status["dc"]), next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
return CelesteState(
stage = stage,
xpos = int(self._internal_state["px"]),
ypos = int(self._internal_state["py"]),
xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]),
dist = self._dist,
next_point = self._next_checkpoint_idx,
next_point_x = next_point_x,
next_point_y = next_point_y,
state_count = self._state_counter,
can_dash = self._internal_state["ds"] == "t"
)
"dist": self.dist,
"next_point": self.next_point,
"frame_count": self.frame_counter
}
except KeyError: except KeyError:
raise CelesteError("Not enough data to get status.") raise CelesteError("Not enough data to get state.")
def _keypress(self, key: str, *, post = 200):
def keypress(self, key: str, *, post = 200):
subprocess.run([ subprocess.run([
"xdotool", "xdotool",
"key", "key",
"--window", self.winid, "--window", self._winid,
key key
]) ])
time.sleep(post / 1000) time.sleep(post / 1000)
def keydown(self, key: str): def _keydown(self, key: str):
subprocess.run([ subprocess.run([
"xdotool", "xdotool",
"keydown", "keydown",
"--window", self.winid, "--window", self._winid,
key key
]) ])
def keyup(self, key: str): def _keyup(self, key: str):
subprocess.run([ subprocess.run([
"xdotool", "xdotool",
"keyup", "keyup",
"--window", self.winid, "--window", self._winid,
key key
]) ])
def keystring(self, string, *, delay = 100, post = 200): def _keystring(self, string, *, delay = 100, post = 200):
subprocess.run([ subprocess.run([
"xdotool", "xdotool",
"type", "type",
"--window", self.winid, "--window", self._winid,
"--delay", str(delay), "--delay", str(delay),
string string
]) ])
time.sleep(post / 1000) time.sleep(post / 1000)
def reset(self): def reset(self):
self.internal_status = {} # Make sure all keys are released
self.next_point = 0 self.act(None)
self.frame_counter = 0 self._apply_keys()
self.before_out = None
self.resetting = True
self.last_point_frame = 0
self.keypress("Escape") self._internal_state = {}
self.keystring("run") self._next_checkpoint_idx = 0
self.keypress("Enter", post = 1000) self._state_counter = 0
self._before_out = None
self._resetting = True
self._last_checkpoint_state = 0
self.flush_reader() self._keypress("Escape")
self._keystring("run")
self._keypress("Enter", post = 1000)
def flush_reader(self):
for k in iter(self.process.stdout.readline, ""):
# Clear all old stdout messages and
# wait for the game to restart.
for k in iter(self._process.stdout.readline, ""):
k = k.decode("utf-8")[:-1] k = k.decode("utf-8")[:-1]
if k == "!RESTART": if k == "!RESTART":
break break
@ -187,61 +273,68 @@ class Celeste:
def update_loop(self, before, after): def update_loop(self, before, after):
# Get state, call callback, wait for state # Waits for stdout from pico-8 process
# One line => one frame. for line in iter(self._process.stdout.readline, ""):
it = iter(self.process.stdout.readline, "")
for line in it:
l = line.decode("utf-8")[:-1].strip() l = line.decode("utf-8")[:-1].strip()
self.resetting = False
# Release all keys
self.act(None)
self._apply_keys()
# Clear reset state
self._resetting = False
# This should only occur at game start # This should only occur at game start
if l in ["!RESTART"]: if l in ["!RESTART"]:
continue continue
self.frame_counter += 1 self._state_counter += 1
# Parse status string # Parse state string
for entry in l.split(";"): for entry in l.split(";"):
if entry == "": if entry == "":
continue continue
key, val = entry.split(":") key, val = entry.split(":")
self.internal_status[key] = val self._internal_state[key] = val
# Update checkpoints # Update checkpoints
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
tx, ty = self.target_points[self.status["stage"]][self.next_point] x = self.state.xpos
x = self.status["xpos"] y = self.state.ypos
y = self.status["ypos"]
dist = math.sqrt( dist = math.sqrt(
(x-tx)*(x-tx) + (x-tx)*(x-tx) +
(y-ty)*(y-ty) ((y-ty)*(y-ty))/2
# Possible modification:
# make x-distance twice as valuable as y-distance
) )
if dist <= 4 and y == ty: if dist <= 5:
print(f"Got point {self.next_point}") print(f"Got point {self._next_checkpoint_idx}")
self.next_point += 1 self._next_checkpoint_idx += 1
self.last_point_frame = self.frame_counter self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point # Recalculate distance to new point
tx, ty = self.target_points[self.status["stage"]][self.next_point] tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
dist = math.sqrt( dist = math.sqrt(
(x-tx)*(x-tx) + (x-tx)*(x-tx) +
(y-ty)*(y-ty) ((y-ty)*(y-ty))/2
) )
# Timeout if we spend too long between points # Timeout if we spend too long between points
elif self.frame_counter - self.last_point_frame > 40: elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self.internal_status["dc"] = str(int(self.internal_status["dc"]) + 1) self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
self.dist = dist self._dist = dist
# Call step callbacks # Call step callbacks
if self.before_out is not None: # These should call celeste.act() to set next input
after(self, self.before_out) if self._before_out is not None:
if not self.resetting: after(self, self._before_out)
self.before_out = before(self)
# Do not run before callback if after() triggered a reset.
if not self._resetting:
self._before_out = before(self)
self._apply_keys()

View File

@ -1,30 +1,24 @@
from collections import namedtuple from collections import namedtuple
from collections import deque from collections import deque
from pathlib import Path
import random import random
import math import math
import json
import torch import torch
# Glue layer
from celeste import Celeste from celeste import Celeste
run_data_path = Path("out")
run_data_path.mkdir(parents = True, exist_ok = True)
compute_device = torch.device( compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu" "cuda" if torch.cuda.is_available() else "cpu"
) )
state_number_map = [
"xpos",
"ypos",
"xvel",
"yvel",
"next_point"
]
# Celeste env properties # Celeste env properties
n_observations = len(state_number_map) n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space) n_actions = len(Celeste.action_space)
@ -39,7 +33,7 @@ EPS_END = 0.05
EPS_DECAY = 1000 EPS_DECAY = 1000
BATCH_SIZE = 128 BATCH_SIZE = 1_000
# Learning rate of target_net. # Learning rate of target_net.
# Controls how soft our soft update is. # Controls how soft our soft update is.
# #
@ -64,9 +58,19 @@ GAMMA = 0.99
class DQN(torch.nn.Module): class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int): def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__() super(DQN, self).__init__()
self.layer1 = torch.nn.Linear(n_observations, 128)
self.layer2 = torch.nn.Linear(128, 128) self.layers = torch.nn.Sequential(
self.layer3 = torch.nn.Linear(128, n_actions) torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch. # Can be called with one input, or with a batch.
# #
@ -77,9 +81,7 @@ class DQN(torch.nn.Module):
# Recall that Q(s, a) is the (expected) return of taking # Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s` # action `a` at state `s`
def forward(self, x): def forward(self, x):
x = torch.nn.functional.relu(self.layer1(x)) return self.layers(x)
x = torch.nn.functional.relu(self.layer2(x))
return self.layer3(x)
@ -94,7 +96,7 @@ num_episodes = 100
# Memory: a deque that holds recent states as Transitions # Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest # Has a fixed length, drops oldest
# element if maxlen is exceeded. # element if maxlen is exceeded.
memory = deque([], maxlen=10_000) memory = deque([], maxlen=100_000)
policy_net = DQN( policy_net = DQN(
@ -112,11 +114,10 @@ target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = 1e-4, # Hyperparameter: learning rate lr = 0.01, # Hyperparameter: learning rate
amsgrad = True amsgrad = True
) )
def select_action(state, steps_done): def select_action(state, steps_done):
""" """
Select an action using an epsilon-greedy policy. Select an action using an epsilon-greedy policy.
@ -303,39 +304,68 @@ def optimize_model():
optimizer.step() optimizer.step()
episode_number = 0
if (run_data_path/"checkpoint.torch").is_file():
# Load model if one exists
checkpoint = torch.load((run_data_path/"checkpoint.torch"))
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def on_state_before(celeste): def on_state_before(celeste):
global steps_done global steps_done
# Conversion to pytorch # Conversion to pytorch
state = celeste.status state = celeste.state
pt_state = torch.tensor( pt_state = torch.tensor(
[state[x] for x in state_number_map], [getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32, dtype = torch.float32,
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0)
action = select_action( action = None
pt_state, while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
steps_done action = select_action(
) pt_state,
steps_done
)
str_action = Celeste.action_space[action]
steps_done += 1 steps_done += 1
# Turn number into action string
str_action = Celeste.action_space[action]
# For manual testing
#str_action = ""
#while str_action not in Celeste.action_space:
# str_action = input("action> ")
#action = Celeste.action_space.index(str_action)
print(str_action)
celeste.act(str_action) celeste.act(str_action)
return state, action return state, action
image_interval = 10
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global episode_number
global image_count
state, action = before_out state, action = before_out
next_state = celeste.state
pt_state = torch.tensor( pt_state = torch.tensor(
[state[x] for x in state_number_map], [getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32, dtype = torch.float32,
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0)
@ -346,33 +376,30 @@ def on_state_after(celeste, before_out):
dtype = torch.long dtype = torch.long
) )
next_state = celeste.status if next_state.deaths != 0:
if next_state["deaths"] != 0:
pt_next_state = None pt_next_state = None
reward = 0 reward = 0
else: else:
pt_next_state = torch.tensor( pt_next_state = torch.tensor(
[next_state[x] for x in state_number_map], [getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32, dtype = torch.float32,
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0)
if state["next_point"] == next_state["next_point"]: if state.next_point == next_state.next_point:
reward = state["dist"] - next_state["dist"] reward = state.dist - next_state.dist
if reward > 0: # Clip rewards that are too large
if reward > 1:
reward = 1 reward = 1
elif reward < 0:
reward = -1
else: else:
reward = 0 reward = 0
else: else:
# Score for reaching a point # Score for reaching a point
reward = 10 reward = 1
pt_reward = torch.tensor([reward], device = compute_device) pt_reward = torch.tensor([reward], device = compute_device)
@ -387,6 +414,8 @@ def on_state_after(celeste, before_out):
) )
) )
print("==> ", int(reward))
print("\n")
# Only train the network if we have enough # Only train the network if we have enough
@ -406,8 +435,51 @@ def on_state_after(celeste, before_out):
# Move on to the next episode once we reach # Move on to the next episode once we reach
# a terminal state. # a terminal state.
if (next_state["deaths"] != 0): if (next_state.deaths != 0):
s = celeste.state
with open(run_data_path / "train.log", "a") as f:
f.write(json.dumps({
"checkpoints": s.next_point,
"state_count": s.state_count
}) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, run_data_path / "checkpoint.torch")
# Clean up screenshots
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
target = run_data_path / Path(f"screenshots/{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a prediction graph
if episode_number % image_interval == 0:
p = run_data_path / Path("model_images")
p.mkdir(parents = True, exist_ok = True)
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, p / f"{episode_number}.torch")
print("State over, resetting") print("State over, resetting")
episode_number += 1
celeste.reset() celeste.reset()