Reorganized celeste code
This commit is contained in:
6
celeste/celeste_ai/__init__.py
Normal file
6
celeste/celeste_ai/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .network import DQN
|
||||
from .network import Transition
|
||||
|
||||
from .celeste import Celeste
|
||||
from .celeste import CelesteError
|
||||
from .celeste import CelesteState
|
340
celeste/celeste_ai/celeste.py
Executable file
340
celeste/celeste_ai/celeste.py
Executable file
@ -0,0 +1,340 @@
|
||||
from typing import NamedTuple
|
||||
import subprocess
|
||||
import time
|
||||
import math
|
||||
|
||||
class CelesteError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CelesteState(NamedTuple):
|
||||
# Stage number
|
||||
stage: int
|
||||
|
||||
# Player position
|
||||
xpos: int
|
||||
ypos: int
|
||||
|
||||
# Player velocity
|
||||
xvel: float
|
||||
yvel: float
|
||||
|
||||
# Number of deaths since game start
|
||||
deaths: int
|
||||
|
||||
# Distance to next point
|
||||
dist: float
|
||||
|
||||
# Index of next point
|
||||
next_point: int
|
||||
|
||||
# Coordinates of next point
|
||||
next_point_x: int
|
||||
next_point_y: int
|
||||
|
||||
# Number of states recieved since restart
|
||||
state_count: int
|
||||
|
||||
# True if Madeline can dash
|
||||
can_dash: bool
|
||||
|
||||
|
||||
class Celeste:
|
||||
action_space = [
|
||||
"left", # move left
|
||||
"right", # move right
|
||||
"jump", # jump
|
||||
|
||||
"dash-u", # dash up
|
||||
"dash-r", # dash right
|
||||
"dash-l", # dash left
|
||||
"dash-ru", # dash right-up
|
||||
"dash-lu" # dash left-up
|
||||
]
|
||||
|
||||
# Map integers to state values.
|
||||
# This also determines what data is fed to the model.
|
||||
state_number_map = [
|
||||
"xpos",
|
||||
"ypos",
|
||||
"next_point_x",
|
||||
"next_point_y"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pico_path,
|
||||
*,
|
||||
state_timeout = 30,
|
||||
cart_name = "hackcel.p8",
|
||||
):
|
||||
|
||||
# Start pico-8
|
||||
self._process = subprocess.Popen(
|
||||
pico_path,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT
|
||||
)
|
||||
|
||||
# Wait for window to open and get window id
|
||||
time.sleep(2)
|
||||
winid = subprocess.check_output([
|
||||
"xdotool",
|
||||
"search",
|
||||
"--class",
|
||||
"pico8"
|
||||
]).decode("utf-8").strip().split("\n")
|
||||
if len(winid) != 1:
|
||||
raise Exception("Could not find unique PICO-8 window id")
|
||||
self._winid = winid[0]
|
||||
|
||||
# Load cartridge
|
||||
self._keystring(f"load {cart_name}")
|
||||
self._keypress("Enter")
|
||||
self._keystring("run")
|
||||
self._keypress("Enter", post = 1000)
|
||||
|
||||
|
||||
# Parameters
|
||||
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
|
||||
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
|
||||
|
||||
# Internal variables
|
||||
self._internal_state = {} # Raw data read from stdout
|
||||
self._before_out = None # Output of "before" callback in update loop
|
||||
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
|
||||
self._state_counter = 0 # Number of frames we've run since last reset
|
||||
self._next_checkpoint_idx = 0 # Index of next point
|
||||
self._dist = 0 # Distance to next point
|
||||
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||
self._keys = {} # Dictionary of "key": bool
|
||||
|
||||
# Targets the agent tries to reach.
|
||||
# The last target MUST be outside the frame.
|
||||
self.target_checkpoints = [
|
||||
[ # Stage 1
|
||||
#(28, 88), # Start pillar
|
||||
(60, 80), # Middle pillar
|
||||
(105, 64), # Right ledge
|
||||
(25, 40), # Left ledge
|
||||
(110, 16), # End ledge
|
||||
(110, -2), # Next stage
|
||||
]
|
||||
]
|
||||
|
||||
def act(self, action: str):
|
||||
"""
|
||||
Specify what keys should be down. This does NOT send key events.
|
||||
Celeste._apply_keys() does that at the right time.
|
||||
|
||||
Args:
|
||||
action (str): key name, as in Celeste.action_space
|
||||
"""
|
||||
|
||||
self._keys = {}
|
||||
if action is None:
|
||||
return
|
||||
elif action == "left":
|
||||
self._keys["Left"] = True
|
||||
elif action == "right":
|
||||
self._keys["Right"] = True
|
||||
elif action == "jump":
|
||||
self._keys["c"] = True
|
||||
|
||||
elif action == "dash-u":
|
||||
self._keys["Up"] = True
|
||||
self._keys["x"] = True
|
||||
elif action == "dash-r":
|
||||
self._keys["Right"] = True
|
||||
self._keys["x"] = True
|
||||
elif action == "dash-l":
|
||||
self._keys["Left"] = True
|
||||
self._keys["x"] = True
|
||||
elif action == "dash-ru":
|
||||
self._keys["Up"] = True
|
||||
self._keys["Right"] = True
|
||||
self._keys["x"] = True
|
||||
elif action == "dash-lu":
|
||||
self._keys["Up"] = True
|
||||
self._keys["Left"] = True
|
||||
self._keys["x"] = True
|
||||
|
||||
|
||||
def _apply_keys(self):
|
||||
for i in [
|
||||
"x", "c",
|
||||
"Left", "Right",
|
||||
"Down", "Up"
|
||||
]:
|
||||
if self._keys.get(i):
|
||||
self._keydown(i)
|
||||
else:
|
||||
self._keyup(i)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
try:
|
||||
stage = (
|
||||
[
|
||||
[0, 1, 2, 3, 4]
|
||||
]
|
||||
[int(self._internal_state["ry"])]
|
||||
[int(self._internal_state["rx"])]
|
||||
)
|
||||
|
||||
if len(self.target_checkpoints) < stage:
|
||||
next_point_x = None
|
||||
next_point_y = None
|
||||
else:
|
||||
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||
|
||||
|
||||
return CelesteState(
|
||||
stage = stage,
|
||||
|
||||
xpos = int(self._internal_state["px"]),
|
||||
ypos = int(self._internal_state["py"]),
|
||||
xvel = float(self._internal_state["vx"]),
|
||||
yvel = float(self._internal_state["vy"]),
|
||||
deaths = int(self._internal_state["dc"]),
|
||||
|
||||
dist = self._dist,
|
||||
next_point = self._next_checkpoint_idx,
|
||||
next_point_x = next_point_x,
|
||||
next_point_y = next_point_y,
|
||||
state_count = self._state_counter,
|
||||
can_dash = self._internal_state["ds"] == "t"
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
raise CelesteError("Not enough data to get state.")
|
||||
|
||||
def _keypress(self, key: str, *, post = 200):
|
||||
subprocess.run([
|
||||
"xdotool",
|
||||
"key",
|
||||
"--window", self._winid,
|
||||
key
|
||||
])
|
||||
time.sleep(post / 1000)
|
||||
|
||||
def _keydown(self, key: str):
|
||||
subprocess.run([
|
||||
"xdotool",
|
||||
"keydown",
|
||||
"--window", self._winid,
|
||||
key
|
||||
])
|
||||
|
||||
def _keyup(self, key: str):
|
||||
subprocess.run([
|
||||
"xdotool",
|
||||
"keyup",
|
||||
"--window", self._winid,
|
||||
key
|
||||
])
|
||||
|
||||
def _keystring(self, string, *, delay = 100, post = 200):
|
||||
subprocess.run([
|
||||
"xdotool",
|
||||
"type",
|
||||
"--window", self._winid,
|
||||
"--delay", str(delay),
|
||||
string
|
||||
])
|
||||
time.sleep(post / 1000)
|
||||
|
||||
def reset(self):
|
||||
# Make sure all keys are released
|
||||
self.act(None)
|
||||
self._apply_keys()
|
||||
|
||||
self._internal_state = {}
|
||||
self._next_checkpoint_idx = 0
|
||||
self._state_counter = 0
|
||||
self._before_out = None
|
||||
self._resetting = True
|
||||
self._last_checkpoint_state = 0
|
||||
|
||||
self._keypress("Escape")
|
||||
self._keystring("run")
|
||||
self._keypress("Enter", post = 1000)
|
||||
|
||||
|
||||
|
||||
# Clear all old stdout messages and
|
||||
# wait for the game to restart.
|
||||
for k in iter(self._process.stdout.readline, ""):
|
||||
k = k.decode("utf-8")[:-1]
|
||||
if k == "!RESTART":
|
||||
break
|
||||
|
||||
|
||||
def update_loop(self, before, after):
|
||||
# Waits for stdout from pico-8 process
|
||||
for line in iter(self._process.stdout.readline, ""):
|
||||
l = line.decode("utf-8")[:-1].strip()
|
||||
|
||||
# Release all keys
|
||||
self.act(None)
|
||||
self._apply_keys()
|
||||
|
||||
# Clear reset state
|
||||
self._resetting = False
|
||||
|
||||
# This should only occur at game start
|
||||
if l in ["!RESTART"]:
|
||||
continue
|
||||
|
||||
self._state_counter += 1
|
||||
|
||||
# Parse state string
|
||||
for entry in l.split(";"):
|
||||
if entry == "":
|
||||
continue
|
||||
|
||||
key, val = entry.split(":")
|
||||
self._internal_state[key] = val
|
||||
|
||||
|
||||
# Update checkpoints
|
||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
x = self.state.xpos
|
||||
y = self.state.ypos
|
||||
dist = math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
# Possible modification:
|
||||
# make x-distance twice as valuable as y-distance
|
||||
)
|
||||
|
||||
if dist <= 5:
|
||||
print(f"Got point {self._next_checkpoint_idx}")
|
||||
self._next_checkpoint_idx += 1
|
||||
self._last_checkpoint_state = self._state_counter
|
||||
|
||||
# Recalculate distance to new point
|
||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
dist = math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
)
|
||||
|
||||
# Timeout if we spend too long between points
|
||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||
|
||||
self._dist = dist
|
||||
|
||||
# Call step callbacks
|
||||
# These should call celeste.act() to set next input
|
||||
if self._before_out is not None:
|
||||
after(self, self._before_out)
|
||||
|
||||
# Do not run before callback if after() triggered a reset.
|
||||
if not self._resetting:
|
||||
self._before_out = before(self)
|
||||
self._apply_keys()
|
||||
|
36
celeste/celeste_ai/network.py
Normal file
36
celeste/celeste_ai/network.py
Normal file
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DQN(torch.nn.Module):
|
||||
def __init__(self, n_observations: int, n_actions: int):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_observations, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.torch.nn.Linear(128, n_actions)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
2
celeste/celeste_ai/plotting/__init__.py
Normal file
2
celeste/celeste_ai/plotting/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .plot_actual_reward import actual_reward
|
||||
from .plot_predicted_reward import predicted_reward
|
81
celeste/celeste_ai/plotting/plot_actual_reward.py
Normal file
81
celeste/celeste_ai/plotting/plot_actual_reward.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
def actual_reward(
|
||||
model_file: Path,
|
||||
target_point: tuple[int, int],
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
|
||||
checkpoint = torch.load(
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
memory = checkpoint["memory"]
|
||||
|
||||
|
||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
for m in memory:
|
||||
x, y, x_target, y_target = list(m.state[0])
|
||||
|
||||
action = m.action[0].item()
|
||||
x = int(x.item())
|
||||
y = int(y.item())
|
||||
x_target = int(x_target.item())
|
||||
y_target = int(y_target.item())
|
||||
|
||||
# Only plot memory related to this point
|
||||
if (x_target, y_target) != target_point:
|
||||
continue
|
||||
|
||||
if m.reward[0].item() == 1:
|
||||
r[y][x][action] += 1
|
||||
else:
|
||||
r[y][x][action] -= 1
|
||||
|
||||
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
|
||||
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
r[:,:,a],
|
||||
cmap = "seismic_r",
|
||||
vmin = -10,
|
||||
vmax = 10
|
||||
)
|
||||
|
||||
# Draw target point on plot
|
||||
ax.plot(
|
||||
target_point[0],
|
||||
target_point[1],
|
||||
"k."
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
77
celeste/celeste_ai/plotting/plot_predicted_reward.py
Normal file
77
celeste/celeste_ai/plotting/plot_predicted_reward.py
Normal file
@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
|
||||
def predicted_reward(
|
||||
model_file: Path,
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Create and load model
|
||||
policy_net = DQN(
|
||||
len(Celeste.state_number_map),
|
||||
len(Celeste.action_space)
|
||||
).to(device)
|
||||
checkpoint = torch.load(
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
|
||||
# Compute preditions
|
||||
p = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
for r in range(len(p)):
|
||||
for c in range(len(p[r])):
|
||||
k = np.asarray(policy_net(
|
||||
torch.tensor(
|
||||
[c, r, 60, 80],
|
||||
dtype = torch.float32,
|
||||
device = device
|
||||
).unsqueeze(0)
|
||||
)[0])
|
||||
p[r][c] = k
|
||||
|
||||
|
||||
# Plot predictions
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
p[:,:,a],
|
||||
cmap = "Greens",
|
||||
vmin = 0,
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
449
celeste/celeste_ai/train.py
Normal file
449
celeste/celeste_ai/train.py
Normal file
@ -0,0 +1,449 @@
|
||||
from collections import namedtuple
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
import random
|
||||
import math
|
||||
import json
|
||||
import torch
|
||||
|
||||
from celeste_ai import Celeste
|
||||
from celeste_ai import DQN
|
||||
from celeste_ai import Transition
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
model_train_log = model_data_root / "train_log"
|
||||
screenshot_dir = model_data_root / "screenshots"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
|
||||
|
||||
# Epsilon-greedy parameters
|
||||
#
|
||||
# Original docs:
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
EPS_START = 0.9
|
||||
EPS_END = 0.05
|
||||
EPS_DECAY = 4000
|
||||
|
||||
|
||||
BATCH_SIZE = 1_000
|
||||
# Learning rate of target_net.
|
||||
# Controls how soft our soft update is.
|
||||
#
|
||||
# Should be between 0 and 1.
|
||||
# Large values
|
||||
# Small values do the opposite.
|
||||
#
|
||||
# A value of one makes target_net
|
||||
# change at the same rate as policy_net.
|
||||
#
|
||||
# A value of zero makes target_net
|
||||
# not change at all.
|
||||
TAU = 0.005
|
||||
|
||||
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
GAMMA = 0.9
|
||||
|
||||
steps_done = 0
|
||||
num_episodes = 100
|
||||
episode_number = 0
|
||||
archive_interval = 10
|
||||
|
||||
# Create replay memory.
|
||||
#
|
||||
# Transition: a container for naming data (defined in util.py)
|
||||
# Memory: a deque that holds recent states as Transitions
|
||||
# Has a fixed length, drops oldest
|
||||
# element if maxlen is exceeded.
|
||||
memory = deque([], maxlen=50_000)
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
|
||||
target_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
policy_net.parameters(),
|
||||
lr = 0.01, # Hyperparameter: learning rate
|
||||
amsgrad = True
|
||||
)
|
||||
|
||||
|
||||
if model_save_path.is_file():
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(
|
||||
model_save_path,
|
||||
map_location = compute_device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
memory = checkpoint["memory"]
|
||||
episode_number = checkpoint["episode_number"] + 1
|
||||
steps_done = checkpoint["steps_done"]
|
||||
|
||||
def select_action(state, steps_done):
|
||||
"""
|
||||
Select an action using an epsilon-greedy policy.
|
||||
|
||||
Sometimes use our model, sometimes sample one uniformly.
|
||||
|
||||
P(random action) starts at EPS_START and decays to EPS_END.
|
||||
Decay rate is controlled by EPS_DECAY.
|
||||
"""
|
||||
|
||||
# Random number 0 <= x < 1
|
||||
sample = random.random()
|
||||
|
||||
# Calculate random step threshhold
|
||||
eps_threshold = (
|
||||
EPS_END + (EPS_START - EPS_END) *
|
||||
math.exp(
|
||||
-1.0 * steps_done /
|
||||
EPS_DECAY
|
||||
)
|
||||
)
|
||||
|
||||
if sample > eps_threshold:
|
||||
with torch.no_grad():
|
||||
# t.max(1) will return the largest column value of each row.
|
||||
# second column on max result is index of where max element was
|
||||
# found, so we pick action with the larger expected reward.
|
||||
return policy_net(state).max(1)[1].view(1, 1).item()
|
||||
|
||||
else:
|
||||
return random.randint( 0, n_actions-1 )
|
||||
|
||||
|
||||
def optimize_model():
|
||||
|
||||
if len(memory) < BATCH_SIZE:
|
||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
||||
|
||||
|
||||
|
||||
# Get a random sample of transitions
|
||||
batch = random.sample(memory, BATCH_SIZE)
|
||||
|
||||
# Conversion.
|
||||
# Transposes batch, turning an array of Transitions
|
||||
# into a Transition of arrays.
|
||||
batch = Transition(*zip(*batch))
|
||||
|
||||
# Conversion.
|
||||
# Combine states, actions, and rewards into their own tensors.
|
||||
state_batch = torch.cat(batch.state)
|
||||
action_batch = torch.cat(batch.action)
|
||||
reward_batch = torch.cat(batch.reward)
|
||||
|
||||
|
||||
|
||||
# Compute a mask of non_final_states.
|
||||
# Each element of this tensor corresponds to an element in the batch.
|
||||
# True if this is a final state, False if it isn't.
|
||||
#
|
||||
# We use this to select non-final states later.
|
||||
non_final_mask = torch.tensor(
|
||||
tuple(map(
|
||||
lambda s: s is not None,
|
||||
batch.next_state
|
||||
))
|
||||
)
|
||||
|
||||
non_final_next_states = torch.cat(
|
||||
[s for s in batch.next_state if s is not None]
|
||||
)
|
||||
|
||||
|
||||
|
||||
# How .gather works:
|
||||
# if out = a.gather(1, b),
|
||||
# out[i, j] = a[ i ][ b[i,j] ]
|
||||
#
|
||||
# a is "input," b is "index"
|
||||
# If this doesn't make sense, RTFD.
|
||||
|
||||
# Compute Q(s_t, a).
|
||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
||||
#
|
||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
||||
# listing the action that was taken in each transition.
|
||||
# 0 => we went left, 1 => we went right.
|
||||
#
|
||||
# This aligns nicely with the output of policy_net. We use
|
||||
# action_batch to index the output of policy_net's prediction.
|
||||
#
|
||||
# This gives us a tensor that contains the return we expect to get
|
||||
# at that state if we follow the model's advice.
|
||||
|
||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
||||
|
||||
|
||||
|
||||
# Compute V(s_t+1) for all next states.
|
||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
||||
# = the maximum reward over all possible actions at state s_t+1.
|
||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
||||
|
||||
# Don't compute gradient for operations in this block.
|
||||
# If you don't understand what this means, RTFD.
|
||||
with torch.no_grad():
|
||||
|
||||
# Note the use of non_final_mask here.
|
||||
# States that are final do not have their reward set by the line
|
||||
# below, so their reward stays at zero.
|
||||
#
|
||||
# States that are not final get their predicted value
|
||||
# set to the best value the model predicts.
|
||||
#
|
||||
#
|
||||
# Expected values of action are selected with the "older" target net,
|
||||
# and their best reward (over possible actions) is selected with max(1)[0].
|
||||
|
||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
||||
|
||||
|
||||
# TODO: What does this mean?
|
||||
# "Compute expected Q values"
|
||||
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
||||
|
||||
|
||||
|
||||
# Compute Huber loss between predicted reward and expected reward.
|
||||
# Pytorch is will account for this when we compute the gradient of loss.
|
||||
#
|
||||
# loss is a single-element tensor (i.e, a scalar).
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
loss = criterion(
|
||||
state_action_values,
|
||||
expected_state_action_values.unsqueeze(1)
|
||||
)
|
||||
|
||||
|
||||
# We can now run a step of backpropagation on our model.
|
||||
|
||||
# TODO: what does this do?
|
||||
#
|
||||
# Calling .backward() multiple times will accumulate parameter gradients.
|
||||
# Thus, we reset the gradient before each step.
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Compute the gradient of loss wrt... something?
|
||||
# TODO: what does this do, we never use loss again?!
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Prevent vanishing and exploding gradients.
|
||||
# Forces gradients to be in [-clip_value, +clip_value]
|
||||
torch.nn.utils.clip_grad_value_( # type: ignore
|
||||
policy_net.parameters(),
|
||||
clip_value = 100
|
||||
)
|
||||
|
||||
# Perform a single optimizer step.
|
||||
#
|
||||
# Uses the current gradient, which is stored
|
||||
# in the .grad attribute of the parameter.
|
||||
optimizer.step()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
# Conversion to pytorch
|
||||
|
||||
state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
[getattr(state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
action = None
|
||||
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
||||
action = select_action(
|
||||
pt_state,
|
||||
steps_done
|
||||
)
|
||||
str_action = Celeste.action_space[action]
|
||||
steps_done += 1
|
||||
|
||||
|
||||
# For manual testing
|
||||
#str_action = ""
|
||||
#while str_action not in Celeste.action_space:
|
||||
# str_action = input("action> ")
|
||||
#action = Celeste.action_space.index(str_action)
|
||||
|
||||
print(str_action)
|
||||
celeste.act(str_action)
|
||||
|
||||
return state, action
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
|
||||
state, action = before_out
|
||||
next_state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
[getattr(state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
pt_action = torch.tensor(
|
||||
[[ action ]],
|
||||
device = compute_device,
|
||||
dtype = torch.long
|
||||
)
|
||||
|
||||
if next_state.deaths != 0:
|
||||
pt_next_state = None
|
||||
reward = 0
|
||||
|
||||
else:
|
||||
pt_next_state = torch.tensor(
|
||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
if state.next_point == next_state.next_point:
|
||||
reward = state.dist - next_state.dist
|
||||
|
||||
# Clip rewards that are too large
|
||||
if reward > 1:
|
||||
reward = 1
|
||||
else:
|
||||
reward = 0
|
||||
|
||||
else:
|
||||
# Reward for reaching a point
|
||||
reward = 1
|
||||
|
||||
pt_reward = torch.tensor([reward], device = compute_device)
|
||||
|
||||
|
||||
# Add this state transition to memory.
|
||||
memory.append(
|
||||
Transition(
|
||||
pt_state, # last state
|
||||
pt_action,
|
||||
pt_next_state, # next state
|
||||
pt_reward
|
||||
)
|
||||
)
|
||||
|
||||
print("==> ", int(reward))
|
||||
print("")
|
||||
|
||||
|
||||
loss = None
|
||||
# Only train the network if we have enough
|
||||
# transitions in memory to do so.
|
||||
if len(memory) >= BATCH_SIZE:
|
||||
loss = optimize_model()
|
||||
|
||||
# Soft update target_net weights
|
||||
target_net_state = target_net.state_dict()
|
||||
policy_net_state = policy_net.state_dict()
|
||||
for key in policy_net_state:
|
||||
target_net_state[key] = (
|
||||
policy_net_state[key] * TAU +
|
||||
target_net_state[key] * (1-TAU)
|
||||
)
|
||||
target_net.load_state_dict(target_net_state)
|
||||
|
||||
# Move on to the next episode once we reach
|
||||
# a terminal state.
|
||||
if (next_state.deaths != 0):
|
||||
s = celeste.state
|
||||
with model_train_log.open("a") as f:
|
||||
f.write(json.dumps({
|
||||
"checkpoints": s.next_point,
|
||||
"state_count": s.state_count,
|
||||
"loss": None if loss is None else loss.item()
|
||||
}) + "\n")
|
||||
|
||||
|
||||
# Save model
|
||||
torch.save({
|
||||
"policy_state_dict": policy_net.state_dict(),
|
||||
"target_state_dict": target_net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"memory": memory,
|
||||
"episode_number": episode_number,
|
||||
"steps_done": steps_done
|
||||
}, model_save_path)
|
||||
|
||||
|
||||
# Clean up screenshots
|
||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||
|
||||
target = screenshot_dir / Path(f"{episode_number}")
|
||||
target.mkdir(parents = True)
|
||||
|
||||
for s in shots:
|
||||
s.rename(target / s.name)
|
||||
|
||||
# Save a prediction graph
|
||||
if episode_number % archive_interval == 0:
|
||||
torch.save({
|
||||
"policy_state_dict": policy_net.state_dict(),
|
||||
"target_state_dict": target_net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"memory": memory,
|
||||
"episode_number": episode_number,
|
||||
"steps_done": steps_done
|
||||
}, model_archive_dir / f"{episode_number}.torch")
|
||||
|
||||
|
||||
print("Game over. Resetting.")
|
||||
episode_number += 1
|
||||
celeste.reset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Celeste(
|
||||
"resources/pico-8/linux/pico8"
|
||||
)
|
||||
|
||||
c.update_loop(
|
||||
on_state_before,
|
||||
on_state_after
|
||||
)
|
Reference in New Issue
Block a user