Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste_ai/celeste.py

407 lines
9.2 KiB
Python
Raw Normal View History

2023-02-18 19:28:02 -08:00
from typing import NamedTuple
2023-02-15 22:24:40 -08:00
import subprocess
import time
import math
import numpy as np
2023-02-15 22:24:40 -08:00
class CelesteError(Exception):
pass
2023-02-18 19:28:02 -08:00
class CelesteState(NamedTuple):
# Stage number
stage: int
# Player position
# Regular position has 0,0 in top-left,
# centered position has 0,0 in center.
2023-02-18 19:28:02 -08:00
xpos: int
ypos: int
xpos_scaled: float
ypos_scaled: float
2023-02-18 19:28:02 -08:00
# Player velocity
xvel: float
yvel: float
# Number of deaths since game start
deaths: int
2023-02-24 22:17:45 -08:00
# If an index is true, we got a strawberry on that stage.
berries: list[bool]
2023-02-18 19:28:02 -08:00
# Distance to next point
dist: float
# Index of next point
next_point: int
# Coordinates of next point
next_point_x: int
next_point_y: int
# Number of states recieved since restart
state_count: int
# True if Madeline can dash
can_dash: bool
can_dash_int: int
2023-02-18 19:28:02 -08:00
2023-02-15 22:24:40 -08:00
class Celeste:
action_space = [
"left", # move left 0
"right", # move right 1
"jump-l", # jump left 2
"jump-r", # jump right 3
"dash-u", # dash up 4
"dash-r", # dash right 5
"dash-l", # dash left 6
"dash-ru", # dash right-up 7
"dash-lu" # dash left-up 8
2023-02-15 22:24:40 -08:00
]
2023-02-18 19:28:02 -08:00
# Map integers to state values.
# This also determines what data is fed to the model.
state_number_map = [
#"xpos",
#"ypos",
"xpos_scaled",
"ypos_scaled",
2023-02-26 12:09:05 -08:00
#"can_dash_int"
#"next_point_x",
#"next_point_y"
]
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
# Format is X, Y, range, force_y
# force_y is optional. If true, y_value MUST match perfectly.
target_checkpoints = [
[ # Stage 1
#(28, 88, 8), # Start pillar
(60, 80, 8), # Middle pillar
(105, 64, 8), # Right ledge
(25, 40, 8), # Left ledge
(97, 24, 5, True), # Small end ledge
(110, 16, 8), # End ledge
(110, -20, 8), # Next stage
]
2023-02-18 19:28:02 -08:00
]
2023-02-24 17:46:07 -08:00
# Maps room_x, room_y coordinates to stage number.
stage_map = [
[0, 1, 2, 3, 4]
]
2023-02-18 19:28:02 -08:00
def __init__(
self,
2023-02-19 20:57:19 -08:00
pico_path,
2023-02-18 19:28:02 -08:00
*,
2023-02-26 12:10:27 -08:00
state_timeout = 20,
2023-02-19 20:57:19 -08:00
cart_name = "hackcel.p8",
2023-02-18 19:28:02 -08:00
):
2023-02-15 22:24:40 -08:00
# Start pico-8
2023-02-18 19:28:02 -08:00
self._process = subprocess.Popen(
2023-02-19 20:57:19 -08:00
pico_path,
2023-02-15 22:24:40 -08:00
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
)
# Wait for window to open and get window id
time.sleep(2)
winid = subprocess.check_output([
"xdotool",
"search",
"--class",
"pico8"
]).decode("utf-8").strip().split("\n")
if len(winid) != 1:
raise Exception("Could not find unique PICO-8 window id")
2023-02-18 19:28:02 -08:00
self._winid = winid[0]
2023-02-15 22:24:40 -08:00
# Load cartridge
2023-02-18 19:28:02 -08:00
self._keystring(f"load {cart_name}")
self._keypress("Enter")
self._keystring("run")
self._keypress("Enter", post = 1000)
# Parameters
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
# Internal variables
self._internal_state = {} # Raw data read from stdout
self._before_out = None # Output of "before" callback in update loop
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
self._state_counter = 0 # Number of frames we've run since last reset
self._next_checkpoint_idx = 0 # Index of next point
self._dist = 0 # Distance to next point
self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool
2023-02-26 12:10:27 -08:00
def act(self, action: str | int):
2023-02-18 19:28:02 -08:00
"""
Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time.
2023-02-15 22:24:40 -08:00
2023-02-18 19:28:02 -08:00
Args:
action (str): key name, as in Celeste.action_space
"""
2023-02-26 12:10:27 -08:00
if isinstance(action, int):
action = Celeste.action_space[action]
2023-02-18 19:28:02 -08:00
self._keys = {}
2023-02-15 22:24:40 -08:00
if action is None:
return
elif action == "left":
2023-02-18 19:28:02 -08:00
self._keys["Left"] = True
2023-02-15 22:24:40 -08:00
elif action == "right":
2023-02-18 19:28:02 -08:00
self._keys["Right"] = True
2023-02-15 22:24:40 -08:00
elif action == "jump":
2023-02-18 19:28:02 -08:00
self._keys["c"] = True
elif action == "jump-l":
self._keys["c"] = True
self._keys["Left"] = True
elif action == "jump-r":
self._keys["c"] = True
self._keys["Right"] = True
2023-02-15 22:24:40 -08:00
elif action == "dash-u":
2023-02-18 19:28:02 -08:00
self._keys["Up"] = True
self._keys["x"] = True
2023-02-15 22:24:40 -08:00
elif action == "dash-r":
2023-02-18 19:28:02 -08:00
self._keys["Right"] = True
self._keys["x"] = True
2023-02-15 22:24:40 -08:00
elif action == "dash-l":
2023-02-18 19:28:02 -08:00
self._keys["Left"] = True
self._keys["x"] = True
2023-02-15 22:24:40 -08:00
elif action == "dash-ru":
2023-02-18 19:28:02 -08:00
self._keys["Up"] = True
self._keys["Right"] = True
self._keys["x"] = True
2023-02-15 22:24:40 -08:00
elif action == "dash-lu":
2023-02-18 19:28:02 -08:00
self._keys["Up"] = True
self._keys["Left"] = True
self._keys["x"] = True
def _apply_keys(self):
for i in [
"x", "c",
"Left", "Right",
"Down", "Up"
]:
if self._keys.get(i):
self._keydown(i)
else:
self._keyup(i)
2023-02-15 22:24:40 -08:00
@property
2023-02-18 19:28:02 -08:00
def state(self):
2023-02-15 22:24:40 -08:00
try:
2023-02-18 19:28:02 -08:00
stage = (
2023-02-24 17:46:07 -08:00
Celeste.stage_map
2023-02-18 19:28:02 -08:00
[int(self._internal_state["ry"])]
[int(self._internal_state["rx"])]
)
2023-02-15 22:24:40 -08:00
if len(Celeste.target_checkpoints) <= stage:
next_point_x = 0
next_point_y = 0
2023-02-18 19:28:02 -08:00
else:
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
2023-02-18 19:28:02 -08:00
return CelesteState(
stage = stage,
xpos = int(self._internal_state["px"]),
ypos = int(self._internal_state["py"]),
xpos_scaled = int(self._internal_state["px"]) / 128.0,
ypos_scaled = int(self._internal_state["py"]) / 128.0,
2023-02-18 19:28:02 -08:00
xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]),
2023-02-24 22:17:45 -08:00
berries = [x == "t" for x in self._internal_state["fr"][1:]],
2023-02-18 19:28:02 -08:00
dist = self._dist,
next_point = self._next_checkpoint_idx,
next_point_x = next_point_x,
next_point_y = next_point_y,
state_count = self._state_counter,
can_dash = self._internal_state["ds"] == "t",
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
2023-02-18 19:28:02 -08:00
)
except KeyError:
raise CelesteError("Not enough data to get state.")
2023-02-15 22:24:40 -08:00
2023-02-18 19:28:02 -08:00
def _keypress(self, key: str, *, post = 200):
2023-02-15 22:24:40 -08:00
subprocess.run([
"xdotool",
"key",
2023-02-18 19:28:02 -08:00
"--window", self._winid,
2023-02-15 22:24:40 -08:00
key
])
time.sleep(post / 1000)
2023-02-18 19:28:02 -08:00
def _keydown(self, key: str):
2023-02-15 22:24:40 -08:00
subprocess.run([
"xdotool",
"keydown",
2023-02-18 19:28:02 -08:00
"--window", self._winid,
2023-02-15 22:24:40 -08:00
key
])
2023-02-18 19:28:02 -08:00
def _keyup(self, key: str):
2023-02-15 22:24:40 -08:00
subprocess.run([
"xdotool",
"keyup",
2023-02-18 19:28:02 -08:00
"--window", self._winid,
2023-02-15 22:24:40 -08:00
key
])
2023-02-18 19:28:02 -08:00
def _keystring(self, string, *, delay = 100, post = 200):
2023-02-15 22:24:40 -08:00
subprocess.run([
"xdotool",
"type",
2023-02-18 19:28:02 -08:00
"--window", self._winid,
2023-02-15 22:24:40 -08:00
"--delay", str(delay),
string
])
time.sleep(post / 1000)
def reset(self):
2023-02-18 19:28:02 -08:00
# Make sure all keys are released
self.act(None)
self._apply_keys()
2023-02-15 22:24:40 -08:00
2023-02-18 19:28:02 -08:00
self._internal_state = {}
self._next_checkpoint_idx = 0
self._state_counter = 0
self._before_out = None
self._resetting = True
self._last_checkpoint_state = 0
2023-02-15 23:38:27 -08:00
2023-02-18 19:28:02 -08:00
self._keypress("Escape")
self._keystring("run")
self._keypress("Enter", post = 1000)
2023-02-15 22:24:40 -08:00
2023-02-18 19:28:02 -08:00
# Clear all old stdout messages and
# wait for the game to restart.
for k in iter(self._process.stdout.readline, ""):
2023-02-15 22:24:40 -08:00
k = k.decode("utf-8")[:-1]
if k == "!RESTART":
break
2023-02-16 12:11:04 -08:00
def update_loop(self, before, after):
2023-02-18 19:28:02 -08:00
# Waits for stdout from pico-8 process
for line in iter(self._process.stdout.readline, ""):
l = line.decode("utf-8")[:-1].strip()
2023-02-15 23:38:27 -08:00
2023-02-18 19:28:02 -08:00
# Release all keys
self.act(None)
self._apply_keys()
2023-02-15 23:38:27 -08:00
2023-02-18 19:28:02 -08:00
# Clear reset state
self._resetting = False
2023-02-15 22:24:40 -08:00
# This should only occur at game start
if l in ["!RESTART"]:
continue
2023-02-18 19:28:02 -08:00
self._state_counter += 1
2023-02-15 22:24:40 -08:00
2023-02-18 19:28:02 -08:00
# Parse state string
2023-02-15 22:24:40 -08:00
for entry in l.split(";"):
if entry == "":
continue
key, val = entry.split(":")
2023-02-18 19:28:02 -08:00
self._internal_state[key] = val
2023-02-16 13:52:59 -08:00
2023-02-15 22:24:40 -08:00
if self.state.stage <= 0:
# Calculate distance to each point
x = self.state.xpos
y = self.state.ypos
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
if i < self._next_checkpoint_idx:
dist[i] = 1000
continue
# Update checkpoints
tx, ty = c[:2]
dist[i] = (math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
# Possible modification:
# make x-distance twice as valuable as y-distance
))
min_idx = int(dist.argmin())
dist = int(dist[min_idx])
t = Celeste.target_checkpoints[self.state.stage][min_idx]
range = t[2]
if len(t) == 3:
force_y = False
else:
force_y = t[3]
if force_y:
got_point = (
dist <= range and
y == t[1]
)
else:
got_point = dist <= range
if got_point:
self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point
tx, ty = (
Celeste.target_checkpoints
[self.state.stage]
[self._next_checkpoint_idx]
[:2]
)
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
)
# Timeout if we spend too long between points
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
2023-02-16 13:52:59 -08:00
self._dist = dist
2023-02-16 13:52:59 -08:00
# Call step callbacks
2023-02-18 19:28:02 -08:00
# These should call celeste.act() to set next input
if self._before_out is not None:
after(self, self._before_out)
# Do not run before callback if after() triggered a reset.
if not self._resetting:
self._before_out = before(self)
self._apply_keys()