Archived
1
0
This commit is contained in:
2023-03-08 16:08:24 -08:00
parent 058292c0bd
commit 571a337ff4
41 changed files with 0 additions and 1380 deletions

6
celeste_ai/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from .network import DQN
from .network import Transition
from .celeste import Celeste
from .celeste import CelesteError
from .celeste import CelesteState

407
celeste_ai/celeste.py Executable file
View File

@ -0,0 +1,407 @@
from typing import NamedTuple
import subprocess
import time
import math
import numpy as np
class CelesteError(Exception):
pass
class CelesteState(NamedTuple):
# Stage number
stage: int
# Player position
# Regular position has 0,0 in top-left,
# centered position has 0,0 in center.
xpos: int
ypos: int
xpos_scaled: float
ypos_scaled: float
# Player velocity
xvel: float
yvel: float
# Number of deaths since game start
deaths: int
# If an index is true, we got a strawberry on that stage.
berries: list[bool]
# Distance to next point
dist: float
# Index of next point
next_point: int
# Coordinates of next point
next_point_x: int
next_point_y: int
# Number of states recieved since restart
state_count: int
# True if Madeline can dash
can_dash: bool
can_dash_int: int
class Celeste:
action_space = [
"left", # move left 0
"right", # move right 1
"jump-l", # jump left 2
"jump-r", # jump right 3
"dash-u", # dash up 4
"dash-r", # dash right 5
"dash-l", # dash left 6
"dash-ru", # dash right-up 7
"dash-lu" # dash left-up 8
]
# Map integers to state values.
# This also determines what data is fed to the model.
state_number_map = [
#"xpos",
#"ypos",
"xpos_scaled",
"ypos_scaled",
#"can_dash_int"
#"next_point_x",
#"next_point_y"
]
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
# Format is X, Y, range, force_y
# force_y is optional. If true, y_value MUST match perfectly.
target_checkpoints = [
[ # Stage 1
#(28, 88, 8), # Start pillar
(60, 80, 8), # Middle pillar
(105, 64, 8), # Right ledge
(25, 40, 8), # Left ledge
(97, 24, 5, True), # Small end ledge
(110, 16, 8), # End ledge
(110, -20, 8), # Next stage
]
]
# Maps room_x, room_y coordinates to stage number.
stage_map = [
[0, 1, 2, 3, 4]
]
def __init__(
self,
pico_path,
*,
state_timeout = 20,
cart_name = "hackcel.p8",
):
# Start pico-8
self._process = subprocess.Popen(
pico_path,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
)
# Wait for window to open and get window id
time.sleep(2)
winid = subprocess.check_output([
"xdotool",
"search",
"--class",
"pico8"
]).decode("utf-8").strip().split("\n")
if len(winid) != 1:
raise Exception("Could not find unique PICO-8 window id")
self._winid = winid[0]
# Load cartridge
self._keystring(f"load {cart_name}")
self._keypress("Enter")
self._keystring("run")
self._keypress("Enter", post = 1000)
# Parameters
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
# Internal variables
self._internal_state = {} # Raw data read from stdout
self._before_out = None # Output of "before" callback in update loop
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
self._state_counter = 0 # Number of frames we've run since last reset
self._next_checkpoint_idx = 0 # Index of next point
self._dist = 0 # Distance to next point
self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool
def act(self, action: str | int):
"""
Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time.
Args:
action (str): key name, as in Celeste.action_space
"""
if isinstance(action, int):
action = Celeste.action_space[action]
self._keys = {}
if action is None:
return
elif action == "left":
self._keys["Left"] = True
elif action == "right":
self._keys["Right"] = True
elif action == "jump":
self._keys["c"] = True
elif action == "jump-l":
self._keys["c"] = True
self._keys["Left"] = True
elif action == "jump-r":
self._keys["c"] = True
self._keys["Right"] = True
elif action == "dash-u":
self._keys["Up"] = True
self._keys["x"] = True
elif action == "dash-r":
self._keys["Right"] = True
self._keys["x"] = True
elif action == "dash-l":
self._keys["Left"] = True
self._keys["x"] = True
elif action == "dash-ru":
self._keys["Up"] = True
self._keys["Right"] = True
self._keys["x"] = True
elif action == "dash-lu":
self._keys["Up"] = True
self._keys["Left"] = True
self._keys["x"] = True
def _apply_keys(self):
for i in [
"x", "c",
"Left", "Right",
"Down", "Up"
]:
if self._keys.get(i):
self._keydown(i)
else:
self._keyup(i)
@property
def state(self):
try:
stage = (
Celeste.stage_map
[int(self._internal_state["ry"])]
[int(self._internal_state["rx"])]
)
if len(Celeste.target_checkpoints) <= stage:
next_point_x = 0
next_point_y = 0
else:
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
return CelesteState(
stage = stage,
xpos = int(self._internal_state["px"]),
ypos = int(self._internal_state["py"]),
xpos_scaled = int(self._internal_state["px"]) / 128.0,
ypos_scaled = int(self._internal_state["py"]) / 128.0,
xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]),
berries = [x == "t" for x in self._internal_state["fr"][1:]],
dist = self._dist,
next_point = self._next_checkpoint_idx,
next_point_x = next_point_x,
next_point_y = next_point_y,
state_count = self._state_counter,
can_dash = self._internal_state["ds"] == "t",
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
)
except KeyError:
raise CelesteError("Not enough data to get state.")
def _keypress(self, key: str, *, post = 200):
subprocess.run([
"xdotool",
"key",
"--window", self._winid,
key
])
time.sleep(post / 1000)
def _keydown(self, key: str):
subprocess.run([
"xdotool",
"keydown",
"--window", self._winid,
key
])
def _keyup(self, key: str):
subprocess.run([
"xdotool",
"keyup",
"--window", self._winid,
key
])
def _keystring(self, string, *, delay = 100, post = 200):
subprocess.run([
"xdotool",
"type",
"--window", self._winid,
"--delay", str(delay),
string
])
time.sleep(post / 1000)
def reset(self):
# Make sure all keys are released
self.act(None)
self._apply_keys()
self._internal_state = {}
self._next_checkpoint_idx = 0
self._state_counter = 0
self._before_out = None
self._resetting = True
self._last_checkpoint_state = 0
self._keypress("Escape")
self._keystring("run")
self._keypress("Enter", post = 1000)
# Clear all old stdout messages and
# wait for the game to restart.
for k in iter(self._process.stdout.readline, ""):
k = k.decode("utf-8")[:-1]
if k == "!RESTART":
break
def update_loop(self, before, after):
# Waits for stdout from pico-8 process
for line in iter(self._process.stdout.readline, ""):
l = line.decode("utf-8")[:-1].strip()
# Release all keys
self.act(None)
self._apply_keys()
# Clear reset state
self._resetting = False
# This should only occur at game start
if l in ["!RESTART"]:
continue
self._state_counter += 1
# Parse state string
for entry in l.split(";"):
if entry == "":
continue
key, val = entry.split(":")
self._internal_state[key] = val
if self.state.stage <= 0:
# Calculate distance to each point
x = self.state.xpos
y = self.state.ypos
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
if i < self._next_checkpoint_idx:
dist[i] = 1000
continue
# Update checkpoints
tx, ty = c[:2]
dist[i] = (math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
# Possible modification:
# make x-distance twice as valuable as y-distance
))
min_idx = int(dist.argmin())
dist = int(dist[min_idx])
t = Celeste.target_checkpoints[self.state.stage][min_idx]
range = t[2]
if len(t) == 3:
force_y = False
else:
force_y = t[3]
if force_y:
got_point = (
dist <= range and
y == t[1]
)
else:
got_point = dist <= range
if got_point:
self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point
tx, ty = (
Celeste.target_checkpoints
[self.state.stage]
[self._next_checkpoint_idx]
[:2]
)
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
)
# Timeout if we spend too long between points
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
self._dist = dist
# Call step callbacks
# These should call celeste.act() to set next input
if self._before_out is not None:
after(self, self._before_out)
# Do not run before callback if after() triggered a reset.
if not self._resetting:
self._before_out = before(self)
self._apply_keys()

36
celeste_ai/network.py Normal file
View File

@ -0,0 +1,36 @@
import torch
from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"last_state",
"action",
"next_state",
"reward"
)
)
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
def forward(self, x):
return self.layers(x)

View File

@ -0,0 +1,3 @@
from .plot_predicted_reward import predicted_reward
from .plot_best_action import best_action

View File

@ -0,0 +1,121 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import json
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def best_action(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu"),
draw_path = True
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
x = c / 128.0
y = r / 128.0
k = np.asarray(policy_net(
torch.tensor(
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
ax = axs
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p,
cmap = cmap,
vmin = 0,
vmax = 8
)
if draw_path:
d = None
with Path("model_data/solved_4layer/paths.json").open("r") as f:
for l in f.readlines():
t = json.loads(l)
if t["current_image"] == model_file.name:
break
d = t
assert d is not None
plt.plot(
[max(0,x["xpos"]) for x in d["hist"]],
[max(0,x["ypos"] + 5) for x in d["hist"]],
marker = "",
markersize = 0,
linestyle = "-",
linewidth = 5,
color = "white",
solid_capstyle = "round",
solid_joinstyle = "round"
)
ax.invert_yaxis()
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
fig.savefig(out_filename)
plt.close()

View File

@ -0,0 +1,84 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128, 9), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
x = c / 128.0
y = r / 128.0
k = np.asarray(policy_net(
torch.tensor(
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
fig, axs = plt.subplots(2, 5, figsize = (20, 10))
for a in range(len(axs.ravel())):
if a >= len(Celeste.action_space):
continue
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
p[:,:,a],
cmap = "Greens",
vmin = 0,
#vmax = 5
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

119
celeste_ai/record_paths.py Normal file
View File

@ -0,0 +1,119 @@
from pathlib import Path
import torch
import json
from celeste_ai import Celeste
from celeste_ai import DQN
model_data_root = Path("model_data/current")
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
k = (model_data_root / "model_archive").iterdir()
i = 0
state_history = []
current_path = None
def next_image():
global policy_net
global current_path
global i
i += 1
try:
current_path = k.__next__()
except StopIteration:
return False
print(f"Pathing {current_path} ({i})")
# Load model if one exists
checkpoint = torch.load(
current_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
next_image()
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
global state_history
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
state_history.append({
"xpos": state.xpos,
"ypos": state.ypos,
"action": Celeste.action_space[action]
})
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
with (model_data_root / "paths.json").open("a") as f:
f.write(json.dumps(
{
"hist": state_history,
"current_image": current_path.name
}
) + "\n")
state_history = []
k = next_image()
if k is False:
raise Exception("Done.")
print("Game over. Resetting.")
celeste.reset()
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)

100
celeste_ai/test.py Normal file
View File

@ -0,0 +1,100 @@
from pathlib import Path
import torch
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_data_root.mkdir(parents = True, exist_ok = True)
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots_test"
).clean() # Remove old screenshots
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
episode_number = 0
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
print(str_action)
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
s = celeste.state
sm.move()
print("Game over. Resetting.")
celeste.reset()
episode_number += 1
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)

492
celeste_ai/train.py Normal file
View File

@ -0,0 +1,492 @@
from collections import namedtuple
from collections import deque
from pathlib import Path
import random
import math
import json
import torch
import shutil
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Epsilon-greedy parameters
# Probability of choosing a random action starts at
# EPS_START and decays to EPS_END.
# EPS_DECAY controls the rate of decay.
EPS_START = 0.9
EPS_END = 0.02
EPS_DECAY = 100
# Bellman equation time-discount factor
GAMMA = 0.9
# Train on this many transitions from
# replay memory each round
BATCH_SIZE = 100
# Controls target_net soft update.
# Should be between 0 and 1.
TAU = 0.05
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
# How many times we've reached each point.
# This is used to compute epsilon-greedy probability.
point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on
n_steps = 0 # Number of training steps we've completed
# Create replay memory.
#
# Holds <Transition> objects, defined in
# network.py
memory = deque([], maxlen=50_000)
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = learning_rate,
amsgrad = True
)
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
n_episodes = checkpoint["n_episodes"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"]
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
"""
Select an action using an epsilon-greedy policy.
Sometimes use our model, sometimes sample one uniformly.
P(random action) starts at EPS_START and decays to EPS_END.
Decay rate is controlled by EPS_DECAY.
"""
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
math.exp(-1.0 * x / EPS_DECAY)
)
if random.random() > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1).item()
else:
return random.randint( 0, n_actions-1 )
def optimize_model():
if len(memory) < BATCH_SIZE:
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
# Get a random sample of transitions
batch = random.sample(memory, BATCH_SIZE)
# Conversion.
# Transposes batch, turning an array of Transitions
# into a Transition of arrays.
batch = Transition(*zip(*batch))
# Conversion.
# Combine states, actions, and rewards into their own tensors.
last_state_batch = torch.cat(batch.last_state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch.
# True if this is a final state, False if it isn't.
#
# We use this to select non-final states later.
non_final_mask = torch.tensor(
tuple(map(
lambda s: s is not None,
batch.next_state
))
)
non_final_next_states = torch.cat(
[s for s in batch.next_state if s is not None]
)
# How .gather works:
# if out = a.gather(1, b),
# out[i, j] = a[ i ][ b[i,j] ]
#
# a is "input," b is "index"
# Compute Q(s_t, a).
# This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice.
state_action_values = policy_net(last_state_batch).gather(1, action_batch)
# Compute V(s_t+1) for all next states.
# V(s_t+1) = max_a ( Q(s_t+1, a) )
# = the maximum reward over all possible actions at state s_t+1.
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
with torch.no_grad():
# Note the use of non_final_mask here.
# States that are final do not have their reward set by the line
# below, so their reward stays at zero.
#
# States that are not final get their predicted value
# set to the best value the model predicts.
#
#
# Expected values of action are selected with the "older" target net,
# and their best reward (over possible actions) is selected with max(1)[0].
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean?
# "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
# Compute Huber loss between predicted reward and expected reward.
# Pytorch is will account for this when we compute the gradient of loss.
#
# loss is a single-element tensor (i.e, a scalar).
criterion = torch.nn.SmoothL1Loss()
loss = criterion(
state_action_values,
expected_state_action_values.unsqueeze(1)
)
# We can now run a step of backpropagation on our model.
# TODO: what does this do?
#
# Calling .backward() multiple times will accumulate parameter gradients.
# Thus, we reset the gradient before each step.
optimizer.zero_grad()
# Compute the gradient of loss wrt... something?
# TODO: what does this do, we never use loss again?!
loss.backward()
# Prevent vanishing and exploding gradients.
# Forces gradients to be in [-clip_value, +clip_value]
torch.nn.utils.clip_grad_value_( # type: ignore
policy_net.parameters(),
clip_value = 100
)
# Perform a single optimizer step.
#
# Uses the current gradient, which is stored
# in the .grad attribute of the parameter.
optimizer.step()
return loss
def on_state_before(celeste):
state = celeste.state
action = select_action(
# Put state in a tensor
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# Random action probability is determined by
# the number of times we've reached the next point.
point_counter[state.next_point]
)
# For manual testing
#str_action = ""
#while str_action not in Celeste.action_space:
# str_action = input("action> ")
#action = Celeste.action_space.index(str_action)
print(Celeste.action_space[action])
celeste.act(action)
return (
state, # CelesteState
action # Integer
)
def compute_reward(last_state, state):
global point_counter
reward = None
# No reward if dead
if state.deaths != 0:
reward = 0
# Reward for finishing a stage
elif state.stage >= 1:
print("FINISHED STAGE!!")
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
reward += 1
# Add to point counter
for i in range(last_state.next_point, len(point_counter)):
point_counter[i] += 1
# Reward for reaching a checkpoint
elif last_state.next_point != state.next_point:
print(f"Got point {state.next_point}")
reward = state.next_point - last_state.next_point
# Add to point counter
for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1
# No reward otherwise
else:
reward = 0
# Strawberry reward
# (Will probably break current version of model)
#if state.berries[state.stage] and not state.berries[state.stage]:
# print(f"Got stage {state.stage} bonus")
# reward += 1
assert reward is not None
return reward * 10
def on_state_after(celeste, before_out):
global n_episodes
global n_steps
last_state, action = before_out
next_state = celeste.state
dead = next_state.deaths != 0
done = next_state.stage >= 1
reward = compute_reward(last_state, next_state)
if dead:
next_state = None
elif done:
# We don't set the next state to None because
# the optimization routine forces zero reward
# for terminal states.
# Copy last state instead. It's a hack, but it
# should work.
next_state = last_state
# Add this state transition to memory.
memory.append(
Transition(
# last state
torch.tensor(
[getattr(last_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
)
)
print("==> ", reward)
print("")
# Perform a training step
loss = None
if len(memory) >= BATCH_SIZE:
n_steps += 1
loss = optimize_model()
# Soft update target_net weights
target_net_state = target_net.state_dict()
policy_net_state = policy_net.state_dict()
for key in policy_net_state:
target_net_state[key] = (
policy_net_state[key] * TAU +
target_net_state[key] * (1-TAU)
)
target_net.load_state_dict(target_net_state)
# Move on to the next episode and run
# housekeeping tasks.
if (dead or done):
s = celeste.state
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
with model_train_log.open("a") as f:
f.write(json.dumps({
"n_episodes": n_episodes,
"n_steps": n_steps,
"checkpoints": s.next_point,
"loss": None if loss is None else loss.item(),
"done": done
}) + "\n")
# Save a snapshot
if n_episodes % model_save_interval == 0:
save_model(model_archive_dir / f"{n_episodes}.torch")
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
print("Game over. Resetting.")
celeste.reset()
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)

View File

View File

@ -0,0 +1,70 @@
from pathlib import Path
import shutil
class ScreenshotManager:
def __init__(
self,
# Where PICO-8 saves screenshots
source: Path,
# How PICO-8 names screenshots.
# Example: "celeste_*.png"
pattern: str,
# Where we want to move screenshots.
target: Path
):
self.source = source
self.pattern = pattern
self.target = target
self.target.mkdir(
parents = True,
exist_ok = True
)
def clean(self):
shots = self.source.glob(self.pattern)
for s in shots:
s.unlink()
return self
def move(self, number: int | None = None, overwrite = False):
shots = self.source.glob(self.pattern)
if number == None:
# Auto-select new directory number.
# Chooses next highest int directory name
number = 0
for f in self.target.iterdir():
try:
number = max(
int(f.name),
number
)
except ValueError:
continue
number += 1
target = self.target / str(number)
else:
target = self.target / str(number)
if target.exists():
if not overwrite:
raise Exception(f"Target \"{target}\" exists!")
else:
print(f"Target \"{target}\" exists, removing.")
shutil.rmtree(target)
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
return self