Archived
1
0

Reorganized celeste code

This commit is contained in:
2023-02-19 20:57:19 -08:00
parent d648d896c0
commit 55ac62dc47
11 changed files with 201 additions and 157 deletions

View File

@ -0,0 +1,6 @@
from .network import DQN
from .network import Transition
from .celeste import Celeste
from .celeste import CelesteError
from .celeste import CelesteState

340
celeste/celeste_ai/celeste.py Executable file
View File

@ -0,0 +1,340 @@
from typing import NamedTuple
import subprocess
import time
import math
class CelesteError(Exception):
pass
class CelesteState(NamedTuple):
# Stage number
stage: int
# Player position
xpos: int
ypos: int
# Player velocity
xvel: float
yvel: float
# Number of deaths since game start
deaths: int
# Distance to next point
dist: float
# Index of next point
next_point: int
# Coordinates of next point
next_point_x: int
next_point_y: int
# Number of states recieved since restart
state_count: int
# True if Madeline can dash
can_dash: bool
class Celeste:
action_space = [
"left", # move left
"right", # move right
"jump", # jump
"dash-u", # dash up
"dash-r", # dash right
"dash-l", # dash left
"dash-ru", # dash right-up
"dash-lu" # dash left-up
]
# Map integers to state values.
# This also determines what data is fed to the model.
state_number_map = [
"xpos",
"ypos",
"next_point_x",
"next_point_y"
]
def __init__(
self,
pico_path,
*,
state_timeout = 30,
cart_name = "hackcel.p8",
):
# Start pico-8
self._process = subprocess.Popen(
pico_path,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
)
# Wait for window to open and get window id
time.sleep(2)
winid = subprocess.check_output([
"xdotool",
"search",
"--class",
"pico8"
]).decode("utf-8").strip().split("\n")
if len(winid) != 1:
raise Exception("Could not find unique PICO-8 window id")
self._winid = winid[0]
# Load cartridge
self._keystring(f"load {cart_name}")
self._keypress("Enter")
self._keystring("run")
self._keypress("Enter", post = 1000)
# Parameters
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
# Internal variables
self._internal_state = {} # Raw data read from stdout
self._before_out = None # Output of "before" callback in update loop
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
self._state_counter = 0 # Number of frames we've run since last reset
self._next_checkpoint_idx = 0 # Index of next point
self._dist = 0 # Distance to next point
self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
self.target_checkpoints = [
[ # Stage 1
#(28, 88), # Start pillar
(60, 80), # Middle pillar
(105, 64), # Right ledge
(25, 40), # Left ledge
(110, 16), # End ledge
(110, -2), # Next stage
]
]
def act(self, action: str):
"""
Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time.
Args:
action (str): key name, as in Celeste.action_space
"""
self._keys = {}
if action is None:
return
elif action == "left":
self._keys["Left"] = True
elif action == "right":
self._keys["Right"] = True
elif action == "jump":
self._keys["c"] = True
elif action == "dash-u":
self._keys["Up"] = True
self._keys["x"] = True
elif action == "dash-r":
self._keys["Right"] = True
self._keys["x"] = True
elif action == "dash-l":
self._keys["Left"] = True
self._keys["x"] = True
elif action == "dash-ru":
self._keys["Up"] = True
self._keys["Right"] = True
self._keys["x"] = True
elif action == "dash-lu":
self._keys["Up"] = True
self._keys["Left"] = True
self._keys["x"] = True
def _apply_keys(self):
for i in [
"x", "c",
"Left", "Right",
"Down", "Up"
]:
if self._keys.get(i):
self._keydown(i)
else:
self._keyup(i)
@property
def state(self):
try:
stage = (
[
[0, 1, 2, 3, 4]
]
[int(self._internal_state["ry"])]
[int(self._internal_state["rx"])]
)
if len(self.target_checkpoints) < stage:
next_point_x = None
next_point_y = None
else:
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
return CelesteState(
stage = stage,
xpos = int(self._internal_state["px"]),
ypos = int(self._internal_state["py"]),
xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]),
dist = self._dist,
next_point = self._next_checkpoint_idx,
next_point_x = next_point_x,
next_point_y = next_point_y,
state_count = self._state_counter,
can_dash = self._internal_state["ds"] == "t"
)
except KeyError:
raise CelesteError("Not enough data to get state.")
def _keypress(self, key: str, *, post = 200):
subprocess.run([
"xdotool",
"key",
"--window", self._winid,
key
])
time.sleep(post / 1000)
def _keydown(self, key: str):
subprocess.run([
"xdotool",
"keydown",
"--window", self._winid,
key
])
def _keyup(self, key: str):
subprocess.run([
"xdotool",
"keyup",
"--window", self._winid,
key
])
def _keystring(self, string, *, delay = 100, post = 200):
subprocess.run([
"xdotool",
"type",
"--window", self._winid,
"--delay", str(delay),
string
])
time.sleep(post / 1000)
def reset(self):
# Make sure all keys are released
self.act(None)
self._apply_keys()
self._internal_state = {}
self._next_checkpoint_idx = 0
self._state_counter = 0
self._before_out = None
self._resetting = True
self._last_checkpoint_state = 0
self._keypress("Escape")
self._keystring("run")
self._keypress("Enter", post = 1000)
# Clear all old stdout messages and
# wait for the game to restart.
for k in iter(self._process.stdout.readline, ""):
k = k.decode("utf-8")[:-1]
if k == "!RESTART":
break
def update_loop(self, before, after):
# Waits for stdout from pico-8 process
for line in iter(self._process.stdout.readline, ""):
l = line.decode("utf-8")[:-1].strip()
# Release all keys
self.act(None)
self._apply_keys()
# Clear reset state
self._resetting = False
# This should only occur at game start
if l in ["!RESTART"]:
continue
self._state_counter += 1
# Parse state string
for entry in l.split(";"):
if entry == "":
continue
key, val = entry.split(":")
self._internal_state[key] = val
# Update checkpoints
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
x = self.state.xpos
y = self.state.ypos
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
# Possible modification:
# make x-distance twice as valuable as y-distance
)
if dist <= 5:
print(f"Got point {self._next_checkpoint_idx}")
self._next_checkpoint_idx += 1
self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
)
# Timeout if we spend too long between points
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
self._dist = dist
# Call step callbacks
# These should call celeste.act() to set next input
if self._before_out is not None:
after(self, self._before_out)
# Do not run before callback if after() triggered a reset.
if not self._resetting:
self._before_out = before(self)
self._apply_keys()

View File

@ -0,0 +1,36 @@
import torch
from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
def forward(self, x):
return self.layers(x)

View File

@ -0,0 +1,2 @@
from .plot_actual_reward import actual_reward
from .plot_predicted_reward import predicted_reward

View File

@ -0,0 +1,81 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def actual_reward(
model_file: Path,
target_point: tuple[int, int],
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
checkpoint = torch.load(
model_file,
map_location = device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
# Only plot memory related to this point
if (x_target, y_target) != target_point:
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
# Draw target point on plot
ax.plot(
target_point[0],
target_point[1],
"k."
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -0,0 +1,77 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = np.asarray(policy_net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
p[:,:,a],
cmap = "Greens",
vmin = 0,
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

449
celeste/celeste_ai/train.py Normal file
View File

@ -0,0 +1,449 @@
from collections import namedtuple
from collections import deque
from pathlib import Path
import random
import math
import json
import torch
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Epsilon-greedy parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 4000
BATCH_SIZE = 1_000
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.005
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate
amsgrad = True
)
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def select_action(state, steps_done):
"""
Select an action using an epsilon-greedy policy.
Sometimes use our model, sometimes sample one uniformly.
P(random action) starts at EPS_START and decays to EPS_END.
Decay rate is controlled by EPS_DECAY.
"""
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
math.exp(
-1.0 * steps_done /
EPS_DECAY
)
)
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1).item()
else:
return random.randint( 0, n_actions-1 )
def optimize_model():
if len(memory) < BATCH_SIZE:
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
# Get a random sample of transitions
batch = random.sample(memory, BATCH_SIZE)
# Conversion.
# Transposes batch, turning an array of Transitions
# into a Transition of arrays.
batch = Transition(*zip(*batch))
# Conversion.
# Combine states, actions, and rewards into their own tensors.
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch.
# True if this is a final state, False if it isn't.
#
# We use this to select non-final states later.
non_final_mask = torch.tensor(
tuple(map(
lambda s: s is not None,
batch.next_state
))
)
non_final_next_states = torch.cat(
[s for s in batch.next_state if s is not None]
)
# How .gather works:
# if out = a.gather(1, b),
# out[i, j] = a[ i ][ b[i,j] ]
#
# a is "input," b is "index"
# If this doesn't make sense, RTFD.
# Compute Q(s_t, a).
# - Use policy_net to compute Q(s_t) for each state in the batch.
# This gives a tensor of [ Q(state, left), Q(state, right) ]
#
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
# listing the action that was taken in each transition.
# 0 => we went left, 1 => we went right.
#
# This aligns nicely with the output of policy_net. We use
# action_batch to index the output of policy_net's prediction.
#
# This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice.
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Compute V(s_t+1) for all next states.
# V(s_t+1) = max_a ( Q(s_t+1, a) )
# = the maximum reward over all possible actions at state s_t+1.
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
# Don't compute gradient for operations in this block.
# If you don't understand what this means, RTFD.
with torch.no_grad():
# Note the use of non_final_mask here.
# States that are final do not have their reward set by the line
# below, so their reward stays at zero.
#
# States that are not final get their predicted value
# set to the best value the model predicts.
#
#
# Expected values of action are selected with the "older" target net,
# and their best reward (over possible actions) is selected with max(1)[0].
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean?
# "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
# Compute Huber loss between predicted reward and expected reward.
# Pytorch is will account for this when we compute the gradient of loss.
#
# loss is a single-element tensor (i.e, a scalar).
criterion = torch.nn.SmoothL1Loss()
loss = criterion(
state_action_values,
expected_state_action_values.unsqueeze(1)
)
# We can now run a step of backpropagation on our model.
# TODO: what does this do?
#
# Calling .backward() multiple times will accumulate parameter gradients.
# Thus, we reset the gradient before each step.
optimizer.zero_grad()
# Compute the gradient of loss wrt... something?
# TODO: what does this do, we never use loss again?!
loss.backward()
# Prevent vanishing and exploding gradients.
# Forces gradients to be in [-clip_value, +clip_value]
torch.nn.utils.clip_grad_value_( # type: ignore
policy_net.parameters(),
clip_value = 100
)
# Perform a single optimizer step.
#
# Uses the current gradient, which is stored
# in the .grad attribute of the parameter.
optimizer.step()
return loss
def on_state_before(celeste):
global steps_done
# Conversion to pytorch
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action(
pt_state,
steps_done
)
str_action = Celeste.action_space[action]
steps_done += 1
# For manual testing
#str_action = ""
#while str_action not in Celeste.action_space:
# str_action = input("action> ")
#action = Celeste.action_space.index(str_action)
print(str_action)
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
)
if next_state.deaths != 0:
pt_next_state = None
reward = 0
else:
pt_next_state = torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
if state.next_point == next_state.next_point:
reward = state.dist - next_state.dist
# Clip rewards that are too large
if reward > 1:
reward = 1
else:
reward = 0
else:
# Reward for reaching a point
reward = 1
pt_reward = torch.tensor([reward], device = compute_device)
# Add this state transition to memory.
memory.append(
Transition(
pt_state, # last state
pt_action,
pt_next_state, # next state
pt_reward
)
)
print("==> ", int(reward))
print("")
loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE:
loss = optimize_model()
# Soft update target_net weights
target_net_state = target_net.state_dict()
policy_net_state = policy_net.state_dict()
for key in policy_net_state:
target_net_state[key] = (
policy_net_state[key] * TAU +
target_net_state[key] * (1-TAU)
)
target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0):
s = celeste.state
with model_train_log.open("a") as f:
f.write(json.dumps({
"checkpoints": s.next_point,
"state_count": s.state_count,
"loss": None if loss is None else loss.item()
}) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_save_path)
# Clean up screenshots
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a prediction graph
if episode_number % archive_interval == 0:
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_archive_dir / f"{episode_number}.torch")
print("Game over. Resetting.")
episode_number += 1
celeste.reset()
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)