374 lines
8.5 KiB
Python
Executable File
374 lines
8.5 KiB
Python
Executable File
from typing import NamedTuple
|
|
import subprocess
|
|
import time
|
|
import math
|
|
|
|
import numpy as np
|
|
|
|
class CelesteError(Exception):
|
|
pass
|
|
|
|
|
|
class CelesteState(NamedTuple):
|
|
# Stage number
|
|
stage: int
|
|
|
|
# Player position
|
|
# Regular position has 0,0 in top-left,
|
|
# centered position has 0,0 in center.
|
|
xpos: int
|
|
ypos: int
|
|
xpos_scaled: float
|
|
ypos_scaled: float
|
|
|
|
# Player velocity
|
|
xvel: float
|
|
yvel: float
|
|
|
|
# Number of deaths since game start
|
|
deaths: int
|
|
|
|
# Distance to next point
|
|
dist: float
|
|
|
|
# Index of next point
|
|
next_point: int
|
|
|
|
# Coordinates of next point
|
|
next_point_x: int
|
|
next_point_y: int
|
|
|
|
# Number of states recieved since restart
|
|
state_count: int
|
|
|
|
# True if Madeline can dash
|
|
can_dash: bool
|
|
can_dash_int: int
|
|
|
|
|
|
class Celeste:
|
|
action_space = [
|
|
"left", # move left 0
|
|
"right", # move right 1
|
|
#"jump", # jump
|
|
"jump-l", # jump left 2
|
|
"jump-r", # jump right 3
|
|
|
|
"dash-u", # dash up 4
|
|
"dash-r", # dash right 5
|
|
"dash-l", # dash left 6
|
|
"dash-ru", # dash right-up 7
|
|
"dash-lu" # dash left-up 8
|
|
]
|
|
|
|
# Map integers to state values.
|
|
# This also determines what data is fed to the model.
|
|
state_number_map = [
|
|
#"xpos",
|
|
#"ypos",
|
|
"xpos_scaled",
|
|
"ypos_scaled",
|
|
"can_dash_int"
|
|
#"next_point_x",
|
|
#"next_point_y"
|
|
]
|
|
|
|
# Targets the agent tries to reach.
|
|
# The last target MUST be outside the frame.
|
|
target_checkpoints = [
|
|
[ # Stage 1
|
|
#(28, 88), # Start pillar
|
|
(60, 80), # Middle pillar
|
|
(105, 64), # Right ledge
|
|
(25, 40), # Left ledge
|
|
(110, 16), # End ledge
|
|
(110, -2), # Next stage
|
|
]
|
|
]
|
|
|
|
def __init__(
|
|
self,
|
|
pico_path,
|
|
*,
|
|
state_timeout = 30,
|
|
cart_name = "hackcel.p8",
|
|
):
|
|
|
|
# Start pico-8
|
|
self._process = subprocess.Popen(
|
|
pico_path,
|
|
shell=True,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT
|
|
)
|
|
|
|
# Wait for window to open and get window id
|
|
time.sleep(2)
|
|
winid = subprocess.check_output([
|
|
"xdotool",
|
|
"search",
|
|
"--class",
|
|
"pico8"
|
|
]).decode("utf-8").strip().split("\n")
|
|
if len(winid) != 1:
|
|
raise Exception("Could not find unique PICO-8 window id")
|
|
self._winid = winid[0]
|
|
|
|
# Load cartridge
|
|
self._keystring(f"load {cart_name}")
|
|
self._keypress("Enter")
|
|
self._keystring("run")
|
|
self._keypress("Enter", post = 1000)
|
|
|
|
|
|
# Parameters
|
|
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
|
|
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
|
|
|
|
# Internal variables
|
|
self._internal_state = {} # Raw data read from stdout
|
|
self._before_out = None # Output of "before" callback in update loop
|
|
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
|
|
self._state_counter = 0 # Number of frames we've run since last reset
|
|
self._next_checkpoint_idx = 0 # Index of next point
|
|
self._dist = 0 # Distance to next point
|
|
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
|
self._keys = {} # Dictionary of "key": bool
|
|
|
|
def act(self, action: str):
|
|
"""
|
|
Specify what keys should be down. This does NOT send key events.
|
|
Celeste._apply_keys() does that at the right time.
|
|
|
|
Args:
|
|
action (str): key name, as in Celeste.action_space
|
|
"""
|
|
|
|
self._keys = {}
|
|
if action is None:
|
|
return
|
|
elif action == "left":
|
|
self._keys["Left"] = True
|
|
elif action == "right":
|
|
self._keys["Right"] = True
|
|
elif action == "jump":
|
|
self._keys["c"] = True
|
|
elif action == "jump-l":
|
|
self._keys["c"] = True
|
|
self._keys["Left"] = True
|
|
elif action == "jump-r":
|
|
self._keys["c"] = True
|
|
self._keys["Right"] = True
|
|
|
|
elif action == "dash-u":
|
|
self._keys["Up"] = True
|
|
self._keys["x"] = True
|
|
elif action == "dash-r":
|
|
self._keys["Right"] = True
|
|
self._keys["x"] = True
|
|
elif action == "dash-l":
|
|
self._keys["Left"] = True
|
|
self._keys["x"] = True
|
|
elif action == "dash-ru":
|
|
self._keys["Up"] = True
|
|
self._keys["Right"] = True
|
|
self._keys["x"] = True
|
|
elif action == "dash-lu":
|
|
self._keys["Up"] = True
|
|
self._keys["Left"] = True
|
|
self._keys["x"] = True
|
|
|
|
|
|
def _apply_keys(self):
|
|
for i in [
|
|
"x", "c",
|
|
"Left", "Right",
|
|
"Down", "Up"
|
|
]:
|
|
if self._keys.get(i):
|
|
self._keydown(i)
|
|
else:
|
|
self._keyup(i)
|
|
|
|
@property
|
|
def state(self):
|
|
try:
|
|
stage = (
|
|
[
|
|
[0, 1, 2, 3, 4]
|
|
]
|
|
[int(self._internal_state["ry"])]
|
|
[int(self._internal_state["rx"])]
|
|
)
|
|
|
|
if len(Celeste.target_checkpoints) < stage:
|
|
next_point_x = None
|
|
next_point_y = None
|
|
else:
|
|
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
|
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
|
|
|
|
|
return CelesteState(
|
|
stage = stage,
|
|
|
|
xpos = int(self._internal_state["px"]),
|
|
ypos = int(self._internal_state["py"]),
|
|
xpos_scaled = int(self._internal_state["px"]) / 128.0,
|
|
ypos_scaled = int(self._internal_state["py"]) / 128.0,
|
|
xvel = float(self._internal_state["vx"]),
|
|
yvel = float(self._internal_state["vy"]),
|
|
deaths = int(self._internal_state["dc"]),
|
|
|
|
dist = self._dist,
|
|
next_point = self._next_checkpoint_idx,
|
|
next_point_x = next_point_x,
|
|
next_point_y = next_point_y,
|
|
state_count = self._state_counter,
|
|
can_dash = self._internal_state["ds"] == "t",
|
|
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
|
|
)
|
|
|
|
except KeyError:
|
|
raise CelesteError("Not enough data to get state.")
|
|
|
|
def _keypress(self, key: str, *, post = 200):
|
|
subprocess.run([
|
|
"xdotool",
|
|
"key",
|
|
"--window", self._winid,
|
|
key
|
|
])
|
|
time.sleep(post / 1000)
|
|
|
|
def _keydown(self, key: str):
|
|
subprocess.run([
|
|
"xdotool",
|
|
"keydown",
|
|
"--window", self._winid,
|
|
key
|
|
])
|
|
|
|
def _keyup(self, key: str):
|
|
subprocess.run([
|
|
"xdotool",
|
|
"keyup",
|
|
"--window", self._winid,
|
|
key
|
|
])
|
|
|
|
def _keystring(self, string, *, delay = 100, post = 200):
|
|
subprocess.run([
|
|
"xdotool",
|
|
"type",
|
|
"--window", self._winid,
|
|
"--delay", str(delay),
|
|
string
|
|
])
|
|
time.sleep(post / 1000)
|
|
|
|
def reset(self):
|
|
# Make sure all keys are released
|
|
self.act(None)
|
|
self._apply_keys()
|
|
|
|
self._internal_state = {}
|
|
self._next_checkpoint_idx = 0
|
|
self._state_counter = 0
|
|
self._before_out = None
|
|
self._resetting = True
|
|
self._last_checkpoint_state = 0
|
|
|
|
self._keypress("Escape")
|
|
self._keystring("run")
|
|
self._keypress("Enter", post = 1000)
|
|
|
|
|
|
|
|
# Clear all old stdout messages and
|
|
# wait for the game to restart.
|
|
for k in iter(self._process.stdout.readline, ""):
|
|
k = k.decode("utf-8")[:-1]
|
|
if k == "!RESTART":
|
|
break
|
|
|
|
|
|
def update_loop(self, before, after):
|
|
# Waits for stdout from pico-8 process
|
|
for line in iter(self._process.stdout.readline, ""):
|
|
l = line.decode("utf-8")[:-1].strip()
|
|
|
|
# Release all keys
|
|
self.act(None)
|
|
self._apply_keys()
|
|
|
|
# Clear reset state
|
|
self._resetting = False
|
|
|
|
# This should only occur at game start
|
|
if l in ["!RESTART"]:
|
|
continue
|
|
|
|
self._state_counter += 1
|
|
|
|
# Parse state string
|
|
for entry in l.split(";"):
|
|
if entry == "":
|
|
continue
|
|
|
|
key, val = entry.split(":")
|
|
self._internal_state[key] = val
|
|
|
|
|
|
|
|
|
|
# Calculate distance to each point
|
|
x = self.state.xpos
|
|
y = self.state.ypos
|
|
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
|
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
|
if i < self._next_checkpoint_idx:
|
|
dist[i] = 1000
|
|
continue
|
|
|
|
# Update checkpoints
|
|
tx, ty = c
|
|
dist[i] = (math.sqrt(
|
|
(x-tx)*(x-tx) +
|
|
((y-ty)*(y-ty))/2
|
|
# Possible modification:
|
|
# make x-distance twice as valuable as y-distance
|
|
))
|
|
min_idx = int(dist.argmin())
|
|
dist = int(dist[min_idx])
|
|
|
|
|
|
if dist <= 8:
|
|
print(f"Got point {min_idx}")
|
|
self._next_checkpoint_idx = min_idx + 1
|
|
self._last_checkpoint_state = self._state_counter
|
|
|
|
# Recalculate distance to new point
|
|
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
|
dist = math.sqrt(
|
|
(x-tx)*(x-tx) +
|
|
((y-ty)*(y-ty))/2
|
|
)
|
|
|
|
# Timeout if we spend too long between points
|
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
|
|
|
|
|
self._dist = dist
|
|
|
|
# Call step callbacks
|
|
# These should call celeste.act() to set next input
|
|
if self._before_out is not None:
|
|
after(self, self._before_out)
|
|
|
|
# Do not run before callback if after() triggered a reset.
|
|
if not self._resetting:
|
|
self._before_out = before(self)
|
|
self._apply_keys()
|
|
|