Compare commits
No commits in common. "24dd65ace88ca89f2a6c051043d2f9dc5c439f49" and "f40b58508e910cbf99ebf8163c91220ee9cba1aa" have entirely different histories.
24dd65ace8
...
f40b58508e
|
@ -70,24 +70,21 @@ class Celeste:
|
||||||
#"ypos",
|
#"ypos",
|
||||||
"xpos_scaled",
|
"xpos_scaled",
|
||||||
"ypos_scaled",
|
"ypos_scaled",
|
||||||
#"can_dash_int"
|
"can_dash_int"
|
||||||
#"next_point_x",
|
#"next_point_x",
|
||||||
#"next_point_y"
|
#"next_point_y"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Targets the agent tries to reach.
|
# Targets the agent tries to reach.
|
||||||
# The last target MUST be outside the frame.
|
# The last target MUST be outside the frame.
|
||||||
# Format is X, Y, range, force_y
|
|
||||||
# force_y is optional. If true, y_value MUST match perfectly.
|
|
||||||
target_checkpoints = [
|
target_checkpoints = [
|
||||||
[ # Stage 1
|
[ # Stage 1
|
||||||
#(28, 88, 8), # Start pillar
|
#(28, 88), # Start pillar
|
||||||
(60, 80, 8), # Middle pillar
|
(60, 80), # Middle pillar
|
||||||
(105, 64, 8), # Right ledge
|
(105, 64), # Right ledge
|
||||||
(25, 40, 8), # Left ledge
|
(25, 40), # Left ledge
|
||||||
(97, 24, 5, True), # Small end ledge
|
(110, 16), # End ledge
|
||||||
(110, 16, 8), # End ledge
|
(110, -2), # Next stage
|
||||||
(110, -20, 8), # Next stage
|
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -102,7 +99,7 @@ class Celeste:
|
||||||
self,
|
self,
|
||||||
pico_path,
|
pico_path,
|
||||||
*,
|
*,
|
||||||
state_timeout = 20,
|
state_timeout = 30,
|
||||||
cart_name = "hackcel.p8",
|
cart_name = "hackcel.p8",
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -147,7 +144,7 @@ class Celeste:
|
||||||
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||||
self._keys = {} # Dictionary of "key": bool
|
self._keys = {} # Dictionary of "key": bool
|
||||||
|
|
||||||
def act(self, action: str | int):
|
def act(self, action: str):
|
||||||
"""
|
"""
|
||||||
Specify what keys should be down. This does NOT send key events.
|
Specify what keys should be down. This does NOT send key events.
|
||||||
Celeste._apply_keys() does that at the right time.
|
Celeste._apply_keys() does that at the right time.
|
||||||
|
@ -156,9 +153,6 @@ class Celeste:
|
||||||
action (str): key name, as in Celeste.action_space
|
action (str): key name, as in Celeste.action_space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(action, int):
|
|
||||||
action = Celeste.action_space[action]
|
|
||||||
|
|
||||||
self._keys = {}
|
self._keys = {}
|
||||||
if action is None:
|
if action is None:
|
||||||
return
|
return
|
||||||
|
@ -214,9 +208,9 @@ class Celeste:
|
||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(Celeste.target_checkpoints) <= stage:
|
if len(Celeste.target_checkpoints) < stage:
|
||||||
next_point_x = 0
|
next_point_x = None
|
||||||
next_point_y = 0
|
next_point_y = None
|
||||||
else:
|
else:
|
||||||
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||||
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||||
|
@ -335,65 +329,46 @@ class Celeste:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if self.state.stage <= 0:
|
|
||||||
# Calculate distance to each point
|
|
||||||
x = self.state.xpos
|
|
||||||
y = self.state.ypos
|
|
||||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
|
||||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
|
||||||
if i < self._next_checkpoint_idx:
|
|
||||||
dist[i] = 1000
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update checkpoints
|
# Calculate distance to each point
|
||||||
tx, ty = c[:2]
|
x = self.state.xpos
|
||||||
dist[i] = (math.sqrt(
|
y = self.state.ypos
|
||||||
(x-tx)*(x-tx) +
|
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||||
((y-ty)*(y-ty))/2
|
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||||
# Possible modification:
|
if i < self._next_checkpoint_idx:
|
||||||
# make x-distance twice as valuable as y-distance
|
dist[i] = 1000
|
||||||
))
|
continue
|
||||||
min_idx = int(dist.argmin())
|
|
||||||
dist = int(dist[min_idx])
|
# Update checkpoints
|
||||||
|
tx, ty = c
|
||||||
|
dist[i] = (math.sqrt(
|
||||||
|
(x-tx)*(x-tx) +
|
||||||
|
((y-ty)*(y-ty))/2
|
||||||
|
# Possible modification:
|
||||||
|
# make x-distance twice as valuable as y-distance
|
||||||
|
))
|
||||||
|
min_idx = int(dist.argmin())
|
||||||
|
dist = int(dist[min_idx])
|
||||||
|
|
||||||
|
|
||||||
t = Celeste.target_checkpoints[self.state.stage][min_idx]
|
if dist <= 8:
|
||||||
range = t[2]
|
print(f"Got point {min_idx}")
|
||||||
if len(t) == 3:
|
self._next_checkpoint_idx = min_idx + 1
|
||||||
force_y = False
|
self._last_checkpoint_state = self._state_counter
|
||||||
else:
|
|
||||||
force_y = t[3]
|
|
||||||
|
|
||||||
if force_y:
|
# Recalculate distance to new point
|
||||||
got_point = (
|
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||||
dist <= range and
|
dist = math.sqrt(
|
||||||
y == t[1]
|
(x-tx)*(x-tx) +
|
||||||
)
|
((y-ty)*(y-ty))/2
|
||||||
else:
|
)
|
||||||
got_point = dist <= range
|
|
||||||
|
|
||||||
if got_point:
|
# Timeout if we spend too long between points
|
||||||
self._next_checkpoint_idx = min_idx + 1
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||||
self._last_checkpoint_state = self._state_counter
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||||
|
|
||||||
# Recalculate distance to new point
|
|
||||||
tx, ty = (
|
|
||||||
Celeste.target_checkpoints
|
|
||||||
[self.state.stage]
|
|
||||||
[self._next_checkpoint_idx]
|
|
||||||
[:2]
|
|
||||||
)
|
|
||||||
dist = math.sqrt(
|
|
||||||
(x-tx)*(x-tx) +
|
|
||||||
((y-ty)*(y-ty))/2
|
|
||||||
)
|
|
||||||
|
|
||||||
# Timeout if we spend too long between points
|
|
||||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
|
||||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
|
||||||
|
|
||||||
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
|
|
||||||
# Call step callbacks
|
# Call step callbacks
|
||||||
# These should call celeste.act() to set next input
|
# These should call celeste.act() to set next input
|
||||||
|
|
|
@ -5,7 +5,7 @@ from collections import namedtuple
|
||||||
Transition = namedtuple(
|
Transition = namedtuple(
|
||||||
"Transition",
|
"Transition",
|
||||||
(
|
(
|
||||||
"last_state",
|
"state",
|
||||||
"action",
|
"action",
|
||||||
"next_state",
|
"next_state",
|
||||||
"reward"
|
"reward"
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib as mpl
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# All of the following are required to load
|
# All of the following are required to load
|
||||||
|
@ -35,7 +34,7 @@ def best_action(
|
||||||
|
|
||||||
|
|
||||||
# Compute preditions
|
# Compute preditions
|
||||||
p = np.zeros((128, 128), dtype=np.float32)
|
p = np.zeros((128, 128, 2), dtype=np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for r in range(len(p)):
|
for r in range(len(p)):
|
||||||
for c in range(len(p[r])):
|
for c in range(len(p[r])):
|
||||||
|
@ -44,31 +43,26 @@ def best_action(
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y],
|
[x, y, 0],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
)[0])
|
)[0])
|
||||||
p[r][c] = np.argmax(k)
|
p[r][c][0] = np.argmax(k)
|
||||||
|
|
||||||
|
k = np.asarray(policy_net(
|
||||||
|
torch.tensor(
|
||||||
|
[x, y, 1],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = device
|
||||||
|
).unsqueeze(0)
|
||||||
|
)[0])
|
||||||
|
p[r][c][1] = np.argmax(k)
|
||||||
|
|
||||||
cmap = mpl.colors.ListedColormap(
|
|
||||||
[
|
|
||||||
"forestgreen",
|
|
||||||
"firebrick",
|
|
||||||
"lightgreen",
|
|
||||||
"salmon",
|
|
||||||
"darkturquoise",
|
|
||||||
"sandybrown",
|
|
||||||
"olive",
|
|
||||||
"darkorchid",
|
|
||||||
"mediumvioletred"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot predictions
|
# Plot predictions
|
||||||
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
|
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
|
||||||
ax = axs
|
ax = axs[0]
|
||||||
ax.set(
|
ax.set(
|
||||||
adjustable = "box",
|
adjustable = "box",
|
||||||
aspect = "equal",
|
aspect = "equal",
|
||||||
|
@ -76,16 +70,30 @@ def best_action(
|
||||||
)
|
)
|
||||||
|
|
||||||
plot = ax.pcolor(
|
plot = ax.pcolor(
|
||||||
p,
|
p[:,:,0],
|
||||||
cmap = cmap,
|
cmap = "Set1",
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
vmax = 8
|
vmax = 8
|
||||||
)
|
)
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
fig.colorbar(plot)
|
||||||
cbar.ax.set_yticklabels(Celeste.action_space)
|
|
||||||
|
|
||||||
|
ax = axs[1]
|
||||||
|
ax.set(
|
||||||
|
adjustable = "box",
|
||||||
|
aspect = "equal",
|
||||||
|
title = "Best Action"
|
||||||
|
)
|
||||||
|
|
||||||
|
plot = ax.pcolor(
|
||||||
|
p[:,:,0],
|
||||||
|
cmap = "Set1",
|
||||||
|
vmin = 0,
|
||||||
|
vmax = 8
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.invert_yaxis()
|
||||||
|
fig.colorbar(plot)
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
fig.savefig(out_filename)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
|
@ -43,7 +43,7 @@ def predicted_reward(
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y],
|
[x, y, 0],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
|
@ -5,31 +5,33 @@ import random
|
||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
from celeste_ai import Celeste
|
from celeste_ai import Celeste
|
||||||
from celeste_ai import DQN
|
from celeste_ai import DQN
|
||||||
from celeste_ai import Transition
|
from celeste_ai import Transition
|
||||||
from celeste_ai.util.screenshots import ScreenshotManager
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Where to read/write model data.
|
# Where to read/write model data.
|
||||||
model_data_root = Path("model_data/current")
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
sm = ScreenshotManager(
|
# Where PICO-8 saves screenshots.
|
||||||
# Where PICO-8 saves screenshots.
|
# Probably your desktop.
|
||||||
# Probably your desktop.
|
screenshot_source = Path("/home/mark/Desktop")
|
||||||
source = Path("/home/mark/Desktop"),
|
|
||||||
pattern = "hackcel_*.png",
|
|
||||||
target = model_data_root / "screenshots"
|
|
||||||
).clean() # Remove old screenshots
|
|
||||||
|
|
||||||
model_save_path = model_data_root / "model.torch"
|
model_save_path = model_data_root / "model.torch"
|
||||||
model_archive_dir = model_data_root / "model_archive"
|
model_archive_dir = model_data_root / "model_archive"
|
||||||
model_train_log = model_data_root / "train_log"
|
model_train_log = model_data_root / "train_log"
|
||||||
|
screenshot_dir = model_data_root / "screenshots"
|
||||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
# Remove old screenshots
|
||||||
|
shots = screenshot_source.glob("hackcel_*.png")
|
||||||
|
for s in shots:
|
||||||
|
s.unlink()
|
||||||
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
|
@ -43,51 +45,66 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# Epsilon-greedy parameters
|
# Epsilon-greedy parameters
|
||||||
# Probability of choosing a random action starts at
|
#
|
||||||
# EPS_START and decays to EPS_END.
|
# Original docs:
|
||||||
# EPS_DECAY controls the rate of decay.
|
# EPS_START is the starting value of epsilon
|
||||||
|
# EPS_END is the final value of epsilon
|
||||||
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||||
EPS_START = 0.9
|
EPS_START = 0.9
|
||||||
EPS_END = 0.02
|
EPS_END = 0.02
|
||||||
EPS_DECAY = 100
|
EPS_DECAY = 100
|
||||||
|
|
||||||
# Bellman equation time-discount factor
|
|
||||||
GAMMA = 0.9
|
|
||||||
|
|
||||||
# Train on this many transitions from
|
|
||||||
# replay memory each round
|
|
||||||
BATCH_SIZE = 100
|
|
||||||
|
|
||||||
# Controls target_net soft update.
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
TAU = 0.05
|
|
||||||
|
|
||||||
# Optimizer learning rate
|
|
||||||
learning_rate = 0.001
|
|
||||||
|
|
||||||
# Save a snapshot of the model every n
|
|
||||||
# episodes.
|
|
||||||
model_save_interval = 10
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# How many times we've reached each point.
|
# How many times we've reached each point.
|
||||||
# This is used to compute epsilon-greedy probability.
|
# Used to compute epsilon-greedy probability with
|
||||||
|
# the parameters above.
|
||||||
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
||||||
|
|
||||||
n_episodes = 0 # Number of episodes we've trained on
|
BATCH_SIZE = 100
|
||||||
n_steps = 0 # Number of training steps we've completed
|
# Learning rate of target_net.
|
||||||
|
# Controls how soft our soft update is.
|
||||||
|
#
|
||||||
|
# Should be between 0 and 1.
|
||||||
|
# Large values
|
||||||
|
# Small values do the opposite.
|
||||||
|
#
|
||||||
|
# A value of one makes target_net
|
||||||
|
# change at the same rate as policy_net.
|
||||||
|
#
|
||||||
|
# A value of zero makes target_net
|
||||||
|
# not change at all.
|
||||||
|
TAU = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
# GAMMA is the discount factor as mentioned in the previous section
|
||||||
|
GAMMA = 0.9
|
||||||
|
|
||||||
|
steps_done = 0
|
||||||
|
num_episodes = 100
|
||||||
|
episode_number = 0
|
||||||
|
archive_interval = 10
|
||||||
|
|
||||||
# Create replay memory.
|
# Create replay memory.
|
||||||
#
|
#
|
||||||
# Holds <Transition> objects, defined in
|
# Transition: a container for naming data (defined in util.py)
|
||||||
# network.py
|
# Memory: a deque that holds recent states as Transitions
|
||||||
|
# Has a fixed length, drops oldest
|
||||||
|
# element if maxlen is exceeded.
|
||||||
memory = deque([], maxlen=50_000)
|
memory = deque([], maxlen=50_000)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
target_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
target_net.load_state_dict(policy_net.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
policy_net.parameters(),
|
policy_net.parameters(),
|
||||||
lr = learning_rate,
|
lr = learning_rate,
|
||||||
|
@ -105,43 +122,11 @@ if __name__ == "__main__":
|
||||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
memory = checkpoint["memory"]
|
memory = checkpoint["memory"]
|
||||||
|
episode_number = checkpoint["episode_number"] + 1
|
||||||
n_episodes = checkpoint["n_episodes"]
|
steps_done = checkpoint["steps_done"]
|
||||||
n_steps = checkpoint["n_steps"]
|
|
||||||
point_counter = checkpoint["point_counter"]
|
point_counter = checkpoint["point_counter"]
|
||||||
|
|
||||||
|
def select_action(state, steps_done):
|
||||||
|
|
||||||
def save_model(path):
|
|
||||||
torch.save({
|
|
||||||
# Newtorks
|
|
||||||
"policy_state_dict": policy_net.state_dict(),
|
|
||||||
"target_state_dict": target_net.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
|
|
||||||
# Training data
|
|
||||||
"memory": memory,
|
|
||||||
"point_counter": point_counter,
|
|
||||||
"n_episodes": n_episodes,
|
|
||||||
"n_steps": n_steps,
|
|
||||||
|
|
||||||
# Hyperparameters,
|
|
||||||
# for reference
|
|
||||||
"eps_start": EPS_START,
|
|
||||||
"eps_end": EPS_END,
|
|
||||||
"eps_decay": EPS_DECAY,
|
|
||||||
"batch_size": BATCH_SIZE,
|
|
||||||
"tau": TAU,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"gamma": GAMMA
|
|
||||||
}, path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def select_action(state, x) -> int:
|
|
||||||
"""
|
"""
|
||||||
Select an action using an epsilon-greedy policy.
|
Select an action using an epsilon-greedy policy.
|
||||||
|
|
||||||
|
@ -151,13 +136,19 @@ def select_action(state, x) -> int:
|
||||||
Decay rate is controlled by EPS_DECAY.
|
Decay rate is controlled by EPS_DECAY.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Random number 0 <= x < 1
|
||||||
|
sample = random.random()
|
||||||
|
|
||||||
# Calculate random step threshhold
|
# Calculate random step threshhold
|
||||||
eps_threshold = (
|
eps_threshold = (
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
EPS_END + (EPS_START - EPS_END) *
|
||||||
math.exp(-1.0 * x / EPS_DECAY)
|
math.exp(
|
||||||
|
-1.0 * steps_done /
|
||||||
|
EPS_DECAY
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if random.random() > eps_threshold:
|
if sample > eps_threshold:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# t.max(1) will return the largest column value of each row.
|
# t.max(1) will return the largest column value of each row.
|
||||||
# second column on max result is index of where max element was
|
# second column on max result is index of where max element was
|
||||||
|
@ -184,7 +175,7 @@ def optimize_model():
|
||||||
|
|
||||||
# Conversion.
|
# Conversion.
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
# Combine states, actions, and rewards into their own tensors.
|
||||||
last_state_batch = torch.cat(batch.last_state)
|
state_batch = torch.cat(batch.state)
|
||||||
action_batch = torch.cat(batch.action)
|
action_batch = torch.cat(batch.action)
|
||||||
reward_batch = torch.cat(batch.reward)
|
reward_batch = torch.cat(batch.reward)
|
||||||
|
|
||||||
|
@ -218,7 +209,7 @@ def optimize_model():
|
||||||
# This gives us a tensor that contains the return we expect to get
|
# This gives us a tensor that contains the return we expect to get
|
||||||
# at that state if we follow the model's advice.
|
# at that state if we follow the model's advice.
|
||||||
|
|
||||||
state_action_values = policy_net(last_state_batch).gather(1, action_batch)
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -291,21 +282,36 @@ def optimize_model():
|
||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
state = celeste.state
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
action = select_action(
|
action = select_action(
|
||||||
# Put state in a tensor
|
pt_state,
|
||||||
torch.tensor(
|
|
||||||
[getattr(state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0),
|
|
||||||
|
|
||||||
# Random action probability is determined by
|
|
||||||
# the number of times we've reached the next point.
|
|
||||||
point_counter[state.next_point]
|
point_counter[state.next_point]
|
||||||
)
|
)
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
action = None
|
||||||
|
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
||||||
|
action = select_action(
|
||||||
|
pt_state,
|
||||||
|
steps_done
|
||||||
|
)
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
"""
|
||||||
|
|
||||||
|
steps_done += 1
|
||||||
|
|
||||||
|
|
||||||
# For manual testing
|
# For manual testing
|
||||||
#str_action = ""
|
#str_action = ""
|
||||||
|
@ -313,114 +319,86 @@ def on_state_before(celeste):
|
||||||
# str_action = input("action> ")
|
# str_action = input("action> ")
|
||||||
#action = Celeste.action_space.index(str_action)
|
#action = Celeste.action_space.index(str_action)
|
||||||
|
|
||||||
print(Celeste.action_space[action])
|
print(str_action)
|
||||||
celeste.act(action)
|
celeste.act(str_action)
|
||||||
|
|
||||||
return (
|
return state, action
|
||||||
state, # CelesteState
|
|
||||||
action # Integer
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_reward(last_state, state):
|
|
||||||
global point_counter
|
|
||||||
|
|
||||||
reward = None
|
|
||||||
|
|
||||||
# No reward if dead
|
|
||||||
if state.deaths != 0:
|
|
||||||
reward = 0
|
|
||||||
|
|
||||||
# Reward for finishing a stage
|
|
||||||
elif state.stage >= 1:
|
|
||||||
print("FINISHED STAGE!!")
|
|
||||||
|
|
||||||
# We don't set a fixed reward here because the agent may
|
|
||||||
# complete the stage before getting all points.
|
|
||||||
# The below line provides extra reward for taking shortcuts.
|
|
||||||
reward = state.next_point - last_state.next_point
|
|
||||||
reward += 1
|
|
||||||
|
|
||||||
# Add to point counter
|
|
||||||
for i in range(last_state.next_point, len(point_counter)):
|
|
||||||
point_counter[i] += 1
|
|
||||||
|
|
||||||
# Reward for reaching a checkpoint
|
|
||||||
elif last_state.next_point != state.next_point:
|
|
||||||
print(f"Got point {state.next_point}")
|
|
||||||
|
|
||||||
reward = state.next_point - last_state.next_point
|
|
||||||
|
|
||||||
# Add to point counter
|
|
||||||
for i in range(last_state.next_point, last_state.next_point + reward):
|
|
||||||
point_counter[i] += 1
|
|
||||||
|
|
||||||
# No reward otherwise
|
|
||||||
else:
|
|
||||||
reward = 0
|
|
||||||
|
|
||||||
# Strawberry reward
|
|
||||||
# (Will probably break current version of model)
|
|
||||||
#if state.berries[state.stage] and not state.berries[state.stage]:
|
|
||||||
# print(f"Got stage {state.stage} bonus")
|
|
||||||
# reward += 1
|
|
||||||
|
|
||||||
assert reward is not None
|
|
||||||
return reward * 10
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
def on_state_after(celeste, before_out):
|
||||||
global n_episodes
|
global episode_number
|
||||||
global n_steps
|
|
||||||
|
|
||||||
last_state, action = before_out
|
state, action = before_out
|
||||||
next_state = celeste.state
|
next_state = celeste.state
|
||||||
dead = next_state.deaths != 0
|
|
||||||
done = next_state.stage >= 1
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
pt_action = torch.tensor(
|
||||||
|
[[ action ]],
|
||||||
|
device = compute_device,
|
||||||
|
dtype = torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
finished_stage = False
|
||||||
|
# No reward if dead
|
||||||
|
if next_state.deaths != 0:
|
||||||
|
pt_next_state = None
|
||||||
|
reward = 0
|
||||||
|
|
||||||
|
# Reward for finishing a stage
|
||||||
|
elif next_state.stage >= 1:
|
||||||
|
finished_stage = True
|
||||||
|
reward = next_state.next_point - state.next_point
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
# Add to point counter
|
||||||
|
for i in range(state.next_point, state.next_point + reward):
|
||||||
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
# Regular reward
|
||||||
|
else:
|
||||||
|
pt_next_state = torch.tensor(
|
||||||
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
reward = compute_reward(last_state, next_state)
|
|
||||||
|
|
||||||
if dead:
|
if state.next_point == next_state.next_point:
|
||||||
next_state = None
|
reward = 0
|
||||||
elif done:
|
else:
|
||||||
# We don't set the next state to None because
|
print(f"Got point {state.next_point}")
|
||||||
# the optimization routine forces zero reward
|
# Reward for reaching a point
|
||||||
# for terminal states.
|
reward = next_state.next_point - state.next_point
|
||||||
# Copy last state instead. It's a hack, but it
|
|
||||||
# should work.
|
# Add to point counter
|
||||||
next_state = last_state
|
for i in range(state.next_point, state.next_point + reward):
|
||||||
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
# Strawberry reward
|
||||||
|
if next_state.berries[state.stage] and not state.berries[state.stage]:
|
||||||
|
print(f"Got stage {state.stage} bonus")
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
reward = reward * 10
|
||||||
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
# Add this state transition to memory.
|
||||||
memory.append(
|
memory.append(
|
||||||
Transition(
|
Transition(
|
||||||
# last state
|
pt_state,
|
||||||
torch.tensor(
|
pt_action,
|
||||||
[getattr(last_state, x) for x in Celeste.state_number_map],
|
pt_next_state,
|
||||||
dtype = torch.float32,
|
pt_reward
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0),
|
|
||||||
|
|
||||||
# action
|
|
||||||
torch.tensor(
|
|
||||||
[[ action ]],
|
|
||||||
device = compute_device,
|
|
||||||
dtype = torch.long
|
|
||||||
),
|
|
||||||
|
|
||||||
# next state
|
|
||||||
# None if dead or done.
|
|
||||||
torch.tensor(
|
|
||||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0) if next_state is not None else None,
|
|
||||||
|
|
||||||
# reward
|
|
||||||
torch.tensor(
|
|
||||||
[reward],
|
|
||||||
device = compute_device
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -428,10 +406,11 @@ def on_state_after(celeste, before_out):
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
# Perform a training step
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
|
# Only train the network if we have enough
|
||||||
|
# transitions in memory to do so.
|
||||||
if len(memory) >= BATCH_SIZE:
|
if len(memory) >= BATCH_SIZE:
|
||||||
n_steps += 1
|
|
||||||
loss = optimize_model()
|
loss = optimize_model()
|
||||||
|
|
||||||
# Soft update target_net weights
|
# Soft update target_net weights
|
||||||
|
@ -444,43 +423,65 @@ def on_state_after(celeste, before_out):
|
||||||
)
|
)
|
||||||
target_net.load_state_dict(target_net_state)
|
target_net.load_state_dict(target_net_state)
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
# Move on to the next episode and run
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
# housekeeping tasks.
|
|
||||||
if (dead or done):
|
|
||||||
s = celeste.state
|
s = celeste.state
|
||||||
n_episodes += 1
|
|
||||||
|
|
||||||
# Move screenshots
|
|
||||||
sm.move(
|
|
||||||
number = n_episodes,
|
|
||||||
overwrite = True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Log this episode
|
|
||||||
with model_train_log.open("a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
"n_episodes": n_episodes,
|
|
||||||
"n_steps": n_steps,
|
|
||||||
"checkpoints": s.next_point,
|
"checkpoints": s.next_point,
|
||||||
"loss": None if loss is None else loss.item(),
|
"state_count": s.state_count,
|
||||||
"done": done
|
"loss": None if loss is None else loss.item()
|
||||||
}) + "\n")
|
}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
torch.save({
|
||||||
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"memory": memory,
|
||||||
|
"point_counter": point_counter,
|
||||||
|
"episode_number": episode_number,
|
||||||
|
"steps_done": steps_done,
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
"eps_start": EPS_START,
|
||||||
|
"eps_end": EPS_END,
|
||||||
|
"eps_decay": EPS_DECAY,
|
||||||
|
"batch_size": BATCH_SIZE,
|
||||||
|
"tau": TAU,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"gamma": GAMMA
|
||||||
|
}, model_save_path)
|
||||||
|
|
||||||
|
|
||||||
|
# Clean up screenshots
|
||||||
|
shots = screenshot_source.glob("hackcel_*.png")
|
||||||
|
|
||||||
|
target = screenshot_dir / Path(f"{episode_number}")
|
||||||
|
target.mkdir(parents = True)
|
||||||
|
|
||||||
|
for s in shots:
|
||||||
|
s.rename(target / s.name)
|
||||||
|
|
||||||
# Save a snapshot
|
# Save a snapshot
|
||||||
if n_episodes % model_save_interval == 0:
|
if episode_number % archive_interval == 0:
|
||||||
save_model(model_archive_dir / f"{n_episodes}.torch")
|
torch.save({
|
||||||
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"memory": memory,
|
||||||
|
"episode_number": episode_number,
|
||||||
|
"steps_done": steps_done
|
||||||
|
}, model_archive_dir / f"{episode_number}.torch")
|
||||||
|
|
||||||
|
|
||||||
print("Game over. Resetting.")
|
print("Game over. Resetting.")
|
||||||
|
episode_number += 1
|
||||||
celeste.reset()
|
celeste.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
c = Celeste(
|
c = Celeste(
|
||||||
"resources/pico-8/linux/pico8"
|
"resources/pico-8/linux/pico8"
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
|
|
||||||
class ScreenshotManager:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
|
|
||||||
# Where PICO-8 saves screenshots
|
|
||||||
source: Path,
|
|
||||||
|
|
||||||
# How PICO-8 names screenshots.
|
|
||||||
# Example: "celeste_*.png"
|
|
||||||
pattern: str,
|
|
||||||
|
|
||||||
# Where we want to move screenshots.
|
|
||||||
target: Path
|
|
||||||
):
|
|
||||||
self.source = source
|
|
||||||
self.pattern = pattern
|
|
||||||
self.target = target
|
|
||||||
self.target.mkdir(
|
|
||||||
parents = True,
|
|
||||||
exist_ok = True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def clean(self):
|
|
||||||
shots = self.source.glob(self.pattern)
|
|
||||||
for s in shots:
|
|
||||||
s.unlink()
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def move(self, number: int | None = None, overwrite = False):
|
|
||||||
shots = self.source.glob(self.pattern)
|
|
||||||
|
|
||||||
if number == None:
|
|
||||||
|
|
||||||
# Auto-select new directory number.
|
|
||||||
# Chooses next highest int directory name
|
|
||||||
number = 0
|
|
||||||
for f in self.target.iterdir():
|
|
||||||
try:
|
|
||||||
number = max(
|
|
||||||
int(f.name),
|
|
||||||
number
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
number += 1
|
|
||||||
|
|
||||||
else:
|
|
||||||
target = self.target / str(number)
|
|
||||||
|
|
||||||
if target.exists():
|
|
||||||
if not overwrite:
|
|
||||||
raise Exception(f"Target \"{target}\" exists!")
|
|
||||||
else:
|
|
||||||
print(f"Target \"{target}\" exists, removing.")
|
|
||||||
shutil.rmtree(target)
|
|
||||||
|
|
||||||
target.mkdir(parents = True)
|
|
||||||
|
|
||||||
for s in shots:
|
|
||||||
s.rename(target / s.name)
|
|
||||||
return self
|
|
|
@ -47,14 +47,6 @@ plots = {
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
if plots["best"]:
|
|
||||||
print("Making best-action plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_best,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["prediction"]:
|
if plots["prediction"]:
|
||||||
print("Making prediction plots...")
|
print("Making prediction plots...")
|
||||||
with Pool(5) as p:
|
with Pool(5) as p:
|
||||||
|
@ -63,6 +55,14 @@ if __name__ == "__main__":
|
||||||
list((m / "model_archive").iterdir())
|
list((m / "model_archive").iterdir())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if plots["best"]:
|
||||||
|
print("Making best-action plots...")
|
||||||
|
with Pool(5) as p:
|
||||||
|
p.map(
|
||||||
|
plot_best,
|
||||||
|
list((m / "model_archive").iterdir())
|
||||||
|
)
|
||||||
|
|
||||||
if plots["actual"]:
|
if plots["actual"]:
|
||||||
print("Making actual plots...")
|
print("Making actual plots...")
|
||||||
with Pool(5) as p:
|
with Pool(5) as p:
|
||||||
|
|
|
@ -30,16 +30,6 @@ k_jump=4
|
||||||
k_dash=5
|
k_dash=5
|
||||||
|
|
||||||
|
|
||||||
-- Set to false while training or running the model.
|
|
||||||
-- Set to true to play the game manually with debug print.
|
|
||||||
-- (good for finding coordinates of checkpoints)
|
|
||||||
--
|
|
||||||
-- If true, disables most hack features:
|
|
||||||
-- - screenshots at every frame
|
|
||||||
-- - frame skipping
|
|
||||||
-- - waiting for input
|
|
||||||
hack_human_mode = false
|
|
||||||
|
|
||||||
-- If true, disable screensake
|
-- If true, disable screensake
|
||||||
hack_no_shake = true
|
hack_no_shake = true
|
||||||
|
|
||||||
|
@ -1219,10 +1209,6 @@ end
|
||||||
-- _update60 does 60 fps
|
-- _update60 does 60 fps
|
||||||
-- default for celeste is 30.
|
-- default for celeste is 30.
|
||||||
function _update()
|
function _update()
|
||||||
if hack_human_mode then
|
|
||||||
old_update()
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
-- Run at full speed until ready
|
-- Run at full speed until ready
|
||||||
if not hack_ready then
|
if not hack_ready then
|
||||||
|
@ -1318,10 +1304,7 @@ end
|
||||||
-- Called at the same rate as _update,
|
-- Called at the same rate as _update,
|
||||||
-- but not necessarily at the same time.
|
-- but not necessarily at the same time.
|
||||||
function _draw()
|
function _draw()
|
||||||
if hack_human_mode then
|
--old_draw()
|
||||||
old_draw()
|
|
||||||
return
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function old_update()
|
function old_update()
|
||||||
|
|
Reference in New Issue