Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

..

7 Commits

9 changed files with 390 additions and 288 deletions

View File

@ -70,21 +70,24 @@ class Celeste:
#"ypos",
"xpos_scaled",
"ypos_scaled",
"can_dash_int"
#"can_dash_int"
#"next_point_x",
#"next_point_y"
]
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
# Format is X, Y, range, force_y
# force_y is optional. If true, y_value MUST match perfectly.
target_checkpoints = [
[ # Stage 1
#(28, 88), # Start pillar
(60, 80), # Middle pillar
(105, 64), # Right ledge
(25, 40), # Left ledge
(110, 16), # End ledge
(110, -2), # Next stage
#(28, 88, 8), # Start pillar
(60, 80, 8), # Middle pillar
(105, 64, 8), # Right ledge
(25, 40, 8), # Left ledge
(97, 24, 5, True), # Small end ledge
(110, 16, 8), # End ledge
(110, -20, 8), # Next stage
]
]
@ -99,7 +102,7 @@ class Celeste:
self,
pico_path,
*,
state_timeout = 30,
state_timeout = 20,
cart_name = "hackcel.p8",
):
@ -144,7 +147,7 @@ class Celeste:
self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool
def act(self, action: str):
def act(self, action: str | int):
"""
Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time.
@ -153,6 +156,9 @@ class Celeste:
action (str): key name, as in Celeste.action_space
"""
if isinstance(action, int):
action = Celeste.action_space[action]
self._keys = {}
if action is None:
return
@ -208,9 +214,9 @@ class Celeste:
[int(self._internal_state["rx"])]
)
if len(Celeste.target_checkpoints) < stage:
next_point_x = None
next_point_y = None
if len(Celeste.target_checkpoints) <= stage:
next_point_x = 0
next_point_y = 0
else:
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
@ -329,7 +335,7 @@ class Celeste:
if self.state.stage <= 0:
# Calculate distance to each point
x = self.state.xpos
y = self.state.ypos
@ -340,7 +346,7 @@ class Celeste:
continue
# Update checkpoints
tx, ty = c
tx, ty = c[:2]
dist[i] = (math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
@ -351,13 +357,32 @@ class Celeste:
dist = int(dist[min_idx])
if dist <= 8:
print(f"Got point {min_idx}")
t = Celeste.target_checkpoints[self.state.stage][min_idx]
range = t[2]
if len(t) == 3:
force_y = False
else:
force_y = t[3]
if force_y:
got_point = (
dist <= range and
y == t[1]
)
else:
got_point = dist <= range
if got_point:
self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
tx, ty = (
Celeste.target_checkpoints
[self.state.stage]
[self._next_checkpoint_idx]
[:2]
)
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2

View File

@ -5,7 +5,7 @@ from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"state",
"last_state",
"action",
"next_state",
"reward"

View File

@ -1,6 +1,7 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# All of the following are required to load
@ -34,7 +35,7 @@ def best_action(
# Compute preditions
p = np.zeros((128, 128, 2), dtype=np.float32)
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
@ -43,26 +44,31 @@ def best_action(
k = np.asarray(policy_net(
torch.tensor(
[x, y, 0],
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][0] = np.argmax(k)
p[r][c] = np.argmax(k)
k = np.asarray(policy_net(
torch.tensor(
[x, y, 1],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][1] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
ax = axs[0]
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
ax = axs
ax.set(
adjustable = "box",
aspect = "equal",
@ -70,30 +76,16 @@ def best_action(
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
p,
cmap = cmap,
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
ax = axs[1]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -43,7 +43,7 @@ def predicted_reward(
k = np.asarray(policy_net(
torch.tensor(
[x, y, 0],
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)

View File

@ -5,33 +5,31 @@ import random
import math
import json
import torch
import shutil
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
screenshot_source = Path("/home/mark/Desktop")
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
# Remove old screenshots
shots = screenshot_source.glob("hackcel_*.png")
for s in shots:
s.unlink()
compute_device = torch.device(
@ -45,66 +43,51 @@ if __name__ == "__main__":
# Epsilon-greedy parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# Probability of choosing a random action starts at
# EPS_START and decays to EPS_END.
# EPS_DECAY controls the rate of decay.
EPS_START = 0.9
EPS_END = 0.02
EPS_DECAY = 100
# How many times we've reached each point.
# Used to compute epsilon-greedy probability with
# the parameters above.
point_counter = [0] * len(Celeste.target_checkpoints[0])
BATCH_SIZE = 100
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.05
# GAMMA is the discount factor as mentioned in the previous section
# Bellman equation time-discount factor
GAMMA = 0.9
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
# Train on this many transitions from
# replay memory each round
BATCH_SIZE = 100
# Controls target_net soft update.
# Should be between 0 and 1.
TAU = 0.05
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
# How many times we've reached each point.
# This is used to compute epsilon-greedy probability.
point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on
n_steps = 0 # Number of training steps we've completed
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
# Holds <Transition> objects, defined in
# network.py
memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
target_net.load_state_dict(policy_net.state_dict())
learning_rate = 0.001
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = learning_rate,
@ -122,11 +105,43 @@ if __name__ == "__main__":
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
n_episodes = checkpoint["n_episodes"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"]
def select_action(state, steps_done):
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
"""
Select an action using an epsilon-greedy policy.
@ -136,19 +151,13 @@ def select_action(state, steps_done):
Decay rate is controlled by EPS_DECAY.
"""
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
math.exp(
-1.0 * steps_done /
EPS_DECAY
)
math.exp(-1.0 * x / EPS_DECAY)
)
if sample > eps_threshold:
if random.random() > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
@ -175,7 +184,7 @@ def optimize_model():
# Conversion.
# Combine states, actions, and rewards into their own tensors.
state_batch = torch.cat(batch.state)
last_state_batch = torch.cat(batch.last_state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
@ -209,7 +218,7 @@ def optimize_model():
# This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice.
state_action_values = policy_net(state_batch).gather(1, action_batch)
state_action_values = policy_net(last_state_batch).gather(1, action_batch)
@ -282,36 +291,21 @@ def optimize_model():
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
action = select_action(
# Put state in a tensor
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
).unsqueeze(0),
action = select_action(
pt_state,
# Random action probability is determined by
# the number of times we've reached the next point.
point_counter[state.next_point]
)
str_action = Celeste.action_space[action]
"""
action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action(
pt_state,
steps_done
)
str_action = Celeste.action_space[action]
"""
steps_done += 1
# For manual testing
#str_action = ""
@ -319,86 +313,114 @@ def on_state_before(celeste):
# str_action = input("action> ")
#action = Celeste.action_space.index(str_action)
print(str_action)
celeste.act(str_action)
print(Celeste.action_space[action])
celeste.act(action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
return (
state, # CelesteState
action # Integer
)
finished_stage = False
def compute_reward(last_state, state):
global point_counter
reward = None
# No reward if dead
if next_state.deaths != 0:
pt_next_state = None
if state.deaths != 0:
reward = 0
# Reward for finishing a stage
elif next_state.stage >= 1:
finished_stage = True
reward = next_state.next_point - state.next_point
elif state.stage >= 1:
print("FINISHED STAGE!!")
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
reward += 1
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
for i in range(last_state.next_point, len(point_counter)):
point_counter[i] += 1
# Regular reward
else:
pt_next_state = torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
if state.next_point == next_state.next_point:
reward = 0
else:
# Reward for reaching a checkpoint
elif last_state.next_point != state.next_point:
print(f"Got point {state.next_point}")
# Reward for reaching a point
reward = next_state.next_point - state.next_point
reward = state.next_point - last_state.next_point
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1
# No reward otherwise
else:
reward = 0
# Strawberry reward
if next_state.berries[state.stage] and not state.berries[state.stage]:
print(f"Got stage {state.stage} bonus")
reward += 1
# (Will probably break current version of model)
#if state.berries[state.stage] and not state.berries[state.stage]:
# print(f"Got stage {state.stage} bonus")
# reward += 1
assert reward is not None
return reward * 10
def on_state_after(celeste, before_out):
global n_episodes
global n_steps
last_state, action = before_out
next_state = celeste.state
dead = next_state.deaths != 0
done = next_state.stage >= 1
reward = reward * 10
pt_reward = torch.tensor([reward], device = compute_device)
reward = compute_reward(last_state, next_state)
if dead:
next_state = None
elif done:
# We don't set the next state to None because
# the optimization routine forces zero reward
# for terminal states.
# Copy last state instead. It's a hack, but it
# should work.
next_state = last_state
# Add this state transition to memory.
memory.append(
Transition(
pt_state,
pt_action,
pt_next_state,
pt_reward
# last state
torch.tensor(
[getattr(last_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
)
)
@ -406,11 +428,10 @@ def on_state_after(celeste, before_out):
print("")
# Perform a training step
loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE:
n_steps += 1
loss = optimize_model()
# Soft update target_net weights
@ -423,65 +444,43 @@ def on_state_after(celeste, before_out):
)
target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
# Move on to the next episode and run
# housekeeping tasks.
if (dead or done):
s = celeste.state
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
with model_train_log.open("a") as f:
f.write(json.dumps({
"n_episodes": n_episodes,
"n_steps": n_steps,
"checkpoints": s.next_point,
"state_count": s.state_count,
"loss": None if loss is None else loss.item()
"loss": None if loss is None else loss.item(),
"done": done
}) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"point_counter": point_counter,
"episode_number": episode_number,
"steps_done": steps_done,
# Hyperparameters
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, model_save_path)
# Clean up screenshots
shots = screenshot_source.glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a snapshot
if episode_number % archive_interval == 0:
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_archive_dir / f"{episode_number}.torch")
if n_episodes % model_save_interval == 0:
save_model(model_archive_dir / f"{n_episodes}.torch")
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
print("Game over. Resetting.")
episode_number += 1
celeste.reset()
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"

View File

View File

@ -0,0 +1,69 @@
from pathlib import Path
import shutil
class ScreenshotManager:
def __init__(
self,
# Where PICO-8 saves screenshots
source: Path,
# How PICO-8 names screenshots.
# Example: "celeste_*.png"
pattern: str,
# Where we want to move screenshots.
target: Path
):
self.source = source
self.pattern = pattern
self.target = target
self.target.mkdir(
parents = True,
exist_ok = True
)
def clean(self):
shots = self.source.glob(self.pattern)
for s in shots:
s.unlink()
return self
def move(self, number: int | None = None, overwrite = False):
shots = self.source.glob(self.pattern)
if number == None:
# Auto-select new directory number.
# Chooses next highest int directory name
number = 0
for f in self.target.iterdir():
try:
number = max(
int(f.name),
number
)
except ValueError:
continue
number += 1
else:
target = self.target / str(number)
if target.exists():
if not overwrite:
raise Exception(f"Target \"{target}\" exists!")
else:
print(f"Target \"{target}\" exists, removing.")
shutil.rmtree(target)
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
return self

View File

@ -47,14 +47,6 @@ plots = {
if __name__ == "__main__":
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["best"]:
print("Making best-action plots...")
with Pool(5) as p:
@ -63,6 +55,14 @@ if __name__ == "__main__":
list((m / "model_archive").iterdir())
)
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["actual"]:
print("Making actual plots...")
with Pool(5) as p:

View File

@ -30,6 +30,16 @@ k_jump=4
k_dash=5
-- Set to false while training or running the model.
-- Set to true to play the game manually with debug print.
-- (good for finding coordinates of checkpoints)
--
-- If true, disables most hack features:
-- - screenshots at every frame
-- - frame skipping
-- - waiting for input
hack_human_mode = false
-- If true, disable screensake
hack_no_shake = true
@ -1209,6 +1219,10 @@ end
-- _update60 does 60 fps
-- default for celeste is 30.
function _update()
if hack_human_mode then
old_update()
return
end
-- Run at full speed until ready
if not hack_ready then
@ -1304,7 +1318,10 @@ end
-- Called at the same rate as _update,
-- but not necessarily at the same time.
function _draw()
--old_draw()
if hack_human_mode then
old_draw()
return
end
end
function old_update()