Compare commits
21 Commits
f40b58508e
...
master
Author | SHA1 | Date | |
---|---|---|---|
5242c7443c | |||
7759db9914
|
|||
571a337ff4
|
|||
058292c0bd
|
|||
8420e719d8
|
|||
6b7abc49a6
|
|||
ee232329b7
|
|||
3d25d63efe
|
|||
05d745cc07
|
|||
072867a7a3
|
|||
938398f9b1
|
|||
55a8a9d7cf
|
|||
25ad663eec
|
|||
4171fefb00
|
|||
24dd65ace8
|
|||
755495a992
|
|||
3745346c5b
|
|||
25390f5455
|
|||
c185965657
|
|||
03135e2ef9
|
|||
0b61702677
|
40
README.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Celeste-AI: A Celeste Classic DQL Agent
|
||||||
|
|
||||||
|
This is an attempt to create a deep Q-learning agent that automatically solves the first stage of Celeste Classic.
|
||||||
|
|
||||||
|
A gif of the result is below. This took 4000 episodes, which amounts to about 30 hours of training time.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
- `./resources`: contains files this script requires. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
|
||||||
|
- `ffmpeg.sh`: uses game screenshots to make real-time video of the agent's attempts. Read the script, it's pretty simple.
|
||||||
|
- `plot.py`: generates plots from model snapshots. These are placed in `model_data/current/plots/`.
|
||||||
|
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Before you set up Celeste-AI, you need to prepare PICO-8. See [`resources/README.md`](./resources/README.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
|
||||||
|
|
||||||
|
1. `cd` into this directory
|
||||||
|
2. Make and enter a venv
|
||||||
|
3. `pip install -e .`
|
||||||
|
|
||||||
|
Once you're set up, you can...
|
||||||
|
- `python celeste_ai/train.py` to train a model
|
||||||
|
- `python plot.py` to make prediction plots
|
||||||
|
- `python test.py` to test a model
|
||||||
|
|
||||||
|
|
||||||
|
**Before running, be aware of the following:**
|
||||||
|
- Only one instance of PICO-8 can be running at a time. See `celeste.py`.
|
||||||
|
- `hackcel.p8` captures a screenshot of every frame. PICO-8 will probably place these on your desktop. Since this repo contains a rather old version of PICO-8, there is no way to change where it places screenshots. `train.py` will delete, move, and rename screenshots automatically during training, but you should tell it where your desktop is first.
|
||||||
|
- When you start training, a `model_data` directory will be created. It contains the following:
|
||||||
|
- `model_archive`: history of the model. Save interval is configured inside `train.py`
|
||||||
|
- `screenshots`: contains subdirectories. Each subdirectory contains the frames of one episode. Use `ffmpeg.sh` to turn these into a video.
|
||||||
|
- `plots`: generated by `plot.py`. Contains pretty plots.
|
@ -1,15 +0,0 @@
|
|||||||
# Celeste-AI: A Celeste Classic DQL Agent
|
|
||||||
|
|
||||||
This is an attempt to create an agent that learns to play Celeste Classic.
|
|
||||||
|
|
||||||
## Contents
|
|
||||||
- `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
|
|
||||||
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
|
|
||||||
|
|
||||||
1. `cd` into this directory
|
|
||||||
2. Make and enter a venv
|
|
||||||
3. `pip install -e .`
|
|
@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
# All of the following are required to load
|
|
||||||
# a pickled model.
|
|
||||||
from celeste_ai.celeste import Celeste
|
|
||||||
from celeste_ai.network import DQN
|
|
||||||
from celeste_ai.network import Transition
|
|
||||||
|
|
||||||
def actual_reward(
|
|
||||||
model_file: Path,
|
|
||||||
target_point: tuple[int, int],
|
|
||||||
out_filename: Path,
|
|
||||||
*,
|
|
||||||
device = torch.device("cpu")
|
|
||||||
):
|
|
||||||
if not model_file.is_file():
|
|
||||||
raise Exception(f"Bad model file {model_file}")
|
|
||||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = torch.load(
|
|
||||||
model_file,
|
|
||||||
map_location = device
|
|
||||||
)
|
|
||||||
memory = checkpoint["memory"]
|
|
||||||
|
|
||||||
|
|
||||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
|
||||||
for m in memory:
|
|
||||||
x, y, x_target, y_target = list(m.state[0])
|
|
||||||
|
|
||||||
action = m.action[0].item()
|
|
||||||
x = int(x.item())
|
|
||||||
y = int(y.item())
|
|
||||||
x_target = int(x_target.item())
|
|
||||||
y_target = int(y_target.item())
|
|
||||||
|
|
||||||
# Only plot memory related to this point
|
|
||||||
if (x_target, y_target) != target_point:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.reward[0].item() == 1:
|
|
||||||
r[y][x][action] += 1
|
|
||||||
else:
|
|
||||||
r[y][x][action] -= 1
|
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
|
||||||
|
|
||||||
|
|
||||||
for a in range(len(axs.ravel())):
|
|
||||||
ax = axs.ravel()[a]
|
|
||||||
ax.set(
|
|
||||||
adjustable = "box",
|
|
||||||
aspect = "equal",
|
|
||||||
title = Celeste.action_space[a]
|
|
||||||
)
|
|
||||||
|
|
||||||
plot = ax.pcolor(
|
|
||||||
r[:,:,a],
|
|
||||||
cmap = "seismic_r",
|
|
||||||
vmin = -10,
|
|
||||||
vmax = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw target point on plot
|
|
||||||
ax.plot(
|
|
||||||
target_point[0],
|
|
||||||
target_point[1],
|
|
||||||
"k."
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
|
||||||
plt.close()
|
|
@ -1,65 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
|
|
||||||
# Where screenshots are saved
|
|
||||||
SC_ROOT="model_data/current"
|
|
||||||
|
|
||||||
# WILL BE DELETED
|
|
||||||
OUTPUT_DIR="model_data/video_output"
|
|
||||||
|
|
||||||
|
|
||||||
# To make with fade in and out:
|
|
||||||
# ffmpeg -framerate 30 -i %03d.png -vf "scale=1024x1024:flags=neighbor,fade=in:0:45,fade=out:1040:45" out.webm
|
|
||||||
|
|
||||||
render_dir () {
|
|
||||||
ffmpeg \
|
|
||||||
-y \
|
|
||||||
-loglevel quiet \
|
|
||||||
-framerate 30 \
|
|
||||||
-i $1/hackcel_%003d.png \
|
|
||||||
-c:v libx264 \
|
|
||||||
-pix_fmt yuv420p \
|
|
||||||
$OUTPUT_DIR/${1##*/}.mp4
|
|
||||||
}
|
|
||||||
|
|
||||||
# Todo: error out if exists
|
|
||||||
mkdir -p $OUTPUT_DIR
|
|
||||||
|
|
||||||
|
|
||||||
echo "Making episode files..."
|
|
||||||
for D in $SC_ROOT/screenshots/*; do
|
|
||||||
if [ -d "${D}" ]; then
|
|
||||||
render_dir $D
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "Done."
|
|
||||||
|
|
||||||
|
|
||||||
# Generate video for each run
|
|
||||||
for f in $OUTPUT_DIR/*.mp4; do
|
|
||||||
echo file \'$f\' >> video_merge_list
|
|
||||||
done
|
|
||||||
|
|
||||||
# Merge videos
|
|
||||||
ffmpeg \
|
|
||||||
-loglevel error -stats -y \
|
|
||||||
-f concat \
|
|
||||||
-safe 0 \
|
|
||||||
-i video_merge_list \
|
|
||||||
-vf "scale=1024x1024:flags=neighbor" \
|
|
||||||
$SC_ROOT/1x.mp4
|
|
||||||
|
|
||||||
rm video_merge_list
|
|
||||||
|
|
||||||
# Make accelerated video
|
|
||||||
ffmpeg \
|
|
||||||
-loglevel error -stats -y \
|
|
||||||
-i $SC_ROOT/1x.mp4 \
|
|
||||||
-framerate 60 \
|
|
||||||
-filter:v "setpts=0.125*PTS" \
|
|
||||||
$SC_ROOT/8x.mp4
|
|
||||||
|
|
||||||
echo "Cleaning up..."
|
|
||||||
|
|
||||||
rm -dr $OUTPUT_DIR
|
|
@ -1,72 +0,0 @@
|
|||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import celeste_ai.plotting as plotting
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
m = Path("model_data/current")
|
|
||||||
|
|
||||||
# Make "predicted reward" plots
|
|
||||||
def plot_pred(src_model):
|
|
||||||
plotting.predicted_reward(
|
|
||||||
src_model,
|
|
||||||
m / f"plots/predicted/{src_model.stem}.png",
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make "best action" plots
|
|
||||||
def plot_best(src_model):
|
|
||||||
plotting.best_action(
|
|
||||||
src_model,
|
|
||||||
m / f"plots/best_action/{src_model.stem}.png",
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make "actual reward" plots
|
|
||||||
def plot_act(src_model):
|
|
||||||
plotting.actual_reward(
|
|
||||||
src_model,
|
|
||||||
(60, 80),
|
|
||||||
m / f"plots/actual/{src_model.stem}.png",
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Which plots should we make?
|
|
||||||
plots = {
|
|
||||||
"prediction": True,
|
|
||||||
"actual": False,
|
|
||||||
"best": True
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
if plots["prediction"]:
|
|
||||||
print("Making prediction plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_pred,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["best"]:
|
|
||||||
print("Making best-action plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_best,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["actual"]:
|
|
||||||
print("Making actual plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_act,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
@ -70,21 +70,24 @@ class Celeste:
|
|||||||
#"ypos",
|
#"ypos",
|
||||||
"xpos_scaled",
|
"xpos_scaled",
|
||||||
"ypos_scaled",
|
"ypos_scaled",
|
||||||
"can_dash_int"
|
#"can_dash_int"
|
||||||
#"next_point_x",
|
#"next_point_x",
|
||||||
#"next_point_y"
|
#"next_point_y"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Targets the agent tries to reach.
|
# Targets the agent tries to reach.
|
||||||
# The last target MUST be outside the frame.
|
# The last target MUST be outside the frame.
|
||||||
|
# Format is X, Y, range, force_y
|
||||||
|
# force_y is optional. If true, y_value MUST match perfectly.
|
||||||
target_checkpoints = [
|
target_checkpoints = [
|
||||||
[ # Stage 1
|
[ # Stage 1
|
||||||
#(28, 88), # Start pillar
|
#(28, 88, 8), # Start pillar
|
||||||
(60, 80), # Middle pillar
|
(60, 80, 8), # Middle pillar
|
||||||
(105, 64), # Right ledge
|
(105, 64, 8), # Right ledge
|
||||||
(25, 40), # Left ledge
|
(25, 40, 8), # Left ledge
|
||||||
(110, 16), # End ledge
|
(97, 24, 5, True), # Small end ledge
|
||||||
(110, -2), # Next stage
|
(110, 16, 8), # End ledge
|
||||||
|
(110, -20, 8), # Next stage
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -99,7 +102,7 @@ class Celeste:
|
|||||||
self,
|
self,
|
||||||
pico_path,
|
pico_path,
|
||||||
*,
|
*,
|
||||||
state_timeout = 30,
|
state_timeout = 20,
|
||||||
cart_name = "hackcel.p8",
|
cart_name = "hackcel.p8",
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -144,7 +147,7 @@ class Celeste:
|
|||||||
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||||
self._keys = {} # Dictionary of "key": bool
|
self._keys = {} # Dictionary of "key": bool
|
||||||
|
|
||||||
def act(self, action: str):
|
def act(self, action: str | int):
|
||||||
"""
|
"""
|
||||||
Specify what keys should be down. This does NOT send key events.
|
Specify what keys should be down. This does NOT send key events.
|
||||||
Celeste._apply_keys() does that at the right time.
|
Celeste._apply_keys() does that at the right time.
|
||||||
@ -153,6 +156,9 @@ class Celeste:
|
|||||||
action (str): key name, as in Celeste.action_space
|
action (str): key name, as in Celeste.action_space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(action, int):
|
||||||
|
action = Celeste.action_space[action]
|
||||||
|
|
||||||
self._keys = {}
|
self._keys = {}
|
||||||
if action is None:
|
if action is None:
|
||||||
return
|
return
|
||||||
@ -208,9 +214,9 @@ class Celeste:
|
|||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(Celeste.target_checkpoints) < stage:
|
if len(Celeste.target_checkpoints) <= stage:
|
||||||
next_point_x = None
|
next_point_x = 0
|
||||||
next_point_y = None
|
next_point_y = 0
|
||||||
else:
|
else:
|
||||||
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||||
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||||
@ -329,46 +335,65 @@ class Celeste:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.state.stage <= 0:
|
||||||
|
# Calculate distance to each point
|
||||||
|
x = self.state.xpos
|
||||||
|
y = self.state.ypos
|
||||||
|
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||||
|
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||||
|
if i < self._next_checkpoint_idx:
|
||||||
|
dist[i] = 1000
|
||||||
|
continue
|
||||||
|
|
||||||
# Calculate distance to each point
|
# Update checkpoints
|
||||||
x = self.state.xpos
|
tx, ty = c[:2]
|
||||||
y = self.state.ypos
|
dist[i] = (math.sqrt(
|
||||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
(x-tx)*(x-tx) +
|
||||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
((y-ty)*(y-ty))/2
|
||||||
if i < self._next_checkpoint_idx:
|
# Possible modification:
|
||||||
dist[i] = 1000
|
# make x-distance twice as valuable as y-distance
|
||||||
continue
|
))
|
||||||
|
min_idx = int(dist.argmin())
|
||||||
|
dist = int(dist[min_idx])
|
||||||
|
|
||||||
# Update checkpoints
|
|
||||||
tx, ty = c
|
t = Celeste.target_checkpoints[self.state.stage][min_idx]
|
||||||
dist[i] = (math.sqrt(
|
range = t[2]
|
||||||
(x-tx)*(x-tx) +
|
if len(t) == 3:
|
||||||
((y-ty)*(y-ty))/2
|
force_y = False
|
||||||
# Possible modification:
|
else:
|
||||||
# make x-distance twice as valuable as y-distance
|
force_y = t[3]
|
||||||
))
|
|
||||||
min_idx = int(dist.argmin())
|
|
||||||
dist = int(dist[min_idx])
|
|
||||||
|
|
||||||
|
if force_y:
|
||||||
|
got_point = (
|
||||||
|
dist <= range and
|
||||||
|
y == t[1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
got_point = dist <= range
|
||||||
|
|
||||||
if dist <= 8:
|
if got_point:
|
||||||
print(f"Got point {min_idx}")
|
self._next_checkpoint_idx = min_idx + 1
|
||||||
self._next_checkpoint_idx = min_idx + 1
|
self._last_checkpoint_state = self._state_counter
|
||||||
self._last_checkpoint_state = self._state_counter
|
|
||||||
|
|
||||||
# Recalculate distance to new point
|
# Recalculate distance to new point
|
||||||
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
tx, ty = (
|
||||||
dist = math.sqrt(
|
Celeste.target_checkpoints
|
||||||
(x-tx)*(x-tx) +
|
[self.state.stage]
|
||||||
((y-ty)*(y-ty))/2
|
[self._next_checkpoint_idx]
|
||||||
)
|
[:2]
|
||||||
|
)
|
||||||
|
dist = math.sqrt(
|
||||||
|
(x-tx)*(x-tx) +
|
||||||
|
((y-ty)*(y-ty))/2
|
||||||
|
)
|
||||||
|
|
||||||
# Timeout if we spend too long between points
|
# Timeout if we spend too long between points
|
||||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||||
|
|
||||||
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
|
|
||||||
# Call step callbacks
|
# Call step callbacks
|
||||||
# These should call celeste.act() to set next input
|
# These should call celeste.act() to set next input
|
@ -5,7 +5,7 @@ from collections import namedtuple
|
|||||||
Transition = namedtuple(
|
Transition = namedtuple(
|
||||||
"Transition",
|
"Transition",
|
||||||
(
|
(
|
||||||
"state",
|
"last_state",
|
||||||
"action",
|
"action",
|
||||||
"next_state",
|
"next_state",
|
||||||
"reward"
|
"reward"
|
@ -1,4 +1,3 @@
|
|||||||
from .plot_actual_reward import actual_reward
|
|
||||||
from .plot_predicted_reward import predicted_reward
|
from .plot_predicted_reward import predicted_reward
|
||||||
from .plot_best_action import best_action
|
from .plot_best_action import best_action
|
||||||
|
|
@ -1,7 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import matplotlib as mpl
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
|
||||||
# All of the following are required to load
|
# All of the following are required to load
|
||||||
# a pickled model.
|
# a pickled model.
|
||||||
@ -14,7 +16,8 @@ def best_action(
|
|||||||
model_file: Path,
|
model_file: Path,
|
||||||
out_filename: Path,
|
out_filename: Path,
|
||||||
*,
|
*,
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu"),
|
||||||
|
draw_path = True
|
||||||
):
|
):
|
||||||
if not model_file.is_file():
|
if not model_file.is_file():
|
||||||
raise Exception(f"Bad model file {model_file}")
|
raise Exception(f"Bad model file {model_file}")
|
||||||
@ -34,7 +37,7 @@ def best_action(
|
|||||||
|
|
||||||
|
|
||||||
# Compute preditions
|
# Compute preditions
|
||||||
p = np.zeros((128, 128, 2), dtype=np.float32)
|
p = np.zeros((128, 128), dtype=np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for r in range(len(p)):
|
for r in range(len(p)):
|
||||||
for c in range(len(p[r])):
|
for c in range(len(p[r])):
|
||||||
@ -43,26 +46,31 @@ def best_action(
|
|||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y, 0],
|
[x, y],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
)[0])
|
)[0])
|
||||||
p[r][c][0] = np.argmax(k)
|
p[r][c] = np.argmax(k)
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
|
||||||
torch.tensor(
|
|
||||||
[x, y, 1],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = device
|
|
||||||
).unsqueeze(0)
|
|
||||||
)[0])
|
|
||||||
p[r][c][1] = np.argmax(k)
|
|
||||||
|
|
||||||
|
cmap = mpl.colors.ListedColormap(
|
||||||
|
[
|
||||||
|
"forestgreen",
|
||||||
|
"firebrick",
|
||||||
|
"lightgreen",
|
||||||
|
"salmon",
|
||||||
|
"darkturquoise",
|
||||||
|
"sandybrown",
|
||||||
|
"olive",
|
||||||
|
"darkorchid",
|
||||||
|
"mediumvioletred"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Plot predictions
|
# Plot predictions
|
||||||
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
|
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
|
||||||
ax = axs[0]
|
ax = axs
|
||||||
ax.set(
|
ax.set(
|
||||||
adjustable = "box",
|
adjustable = "box",
|
||||||
aspect = "equal",
|
aspect = "equal",
|
||||||
@ -70,30 +78,38 @@ def best_action(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot = ax.pcolor(
|
plot = ax.pcolor(
|
||||||
p[:,:,0],
|
p,
|
||||||
cmap = "Set1",
|
cmap = cmap,
|
||||||
vmin = 0,
|
|
||||||
vmax = 8
|
|
||||||
)
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
|
||||||
|
|
||||||
ax = axs[1]
|
|
||||||
ax.set(
|
|
||||||
adjustable = "box",
|
|
||||||
aspect = "equal",
|
|
||||||
title = "Best Action"
|
|
||||||
)
|
|
||||||
|
|
||||||
plot = ax.pcolor(
|
|
||||||
p[:,:,0],
|
|
||||||
cmap = "Set1",
|
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
vmax = 8
|
vmax = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if draw_path:
|
||||||
|
d = None
|
||||||
|
with Path("model_data/solved_4layer/paths.json").open("r") as f:
|
||||||
|
for l in f.readlines():
|
||||||
|
t = json.loads(l)
|
||||||
|
if t["current_image"] == model_file.name:
|
||||||
|
break
|
||||||
|
d = t
|
||||||
|
assert d is not None
|
||||||
|
|
||||||
|
plt.plot(
|
||||||
|
[max(0,x["xpos"]) for x in d["hist"]],
|
||||||
|
[max(0,x["ypos"] + 5) for x in d["hist"]],
|
||||||
|
marker = "",
|
||||||
|
markersize = 0,
|
||||||
|
linestyle = "-",
|
||||||
|
linewidth = 5,
|
||||||
|
color = "white",
|
||||||
|
solid_capstyle = "round",
|
||||||
|
solid_joinstyle = "round"
|
||||||
|
)
|
||||||
|
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
fig.colorbar(plot)
|
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
||||||
|
cbar.ax.set_yticklabels(Celeste.action_space)
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
fig.savefig(out_filename)
|
||||||
plt.close()
|
plt.close()
|
@ -43,7 +43,7 @@ def predicted_reward(
|
|||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y, 0],
|
[x, y],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
119
celeste_ai/record_paths.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
k = (model_data_root / "model_archive").iterdir()
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
state_history = []
|
||||||
|
current_path = None
|
||||||
|
|
||||||
|
def next_image():
|
||||||
|
global policy_net
|
||||||
|
global current_path
|
||||||
|
global i
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_path = k.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Pathing {current_path} ({i})")
|
||||||
|
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load(
|
||||||
|
current_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
next_image()
|
||||||
|
|
||||||
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
celeste.act(str_action)
|
||||||
|
|
||||||
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
global state_history
|
||||||
|
|
||||||
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
finished_stage = next_state.stage >= 1
|
||||||
|
|
||||||
|
state_history.append({
|
||||||
|
"xpos": state.xpos,
|
||||||
|
"ypos": state.ypos,
|
||||||
|
"action": Celeste.action_space[action]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
|
|
||||||
|
with (model_data_root / "paths.json").open("a") as f:
|
||||||
|
f.write(json.dumps(
|
||||||
|
{
|
||||||
|
"hist": state_history,
|
||||||
|
"current_image": current_path.name
|
||||||
|
}
|
||||||
|
) + "\n")
|
||||||
|
|
||||||
|
state_history = []
|
||||||
|
k = next_image()
|
||||||
|
|
||||||
|
if k is False:
|
||||||
|
raise Exception("Done.")
|
||||||
|
|
||||||
|
print("Game over. Resetting.")
|
||||||
|
celeste.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop(
|
||||||
|
on_state_before,
|
||||||
|
on_state_after
|
||||||
|
)
|
100
celeste_ai/test.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
from celeste_ai.util.screenshots import ScreenshotManager
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Where to read/write model data.
|
||||||
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
|
model_save_path = model_data_root / "model.torch"
|
||||||
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
sm = ScreenshotManager(
|
||||||
|
# Where PICO-8 saves screenshots.
|
||||||
|
# Probably your desktop.
|
||||||
|
source = Path("/home/mark/Desktop"),
|
||||||
|
pattern = "hackcel_*.png",
|
||||||
|
target = model_data_root / "screenshots_test"
|
||||||
|
).clean() # Remove old screenshots
|
||||||
|
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_number = 0
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load(
|
||||||
|
model_save_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
print(str_action)
|
||||||
|
celeste.act(str_action)
|
||||||
|
|
||||||
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
|
||||||
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
finished_stage = next_state.stage >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
|
s = celeste.state
|
||||||
|
|
||||||
|
sm.move()
|
||||||
|
|
||||||
|
|
||||||
|
print("Game over. Resetting.")
|
||||||
|
celeste.reset()
|
||||||
|
episode_number += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop(
|
||||||
|
on_state_before,
|
||||||
|
on_state_after
|
||||||
|
)
|
@ -5,33 +5,31 @@ import random
|
|||||||
import math
|
import math
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
from celeste_ai import Celeste
|
from celeste_ai import Celeste
|
||||||
from celeste_ai import DQN
|
from celeste_ai import DQN
|
||||||
from celeste_ai import Transition
|
from celeste_ai import Transition
|
||||||
|
from celeste_ai.util.screenshots import ScreenshotManager
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Where to read/write model data.
|
# Where to read/write model data.
|
||||||
model_data_root = Path("model_data/current")
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
# Where PICO-8 saves screenshots.
|
sm = ScreenshotManager(
|
||||||
# Probably your desktop.
|
# Where PICO-8 saves screenshots.
|
||||||
screenshot_source = Path("/home/mark/Desktop")
|
# Probably your desktop.
|
||||||
|
source = Path("/home/mark/Desktop"),
|
||||||
|
pattern = "hackcel_*.png",
|
||||||
|
target = model_data_root / "screenshots"
|
||||||
|
).clean() # Remove old screenshots
|
||||||
|
|
||||||
model_save_path = model_data_root / "model.torch"
|
model_save_path = model_data_root / "model.torch"
|
||||||
model_archive_dir = model_data_root / "model_archive"
|
model_archive_dir = model_data_root / "model_archive"
|
||||||
model_train_log = model_data_root / "train_log"
|
model_train_log = model_data_root / "train_log"
|
||||||
screenshot_dir = model_data_root / "screenshots"
|
|
||||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
|
||||||
|
|
||||||
|
|
||||||
# Remove old screenshots
|
|
||||||
shots = screenshot_source.glob("hackcel_*.png")
|
|
||||||
for s in shots:
|
|
||||||
s.unlink()
|
|
||||||
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
@ -45,66 +43,51 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
|
|
||||||
# Epsilon-greedy parameters
|
# Epsilon-greedy parameters
|
||||||
#
|
# Probability of choosing a random action starts at
|
||||||
# Original docs:
|
# EPS_START and decays to EPS_END.
|
||||||
# EPS_START is the starting value of epsilon
|
# EPS_DECAY controls the rate of decay.
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
EPS_START = 0.9
|
EPS_START = 0.9
|
||||||
EPS_END = 0.02
|
EPS_END = 0.02
|
||||||
EPS_DECAY = 100
|
EPS_DECAY = 100
|
||||||
|
|
||||||
# How many times we've reached each point.
|
# Bellman equation time-discount factor
|
||||||
# Used to compute epsilon-greedy probability with
|
|
||||||
# the parameters above.
|
|
||||||
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
|
||||||
|
|
||||||
BATCH_SIZE = 100
|
|
||||||
# Learning rate of target_net.
|
|
||||||
# Controls how soft our soft update is.
|
|
||||||
#
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
# Large values
|
|
||||||
# Small values do the opposite.
|
|
||||||
#
|
|
||||||
# A value of one makes target_net
|
|
||||||
# change at the same rate as policy_net.
|
|
||||||
#
|
|
||||||
# A value of zero makes target_net
|
|
||||||
# not change at all.
|
|
||||||
TAU = 0.05
|
|
||||||
|
|
||||||
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
GAMMA = 0.9
|
GAMMA = 0.9
|
||||||
|
|
||||||
steps_done = 0
|
# Train on this many transitions from
|
||||||
num_episodes = 100
|
# replay memory each round
|
||||||
episode_number = 0
|
BATCH_SIZE = 100
|
||||||
archive_interval = 10
|
|
||||||
|
# Controls target_net soft update.
|
||||||
|
# Should be between 0 and 1.
|
||||||
|
TAU = 0.05
|
||||||
|
|
||||||
|
# Optimizer learning rate
|
||||||
|
learning_rate = 0.001
|
||||||
|
|
||||||
|
# Save a snapshot of the model every n
|
||||||
|
# episodes.
|
||||||
|
model_save_interval = 10
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# How many times we've reached each point.
|
||||||
|
# This is used to compute epsilon-greedy probability.
|
||||||
|
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
||||||
|
|
||||||
|
n_episodes = 0 # Number of episodes we've trained on
|
||||||
|
n_steps = 0 # Number of training steps we've completed
|
||||||
|
|
||||||
# Create replay memory.
|
# Create replay memory.
|
||||||
#
|
#
|
||||||
# Transition: a container for naming data (defined in util.py)
|
# Holds <Transition> objects, defined in
|
||||||
# Memory: a deque that holds recent states as Transitions
|
# network.py
|
||||||
# Has a fixed length, drops oldest
|
|
||||||
# element if maxlen is exceeded.
|
|
||||||
memory = deque([], maxlen=50_000)
|
memory = deque([], maxlen=50_000)
|
||||||
|
|
||||||
policy_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
target_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
|
policy_net = DQN(n_observations, n_actions).to(compute_device)
|
||||||
|
target_net = DQN(n_observations, n_actions).to(compute_device)
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
target_net.load_state_dict(policy_net.state_dict())
|
||||||
|
|
||||||
|
|
||||||
learning_rate = 0.001
|
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
policy_net.parameters(),
|
policy_net.parameters(),
|
||||||
lr = learning_rate,
|
lr = learning_rate,
|
||||||
@ -122,11 +105,43 @@ if __name__ == "__main__":
|
|||||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
memory = checkpoint["memory"]
|
memory = checkpoint["memory"]
|
||||||
episode_number = checkpoint["episode_number"] + 1
|
|
||||||
steps_done = checkpoint["steps_done"]
|
n_episodes = checkpoint["n_episodes"]
|
||||||
|
n_steps = checkpoint["n_steps"]
|
||||||
point_counter = checkpoint["point_counter"]
|
point_counter = checkpoint["point_counter"]
|
||||||
|
|
||||||
def select_action(state, steps_done):
|
|
||||||
|
|
||||||
|
def save_model(path):
|
||||||
|
torch.save({
|
||||||
|
# Newtorks
|
||||||
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
|
||||||
|
# Training data
|
||||||
|
"memory": memory,
|
||||||
|
"point_counter": point_counter,
|
||||||
|
"n_episodes": n_episodes,
|
||||||
|
"n_steps": n_steps,
|
||||||
|
|
||||||
|
# Hyperparameters,
|
||||||
|
# for reference
|
||||||
|
"eps_start": EPS_START,
|
||||||
|
"eps_end": EPS_END,
|
||||||
|
"eps_decay": EPS_DECAY,
|
||||||
|
"batch_size": BATCH_SIZE,
|
||||||
|
"tau": TAU,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"gamma": GAMMA
|
||||||
|
}, path
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def select_action(state, x) -> int:
|
||||||
"""
|
"""
|
||||||
Select an action using an epsilon-greedy policy.
|
Select an action using an epsilon-greedy policy.
|
||||||
|
|
||||||
@ -136,19 +151,13 @@ def select_action(state, steps_done):
|
|||||||
Decay rate is controlled by EPS_DECAY.
|
Decay rate is controlled by EPS_DECAY.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Random number 0 <= x < 1
|
|
||||||
sample = random.random()
|
|
||||||
|
|
||||||
# Calculate random step threshhold
|
# Calculate random step threshhold
|
||||||
eps_threshold = (
|
eps_threshold = (
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
EPS_END + (EPS_START - EPS_END) *
|
||||||
math.exp(
|
math.exp(-1.0 * x / EPS_DECAY)
|
||||||
-1.0 * steps_done /
|
|
||||||
EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if sample > eps_threshold:
|
if random.random() > eps_threshold:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# t.max(1) will return the largest column value of each row.
|
# t.max(1) will return the largest column value of each row.
|
||||||
# second column on max result is index of where max element was
|
# second column on max result is index of where max element was
|
||||||
@ -175,7 +184,7 @@ def optimize_model():
|
|||||||
|
|
||||||
# Conversion.
|
# Conversion.
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
# Combine states, actions, and rewards into their own tensors.
|
||||||
state_batch = torch.cat(batch.state)
|
last_state_batch = torch.cat(batch.last_state)
|
||||||
action_batch = torch.cat(batch.action)
|
action_batch = torch.cat(batch.action)
|
||||||
reward_batch = torch.cat(batch.reward)
|
reward_batch = torch.cat(batch.reward)
|
||||||
|
|
||||||
@ -209,7 +218,7 @@ def optimize_model():
|
|||||||
# This gives us a tensor that contains the return we expect to get
|
# This gives us a tensor that contains the return we expect to get
|
||||||
# at that state if we follow the model's advice.
|
# at that state if we follow the model's advice.
|
||||||
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
state_action_values = policy_net(last_state_batch).gather(1, action_batch)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -282,36 +291,21 @@ def optimize_model():
|
|||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
global steps_done
|
|
||||||
|
|
||||||
state = celeste.state
|
state = celeste.state
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
|
||||||
[getattr(state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
action = select_action(
|
action = select_action(
|
||||||
pt_state,
|
# Put state in a tensor
|
||||||
|
torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0),
|
||||||
|
|
||||||
|
# Random action probability is determined by
|
||||||
|
# the number of times we've reached the next point.
|
||||||
point_counter[state.next_point]
|
point_counter[state.next_point]
|
||||||
)
|
)
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
action = None
|
|
||||||
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
|
||||||
action = select_action(
|
|
||||||
pt_state,
|
|
||||||
steps_done
|
|
||||||
)
|
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
"""
|
|
||||||
|
|
||||||
steps_done += 1
|
|
||||||
|
|
||||||
|
|
||||||
# For manual testing
|
# For manual testing
|
||||||
#str_action = ""
|
#str_action = ""
|
||||||
@ -319,86 +313,114 @@ def on_state_before(celeste):
|
|||||||
# str_action = input("action> ")
|
# str_action = input("action> ")
|
||||||
#action = Celeste.action_space.index(str_action)
|
#action = Celeste.action_space.index(str_action)
|
||||||
|
|
||||||
print(str_action)
|
print(Celeste.action_space[action])
|
||||||
celeste.act(str_action)
|
celeste.act(action)
|
||||||
|
|
||||||
return state, action
|
return (
|
||||||
|
state, # CelesteState
|
||||||
|
action # Integer
|
||||||
def on_state_after(celeste, before_out):
|
|
||||||
global episode_number
|
|
||||||
|
|
||||||
state, action = before_out
|
|
||||||
next_state = celeste.state
|
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
|
||||||
[getattr(state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
pt_action = torch.tensor(
|
|
||||||
[[ action ]],
|
|
||||||
device = compute_device,
|
|
||||||
dtype = torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
finished_stage = False
|
|
||||||
|
def compute_reward(last_state, state):
|
||||||
|
global point_counter
|
||||||
|
|
||||||
|
reward = None
|
||||||
|
|
||||||
# No reward if dead
|
# No reward if dead
|
||||||
if next_state.deaths != 0:
|
if state.deaths != 0:
|
||||||
pt_next_state = None
|
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
# Reward for finishing a stage
|
# Reward for finishing a stage
|
||||||
elif next_state.stage >= 1:
|
elif state.stage >= 1:
|
||||||
finished_stage = True
|
print("FINISHED STAGE!!")
|
||||||
reward = next_state.next_point - state.next_point
|
|
||||||
|
# We don't set a fixed reward here because the agent may
|
||||||
|
# complete the stage before getting all points.
|
||||||
|
# The below line provides extra reward for taking shortcuts.
|
||||||
|
reward = state.next_point - last_state.next_point
|
||||||
reward += 1
|
reward += 1
|
||||||
|
|
||||||
# Add to point counter
|
# Add to point counter
|
||||||
for i in range(state.next_point, state.next_point + reward):
|
for i in range(last_state.next_point, len(point_counter)):
|
||||||
point_counter[i] += 1
|
point_counter[i] += 1
|
||||||
|
|
||||||
# Regular reward
|
# Reward for reaching a checkpoint
|
||||||
|
elif last_state.next_point != state.next_point:
|
||||||
|
print(f"Got point {state.next_point}")
|
||||||
|
|
||||||
|
reward = state.next_point - last_state.next_point
|
||||||
|
|
||||||
|
# Add to point counter
|
||||||
|
for i in range(last_state.next_point, last_state.next_point + reward):
|
||||||
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
# No reward otherwise
|
||||||
else:
|
else:
|
||||||
pt_next_state = torch.tensor(
|
reward = 0
|
||||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
# Strawberry reward
|
||||||
device = compute_device
|
# (Will probably break current version of model)
|
||||||
).unsqueeze(0)
|
#if state.berries[state.stage] and not state.berries[state.stage]:
|
||||||
|
# print(f"Got stage {state.stage} bonus")
|
||||||
|
# reward += 1
|
||||||
|
|
||||||
|
assert reward is not None
|
||||||
|
return reward * 10
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global n_episodes
|
||||||
|
global n_steps
|
||||||
|
|
||||||
if state.next_point == next_state.next_point:
|
last_state, action = before_out
|
||||||
reward = 0
|
next_state = celeste.state
|
||||||
else:
|
dead = next_state.deaths != 0
|
||||||
print(f"Got point {state.next_point}")
|
done = next_state.stage >= 1
|
||||||
# Reward for reaching a point
|
|
||||||
reward = next_state.next_point - state.next_point
|
|
||||||
|
|
||||||
# Add to point counter
|
|
||||||
for i in range(state.next_point, state.next_point + reward):
|
|
||||||
point_counter[i] += 1
|
|
||||||
|
|
||||||
# Strawberry reward
|
|
||||||
if next_state.berries[state.stage] and not state.berries[state.stage]:
|
|
||||||
print(f"Got stage {state.stage} bonus")
|
|
||||||
reward += 1
|
|
||||||
|
|
||||||
|
|
||||||
|
reward = compute_reward(last_state, next_state)
|
||||||
|
|
||||||
reward = reward * 10
|
if dead:
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
next_state = None
|
||||||
|
elif done:
|
||||||
|
# We don't set the next state to None because
|
||||||
|
# the optimization routine forces zero reward
|
||||||
|
# for terminal states.
|
||||||
|
# Copy last state instead. It's a hack, but it
|
||||||
|
# should work.
|
||||||
|
next_state = last_state
|
||||||
|
|
||||||
# Add this state transition to memory.
|
# Add this state transition to memory.
|
||||||
memory.append(
|
memory.append(
|
||||||
Transition(
|
Transition(
|
||||||
pt_state,
|
# last state
|
||||||
pt_action,
|
torch.tensor(
|
||||||
pt_next_state,
|
[getattr(last_state, x) for x in Celeste.state_number_map],
|
||||||
pt_reward
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0),
|
||||||
|
|
||||||
|
# action
|
||||||
|
torch.tensor(
|
||||||
|
[[ action ]],
|
||||||
|
device = compute_device,
|
||||||
|
dtype = torch.long
|
||||||
|
),
|
||||||
|
|
||||||
|
# next state
|
||||||
|
# None if dead or done.
|
||||||
|
torch.tensor(
|
||||||
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0) if next_state is not None else None,
|
||||||
|
|
||||||
|
# reward
|
||||||
|
torch.tensor(
|
||||||
|
[reward],
|
||||||
|
device = compute_device
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -406,11 +428,10 @@ def on_state_after(celeste, before_out):
|
|||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
# Perform a training step
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
# Only train the network if we have enough
|
|
||||||
# transitions in memory to do so.
|
|
||||||
if len(memory) >= BATCH_SIZE:
|
if len(memory) >= BATCH_SIZE:
|
||||||
|
n_steps += 1
|
||||||
loss = optimize_model()
|
loss = optimize_model()
|
||||||
|
|
||||||
# Soft update target_net weights
|
# Soft update target_net weights
|
||||||
@ -423,65 +444,43 @@ def on_state_after(celeste, before_out):
|
|||||||
)
|
)
|
||||||
target_net.load_state_dict(target_net_state)
|
target_net.load_state_dict(target_net_state)
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (next_state.deaths != 0 or finished_stage):
|
# Move on to the next episode and run
|
||||||
|
# housekeeping tasks.
|
||||||
|
if (dead or done):
|
||||||
s = celeste.state
|
s = celeste.state
|
||||||
|
n_episodes += 1
|
||||||
|
|
||||||
|
# Move screenshots
|
||||||
|
sm.move(
|
||||||
|
number = n_episodes,
|
||||||
|
overwrite = True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Log this episode
|
||||||
with model_train_log.open("a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
|
"n_episodes": n_episodes,
|
||||||
|
"n_steps": n_steps,
|
||||||
"checkpoints": s.next_point,
|
"checkpoints": s.next_point,
|
||||||
"state_count": s.state_count,
|
"loss": None if loss is None else loss.item(),
|
||||||
"loss": None if loss is None else loss.item()
|
"done": done
|
||||||
}) + "\n")
|
}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
# Save model
|
|
||||||
torch.save({
|
|
||||||
"policy_state_dict": policy_net.state_dict(),
|
|
||||||
"target_state_dict": target_net.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
"memory": memory,
|
|
||||||
"point_counter": point_counter,
|
|
||||||
"episode_number": episode_number,
|
|
||||||
"steps_done": steps_done,
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
"eps_start": EPS_START,
|
|
||||||
"eps_end": EPS_END,
|
|
||||||
"eps_decay": EPS_DECAY,
|
|
||||||
"batch_size": BATCH_SIZE,
|
|
||||||
"tau": TAU,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
"gamma": GAMMA
|
|
||||||
}, model_save_path)
|
|
||||||
|
|
||||||
|
|
||||||
# Clean up screenshots
|
|
||||||
shots = screenshot_source.glob("hackcel_*.png")
|
|
||||||
|
|
||||||
target = screenshot_dir / Path(f"{episode_number}")
|
|
||||||
target.mkdir(parents = True)
|
|
||||||
|
|
||||||
for s in shots:
|
|
||||||
s.rename(target / s.name)
|
|
||||||
|
|
||||||
# Save a snapshot
|
# Save a snapshot
|
||||||
if episode_number % archive_interval == 0:
|
if n_episodes % model_save_interval == 0:
|
||||||
torch.save({
|
save_model(model_archive_dir / f"{n_episodes}.torch")
|
||||||
"policy_state_dict": policy_net.state_dict(),
|
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
|
||||||
"target_state_dict": target_net.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
"memory": memory,
|
|
||||||
"episode_number": episode_number,
|
|
||||||
"steps_done": steps_done
|
|
||||||
}, model_archive_dir / f"{episode_number}.torch")
|
|
||||||
|
|
||||||
|
|
||||||
print("Game over. Resetting.")
|
print("Game over. Resetting.")
|
||||||
episode_number += 1
|
|
||||||
celeste.reset()
|
celeste.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
c = Celeste(
|
c = Celeste(
|
||||||
"resources/pico-8/linux/pico8"
|
"resources/pico-8/linux/pico8"
|
0
celeste_ai/util/__init__.py
Normal file
70
celeste_ai/util/screenshots.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class ScreenshotManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
|
||||||
|
# Where PICO-8 saves screenshots
|
||||||
|
source: Path,
|
||||||
|
|
||||||
|
# How PICO-8 names screenshots.
|
||||||
|
# Example: "celeste_*.png"
|
||||||
|
pattern: str,
|
||||||
|
|
||||||
|
# Where we want to move screenshots.
|
||||||
|
target: Path
|
||||||
|
):
|
||||||
|
self.source = source
|
||||||
|
self.pattern = pattern
|
||||||
|
self.target = target
|
||||||
|
self.target.mkdir(
|
||||||
|
parents = True,
|
||||||
|
exist_ok = True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
shots = self.source.glob(self.pattern)
|
||||||
|
for s in shots:
|
||||||
|
s.unlink()
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def move(self, number: int | None = None, overwrite = False):
|
||||||
|
shots = self.source.glob(self.pattern)
|
||||||
|
|
||||||
|
if number == None:
|
||||||
|
|
||||||
|
# Auto-select new directory number.
|
||||||
|
# Chooses next highest int directory name
|
||||||
|
number = 0
|
||||||
|
for f in self.target.iterdir():
|
||||||
|
try:
|
||||||
|
number = max(
|
||||||
|
int(f.name),
|
||||||
|
number
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
number += 1
|
||||||
|
|
||||||
|
target = self.target / str(number)
|
||||||
|
else:
|
||||||
|
target = self.target / str(number)
|
||||||
|
|
||||||
|
if target.exists():
|
||||||
|
if not overwrite:
|
||||||
|
raise Exception(f"Target \"{target}\" exists!")
|
||||||
|
else:
|
||||||
|
print(f"Target \"{target}\" exists, removing.")
|
||||||
|
shutil.rmtree(target)
|
||||||
|
|
||||||
|
target.mkdir(parents = True)
|
||||||
|
|
||||||
|
for s in shots:
|
||||||
|
s.rename(target / s.name)
|
||||||
|
return self
|
90
ffmpeg.sh
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
|
||||||
|
# Where screenshots are saved.
|
||||||
|
# SC_ROOT/SC_DIR should contain episode screenshot directories
|
||||||
|
SC_ROOT="model_data/current"
|
||||||
|
SC_DIR="screenshots"
|
||||||
|
|
||||||
|
|
||||||
|
# Select a temporary working directory
|
||||||
|
# if false, uses ramdisk.
|
||||||
|
# set to true if ramdisk overflows.
|
||||||
|
if false; then
|
||||||
|
OUTPUT_DIR="model_data/video_output"
|
||||||
|
|
||||||
|
# This directory will be deleted.
|
||||||
|
# Make sure it doesn't already exist.
|
||||||
|
if [ -e "$OUTPUT_DIR" ]; then
|
||||||
|
echo "$OUTPUT_DIR exists, exiting. Please delete it manually."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
else
|
||||||
|
OUTPUT_DIR=$(mktemp -d)
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# Usage:
|
||||||
|
# render_episode <src_dir> <output_name>
|
||||||
|
#
|
||||||
|
# Turns a directory of frame screenshots into a video.
|
||||||
|
# Applies upscaling. We do it early, because upscaling
|
||||||
|
# after encoding will exaggerate artifacts.
|
||||||
|
render_episode () {
|
||||||
|
ffmpeg \
|
||||||
|
-y \
|
||||||
|
-loglevel quiet \
|
||||||
|
-framerate 30 \
|
||||||
|
-i "$1/hackcel_%003d.png" \
|
||||||
|
-c:v libx264 \
|
||||||
|
-crf 20 \
|
||||||
|
-preset slow \
|
||||||
|
-tune animation \
|
||||||
|
-vf "scale=512x512:flags=neighbor" \
|
||||||
|
"$2.mp4"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
echo "Making episode files..."
|
||||||
|
for D in "$SC_ROOT/$SC_DIR"/*; do
|
||||||
|
if [ -d "${D}" ]; then
|
||||||
|
render_episode "$D" "$OUTPUT_DIR/${D##*/}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
echo "Merging..."
|
||||||
|
for f in "$OUTPUT_DIR"/*.mp4; do
|
||||||
|
echo file \'$f\' >> "$OUTPUT_DIR/video_merge_list"
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Merge videos
|
||||||
|
ffmpeg \
|
||||||
|
-loglevel error -stats -y \
|
||||||
|
-f concat \
|
||||||
|
-safe 0 \
|
||||||
|
-i "$OUTPUT_DIR/video_merge_list" \
|
||||||
|
"$SC_ROOT/1x.mp4"
|
||||||
|
echo ""
|
||||||
|
echo "Making accelerated video..."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Make accelerated video
|
||||||
|
ffmpeg \
|
||||||
|
-loglevel error -stats -y \
|
||||||
|
-i "$SC_ROOT/1x.mp4" \
|
||||||
|
-framerate 60 \
|
||||||
|
-filter:v "setpts=0.125*PTS" \
|
||||||
|
"$SC_ROOT/8x.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
echo "Cleaning up...."
|
||||||
|
rm -dr $OUTPUT_DIR
|
37
plot.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import celeste_ai.plotting as plotting
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
|
m = Path("model_data/current")
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pred(src_model):
|
||||||
|
plotting.predicted_reward(
|
||||||
|
src_model,
|
||||||
|
m / f"plots/predicted/{src_model.stem}.png",
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_best(src_model):
|
||||||
|
plotting.best_action(
|
||||||
|
src_model,
|
||||||
|
m / f"plots/best_action/{src_model.stem}.png",
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
for k, v in {
|
||||||
|
#"prediction": plot_pred,
|
||||||
|
"best_action": plot_best,
|
||||||
|
}.items():
|
||||||
|
print(f"Making {k} plots...")
|
||||||
|
with Pool(5) as p:
|
||||||
|
p.map(
|
||||||
|
v,
|
||||||
|
list((m / "model_archive").iterdir())
|
||||||
|
)
|
@ -1,276 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import util
|
|
||||||
import optimize as optimize
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Parameter file
|
|
||||||
|
|
||||||
# TODO: What is this?
|
|
||||||
human_render = False
|
|
||||||
|
|
||||||
# TODO: What is this$
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
|
|
||||||
# Learning rate of target_net.
|
|
||||||
# Controls how soft our soft update is.
|
|
||||||
#
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
# Large values
|
|
||||||
# Small values do the opposite.
|
|
||||||
#
|
|
||||||
# A value of one makes target_net
|
|
||||||
# change at the same rate as policy_net.
|
|
||||||
#
|
|
||||||
# A value of zero makes target_net
|
|
||||||
# not change at all.
|
|
||||||
TAU = 0.005
|
|
||||||
|
|
||||||
|
|
||||||
# Setup game environment
|
|
||||||
if human_render:
|
|
||||||
env = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
else:
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
|
|
||||||
# Setup pytorch
|
|
||||||
compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Number of training episodes.
|
|
||||||
# It will take a while to process a many of these without a GPU,
|
|
||||||
# but you will not see improvement with few training episodes.
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
num_episodes = 600
|
|
||||||
else:
|
|
||||||
num_episodes = 50
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Create replay memory.
|
|
||||||
#
|
|
||||||
# Transition: a container for naming data (defined in util.py)
|
|
||||||
# Memory: a deque that holds recent states as Transitions
|
|
||||||
# Has a fixed length, drops oldest
|
|
||||||
# element if maxlen is exceeded.
|
|
||||||
memory = deque([], maxlen=10000)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
|
||||||
class DQN(nn.Module):
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
|
||||||
#
|
|
||||||
# Returns tensor(
|
|
||||||
# [ Q(s, left), Q(s, right) ], ...
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
|
||||||
# action `a` at state `s`
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Create networks and optimizer
|
|
||||||
|
|
||||||
# n_actions: size of action space
|
|
||||||
# - 2 for cartpole: [0, 1] as "left" and "right"
|
|
||||||
#
|
|
||||||
# n_observations: size of observation vector
|
|
||||||
# - 4 for cartpole:
|
|
||||||
# position, velocity,
|
|
||||||
# angle, angular velocity
|
|
||||||
n_actions = env.action_space.n # type: ignore
|
|
||||||
state, _ = env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# What's the difference between these two?
|
|
||||||
# What do they do?
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
|
|
||||||
# Both networks start with the same weights
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
|
||||||
|
|
||||||
#
|
|
||||||
optimizer = optim.AdamW(
|
|
||||||
policy_net.parameters(),
|
|
||||||
lr = 1e-4, # Hyperparameter: learning rate
|
|
||||||
amsgrad = True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What is this?
|
|
||||||
steps_done = 0
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
|
|
||||||
# TRAINING LOOP
|
|
||||||
for ep in range(num_episodes):
|
|
||||||
|
|
||||||
# Reset environment and get game state
|
|
||||||
state, _ = env.reset()
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Iterate until game is over
|
|
||||||
for t in count():
|
|
||||||
|
|
||||||
# Select next action
|
|
||||||
action = util.select_action(
|
|
||||||
state,
|
|
||||||
steps_done = steps_done,
|
|
||||||
policy_net = policy_net,
|
|
||||||
device = compute_device,
|
|
||||||
env = env
|
|
||||||
)
|
|
||||||
steps_done += 1
|
|
||||||
|
|
||||||
|
|
||||||
# Perform one step of the environment with this action.
|
|
||||||
( next_state, # new state
|
|
||||||
reward, # number: reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = env.step(action.item())
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
reward = torch.tensor([reward], device = compute_device)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
# If the environment reached a terminal state,
|
|
||||||
# observations are meaningless. Set to None.
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
# Conversion
|
|
||||||
next_state = torch.tensor(
|
|
||||||
next_state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
|
||||||
memory.append(
|
|
||||||
util.Transition(
|
|
||||||
state,
|
|
||||||
action,
|
|
||||||
next_state,
|
|
||||||
reward
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
|
||||||
# transitions in memory to do so.
|
|
||||||
if len(memory) >= BATCH_SIZE:
|
|
||||||
# Run optimizer
|
|
||||||
optimize.optimize_model(
|
|
||||||
memory,
|
|
||||||
# Pytorch params
|
|
||||||
compute_device = compute_device,
|
|
||||||
policy_net = policy_net,
|
|
||||||
target_net = target_net,
|
|
||||||
optimizer = optimizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Soft update target_net weights
|
|
||||||
target_net_state = target_net.state_dict()
|
|
||||||
policy_net_state = policy_net.state_dict()
|
|
||||||
for key in policy_net_state:
|
|
||||||
target_net_state[key] = (
|
|
||||||
policy_net_state[key] * TAU +
|
|
||||||
target_net_state[key] * (1-TAU)
|
|
||||||
)
|
|
||||||
target_net.load_state_dict(target_net_state)
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (terminated or truncated):
|
|
||||||
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Complete.")
|
|
||||||
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
plt.xlabel('Episode')
|
|
||||||
plt.ylabel('Duration')
|
|
||||||
plt.plot(durations_t.numpy())
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
en = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = en.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
action = policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
( state, # new state
|
|
||||||
reward, # reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = en.step(action.item())
|
|
||||||
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
en.render()
|
|
||||||
en.reset()
|
|
@ -1,161 +0,0 @@
|
|||||||
import random
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import util
|
|
||||||
|
|
||||||
def optimize_model(
|
|
||||||
memory: deque,
|
|
||||||
|
|
||||||
# Pytorch params
|
|
||||||
compute_device,
|
|
||||||
policy_net: nn.Module,
|
|
||||||
target_net: nn.Module,
|
|
||||||
optimizer,
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
GAMMA = 0.99
|
|
||||||
):
|
|
||||||
|
|
||||||
if len(memory) < BATCH_SIZE:
|
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get a random sample of transitions
|
|
||||||
batch = random.sample(memory, BATCH_SIZE)
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Transposes batch, turning an array of Transitions
|
|
||||||
# into a Transition of arrays.
|
|
||||||
batch = util.Transition(*zip(*batch))
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute a mask of non_final_states.
|
|
||||||
# Each element of this tensor corresponds to an element in the batch.
|
|
||||||
# True if this is a final state, False if it is.
|
|
||||||
#
|
|
||||||
# We use this to select non-final states later.
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# How .gather works:
|
|
||||||
# if out = a.gather(1, b),
|
|
||||||
# out[i, j] = a[ i ][ b[i,j] ]
|
|
||||||
#
|
|
||||||
# a is "input," b is "index"
|
|
||||||
# If this doesn't make sense, RTFD.
|
|
||||||
|
|
||||||
# Compute Q(s_t, a).
|
|
||||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
||||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
||||||
#
|
|
||||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
||||||
# listing the action that was taken in each transition.
|
|
||||||
# 0 => we went left, 1 => we went right.
|
|
||||||
#
|
|
||||||
# This aligns nicely with the output of policy_net. We use
|
|
||||||
# action_batch to index the output of policy_net's prediction.
|
|
||||||
#
|
|
||||||
# This gives us a tensor that contains the return we expect to get
|
|
||||||
# at that state if we follow the model's advice.
|
|
||||||
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute V(s_t+1) for all next states.
|
|
||||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
|
||||||
# If you don't understand what this means, RTFD.
|
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
|
||||||
# States that are final do not have their reward set by the line
|
|
||||||
# below, so their reward stays at zero.
|
|
||||||
#
|
|
||||||
# States that are not final get their predicted value
|
|
||||||
# set to the best value the model predicts.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Expected values of action are selected with the "older" target net,
|
|
||||||
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
||||||
|
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What does this mean?
|
|
||||||
# "Compute expected Q values"
|
|
||||||
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute Huber loss between predicted reward and expected reward.
|
|
||||||
# Pytorch is will account for this when we compute the gradient of loss.
|
|
||||||
#
|
|
||||||
# loss is a single-element tensor (i.e, a scalar).
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
loss = criterion(
|
|
||||||
state_action_values,
|
|
||||||
expected_state_action_values.unsqueeze(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
|
||||||
|
|
||||||
# TODO: what does this do?
|
|
||||||
#
|
|
||||||
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
||||||
# Thus, we reset the gradient before each step.
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Compute the gradient of loss wrt... something?
|
|
||||||
# TODO: what does this do, we never use loss again?!
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
|
|
||||||
# Prevent vanishing and exploding gradients.
|
|
||||||
# Forces gradients to be in [-clip_value, +clip_value]
|
|
||||||
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
||||||
policy_net.parameters(),
|
|
||||||
clip_value = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform a single optimizer step.
|
|
||||||
#
|
|
||||||
# Uses the current gradient, which is stored
|
|
||||||
# in the .grad attribute of the parameter.
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
|||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_action(
|
|
||||||
state,
|
|
||||||
|
|
||||||
*,
|
|
||||||
|
|
||||||
# Number of steps that have been done
|
|
||||||
steps_done: int,
|
|
||||||
|
|
||||||
# TF parameters
|
|
||||||
policy_net, # DQN policy network
|
|
||||||
device, # Render device, "gpu" or "cpu"
|
|
||||||
env, # GYM environment instance
|
|
||||||
|
|
||||||
# Epsilon parameters
|
|
||||||
#
|
|
||||||
# Original docs:
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
EPS_START = 0.9,
|
|
||||||
EPS_END = 0.05,
|
|
||||||
EPS_DECAY = 1000
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Given a state, select an action using an epsilon-greedy policy.
|
|
||||||
|
|
||||||
Sometimes use our model, sometimes sample one uniformly.
|
|
||||||
|
|
||||||
P(random action) starts at EPS_START and decays to EPS_END.
|
|
||||||
Decay rate is controlled by EPS_DECAY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Random number 0 <= x < 1
|
|
||||||
sample = random.random()
|
|
||||||
|
|
||||||
# Calculate random step threshhold
|
|
||||||
eps_threshold = (
|
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * steps_done /
|
|
||||||
EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [env.action_space.sample()] ],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
@ -1,415 +0,0 @@
|
|||||||
import torch
|
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from itertools import count
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
|
|
||||||
## Misc parameters
|
|
||||||
#
|
|
||||||
# Computation backend. Usually "cpu" or "gpu."
|
|
||||||
# Automatic selection if left as None.
|
|
||||||
# It's best to leave this as None.
|
|
||||||
compute_device = None,
|
|
||||||
#
|
|
||||||
# Gymnasium environment name.
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
|
|
||||||
## Modules
|
|
||||||
network,
|
|
||||||
|
|
||||||
## Hyperparameters
|
|
||||||
#
|
|
||||||
# BATCH_SIZE is the of batch we should train on, sampled from memory
|
|
||||||
# GAMMA is the discount factor for optimization
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
GAMMA = 0.99,
|
|
||||||
|
|
||||||
# Learning rate of target_net.
|
|
||||||
# Controls how soft our soft update is.
|
|
||||||
#
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
# Large values
|
|
||||||
# Small values do the opposite.
|
|
||||||
#
|
|
||||||
# A value of one makes target_net
|
|
||||||
# change at the same rate as policy_net.
|
|
||||||
#
|
|
||||||
# A value of zero makes target_net
|
|
||||||
# not change at all.
|
|
||||||
TAU = 0.005,
|
|
||||||
|
|
||||||
# Optimizer learning rate.
|
|
||||||
OPT_LR = 1e-4,
|
|
||||||
|
|
||||||
|
|
||||||
# Epsilon-greedy parameters
|
|
||||||
#
|
|
||||||
# Original docs:
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
EPS_START = 0.9,
|
|
||||||
EPS_END = 0.05,
|
|
||||||
EPS_DECAY = 1000,
|
|
||||||
):
|
|
||||||
|
|
||||||
## Auto-select compute device
|
|
||||||
if compute_device is None:
|
|
||||||
self.compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.compute_device = compute_device
|
|
||||||
|
|
||||||
|
|
||||||
## Initialize misc values
|
|
||||||
self.steps_done = 0 # How many steps this agent has been trained on
|
|
||||||
self.network = network # Network class this agent should use
|
|
||||||
self.env = gym.make(env_name) # Gym environment
|
|
||||||
self.env_name = env_name
|
|
||||||
|
|
||||||
## Initialize replay memory.
|
|
||||||
# This is a deque of util.Transitions.
|
|
||||||
self.memory = deque([], maxlen = 10_000)
|
|
||||||
|
|
||||||
## Save model hyperparameters
|
|
||||||
self.BATCH_SIZE = BATCH_SIZE
|
|
||||||
self.GAMMA = GAMMA
|
|
||||||
self.TAU = TAU
|
|
||||||
self.OPT_LR = OPT_LR
|
|
||||||
self.EPS_START = EPS_START
|
|
||||||
self.EPS_END = EPS_END
|
|
||||||
self.EPS_DECAY = EPS_DECAY
|
|
||||||
|
|
||||||
|
|
||||||
## Create networks and optimizer
|
|
||||||
# n_actions: size of action space
|
|
||||||
# - 2 for cartpole: [0, 1] as "left" and "right"
|
|
||||||
#
|
|
||||||
# n_observations: size of observation vector
|
|
||||||
# - 4 for cartpole:
|
|
||||||
# position, velocity,
|
|
||||||
# angle, angular velocity
|
|
||||||
n_actions = self.env.action_space.n # type: ignore
|
|
||||||
state, _ = self.env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# What's the difference between these two?
|
|
||||||
# What do they do?
|
|
||||||
self.policy_net = self.network(n_observations, n_actions).to(self.compute_device)
|
|
||||||
self.target_net = self.network(n_observations, n_actions).to(self.compute_device)
|
|
||||||
|
|
||||||
# Both networks should start with the same weights
|
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
## Initialize optimizer.
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
self.policy_net.parameters(),
|
|
||||||
lr = self.OPT_LR,
|
|
||||||
amsgrad = True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _select_action(self, state):
|
|
||||||
"""
|
|
||||||
Select an action using an epsilon-greedy policy.
|
|
||||||
|
|
||||||
Sometimes use our model, sometimes sample one uniformly.
|
|
||||||
|
|
||||||
P(random action) starts at EPS_START and decays to EPS_END.
|
|
||||||
Decay rate is controlled by EPS_DECAY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Random number 0 <= x < 1
|
|
||||||
sample = random.random()
|
|
||||||
|
|
||||||
# Calculate random step threshhold
|
|
||||||
eps_threshold = (
|
|
||||||
self.EPS_END + (self.EPS_START - self.EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * self.steps_done /
|
|
||||||
self.EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return self.policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [self.env.action_space.sample()] ],
|
|
||||||
device = self.compute_device,
|
|
||||||
dtype = torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _optimize(self):
|
|
||||||
if len(self.memory) < self.BATCH_SIZE:
|
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {self.BATCH_SIZE}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get a random sample of transitions
|
|
||||||
batch = random.sample(self.memory, self.BATCH_SIZE)
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Transposes batch, turning an array of Transitions
|
|
||||||
# into a Transition of arrays.
|
|
||||||
batch = Transition(*zip(*batch))
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute a mask of non_final_states.
|
|
||||||
# Each element of this tensor corresponds to an element in the batch.
|
|
||||||
# True if this is a final state, False if it is.
|
|
||||||
#
|
|
||||||
# We use this to select non-final states later.
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# How .gather works:
|
|
||||||
# if out = a.gather(1, b),
|
|
||||||
# out[i, j] = a[ i ][ b[i,j] ]
|
|
||||||
#
|
|
||||||
# a is "input," b is "index"
|
|
||||||
# If this doesn't make sense, RTFD.
|
|
||||||
|
|
||||||
# Compute Q(s_t, a).
|
|
||||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
||||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
||||||
#
|
|
||||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
||||||
# listing the action that was taken in each transition.
|
|
||||||
# 0 => we went left, 1 => we went right.
|
|
||||||
#
|
|
||||||
# This aligns nicely with the output of policy_net. We use
|
|
||||||
# action_batch to index the output of policy_net's prediction.
|
|
||||||
#
|
|
||||||
# This gives us a tensor that contains the return we expect to get
|
|
||||||
# at that state if we follow the model's advice.
|
|
||||||
|
|
||||||
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute V(s_t+1) for all next states.
|
|
||||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
|
||||||
next_state_values = torch.zeros(
|
|
||||||
self.BATCH_SIZE,
|
|
||||||
device = self.compute_device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
|
||||||
# If you don't understand what this means, RTFD.
|
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
|
||||||
# States that are final do not have their reward set by the line
|
|
||||||
# below, so their reward stays at zero.
|
|
||||||
#
|
|
||||||
# States that are not final get their predicted value
|
|
||||||
# set to the best value the model predicts.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Expected values of action are selected with the "older" target net,
|
|
||||||
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
||||||
|
|
||||||
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What does this mean?
|
|
||||||
# "Compute expected Q values"
|
|
||||||
expected_state_action_values = reward_batch + (next_state_values * self.GAMMA)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute Huber loss between predicted reward and expected reward.
|
|
||||||
# Pytorch is will account for this when we compute the gradient of loss.
|
|
||||||
#
|
|
||||||
# loss is a single-element tensor (i.e, a scalar).
|
|
||||||
criterion = torch.nn.SmoothL1Loss()
|
|
||||||
loss = criterion(
|
|
||||||
state_action_values,
|
|
||||||
expected_state_action_values.unsqueeze(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
|
||||||
|
|
||||||
# TODO: what does this do?
|
|
||||||
#
|
|
||||||
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
||||||
# Thus, we reset the gradient before each step.
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Compute the gradient of loss wrt... something?
|
|
||||||
# TODO: what does this do, we never use loss again?!
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
|
|
||||||
# Prevent vanishing and exploding gradients.
|
|
||||||
# Forces gradients to be in [-clip_value, +clip_value]
|
|
||||||
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
||||||
self.policy_net.parameters(),
|
|
||||||
clip_value = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform a single optimizer step.
|
|
||||||
#
|
|
||||||
# Uses the current gradient, which is stored
|
|
||||||
# in the .grad attribute of the parameter.
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
|
|
||||||
# Number of training episodes.
|
|
||||||
# Need ~400 to see results.
|
|
||||||
num_episodes = 400,
|
|
||||||
|
|
||||||
# If true, print progress
|
|
||||||
verbose = False
|
|
||||||
) -> list[int]:
|
|
||||||
# Returns a list of training episode durations.
|
|
||||||
# Good for graphing.
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
for ep in range(num_episodes):
|
|
||||||
|
|
||||||
# Reset environment and get game state
|
|
||||||
state, _ = self.env.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = self.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Iterate until game is over
|
|
||||||
for t in count():
|
|
||||||
|
|
||||||
# Select next action
|
|
||||||
action = self._select_action(state)
|
|
||||||
self.steps_done += 1
|
|
||||||
|
|
||||||
|
|
||||||
# Perform one step of the environment with this action.
|
|
||||||
( next_state, # new state
|
|
||||||
reward, # number: reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = self.env.step(action.item())
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
reward = torch.tensor([reward], device = self.compute_device)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
# If the environment reached a terminal state,
|
|
||||||
# observations are meaningless. Set to None.
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
# Conversion
|
|
||||||
next_state = torch.tensor(
|
|
||||||
next_state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = self.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
|
||||||
self.memory.append(
|
|
||||||
Transition(
|
|
||||||
state,
|
|
||||||
action,
|
|
||||||
next_state,
|
|
||||||
reward
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
|
||||||
# transitions in memory to do so.
|
|
||||||
if len(self.memory) >= self.BATCH_SIZE:
|
|
||||||
|
|
||||||
|
|
||||||
# Run optimizer
|
|
||||||
self._optimize()
|
|
||||||
|
|
||||||
|
|
||||||
# Soft update target_net weights
|
|
||||||
target_net_state = self.target_net.state_dict()
|
|
||||||
policy_net_state = self.policy_net.state_dict()
|
|
||||||
for key in policy_net_state:
|
|
||||||
target_net_state[key] = (
|
|
||||||
policy_net_state[key] * self.TAU +
|
|
||||||
target_net_state[key] * (1-self.TAU)
|
|
||||||
)
|
|
||||||
self.target_net.load_state_dict(target_net_state)
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (terminated or truncated):
|
|
||||||
if verbose:
|
|
||||||
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
return episode_durations
|
|
||||||
|
|
||||||
def predict(self, state):
|
|
||||||
return (
|
|
||||||
self.policy_net(state)
|
|
||||||
.max(1)[1]
|
|
||||||
.view(1, 1)
|
|
||||||
.item()
|
|
||||||
)
|
|
@ -1,132 +0,0 @@
|
|||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from agent import Agent
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
|
||||||
class DQN(nn.Module):
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
|
||||||
#
|
|
||||||
# Returns tensor(
|
|
||||||
# [ Q(s, left), Q(s, right) ], ...
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
|
||||||
# action `a` at state `s`
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
def train(i):
|
|
||||||
print(f"Running {i}")
|
|
||||||
|
|
||||||
agent = Agent(
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
network = DQN,
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
TAU = 0.005,
|
|
||||||
OPT_LR = 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train model episodes
|
|
||||||
episode_durations = agent.train(600)
|
|
||||||
|
|
||||||
#print(f"Model has been trained on {agent.steps_done} steps.")
|
|
||||||
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 1)
|
|
||||||
axs.plot(durations_t.numpy())
|
|
||||||
fig.savefig(f"main-{i}.png")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
with Pool(3) as p:
|
|
||||||
p.map(train, list(range(10)))
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Make the model
|
|
||||||
#
|
|
||||||
# Should work with...
|
|
||||||
# CartPole-v1
|
|
||||||
# Acrobot-v1
|
|
||||||
agent = Agent(
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
network = DQN,
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
TAU = 0.005,
|
|
||||||
OPT_LR = 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train the model
|
|
||||||
episode_durations = agent.train(600, verbose = True)
|
|
||||||
|
|
||||||
# Plot training progress
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
fig, axs = plt.subplots(1, 1)
|
|
||||||
axs.plot(durations_t.numpy())
|
|
||||||
fig.savefig(f"main.png")
|
|
||||||
|
|
||||||
|
|
||||||
# Test the model
|
|
||||||
env = gym.make(
|
|
||||||
agent.env_name,
|
|
||||||
render_mode = "human"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = env.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = agent.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
|
|
||||||
# Predict best action given state
|
|
||||||
action = agent.predict(state)
|
|
||||||
|
|
||||||
# Do that action, get new state
|
|
||||||
( state,
|
|
||||||
reward,
|
|
||||||
terminated,
|
|
||||||
truncated,
|
|
||||||
_
|
|
||||||
) = env.step(action)
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = agent.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
# Environment needs to be reset after a session ends
|
|
||||||
env.reset()
|
|
@ -1,316 +0,0 @@
|
|||||||
|
|
||||||
## Setup
|
|
||||||
import gymnasium as gym
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from collections import namedtuple, deque
|
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
|
|
||||||
# set up matplotlib
|
|
||||||
is_ipython = 'inline' in matplotlib.get_backend()
|
|
||||||
if is_ipython:
|
|
||||||
from IPython import display
|
|
||||||
|
|
||||||
plt.ion()
|
|
||||||
|
|
||||||
# if gpu is to be used
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Replay Memory
|
|
||||||
#
|
|
||||||
# We'll be using experience replay memory for training our DQN. It stores the transitions that the agent observes, allowing us to reuse this data later. By sampling from it randomly, the transitions that build up a batch are decorrelated. It has been shown that this greatly stabilizes and improves the DQN training procedure.
|
|
||||||
|
|
||||||
# For this, we're going to need two classses:
|
|
||||||
|
|
||||||
# Transition - a named tuple representing a single transition in our environment. It essentially maps (state, action) pairs to their (next_state, reward) result, with the state being the screen difference image as described later on.
|
|
||||||
|
|
||||||
# ReplayMemory - a cyclic buffer of bounded size that holds the transitions observed recently. It also implements a .sample() method for selecting a random batch of transitions for training.
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemory(object):
|
|
||||||
def __init__(self, capacity):
|
|
||||||
self.memory = deque([], maxlen=capacity)
|
|
||||||
|
|
||||||
def push(self, *args):
|
|
||||||
"""Save a transition"""
|
|
||||||
self.memory.append(Transition(*args))
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
|
||||||
return random.sample(self.memory, batch_size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.memory)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# DQN Algorithm
|
|
||||||
#
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Called with either one element to determine next action, or a batch
|
|
||||||
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
# TAU is the update rate of the target network
|
|
||||||
# LR is the learning rate of the AdamW optimizer
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
GAMMA = 0.99
|
|
||||||
EPS_START = 0.9
|
|
||||||
EPS_END = 0.05
|
|
||||||
EPS_DECAY = 1000
|
|
||||||
TAU = 0.005
|
|
||||||
LR = 1e-4
|
|
||||||
|
|
||||||
# Get number of actions from gym action space
|
|
||||||
n_actions = env.action_space.n
|
|
||||||
# Get the number of state observations
|
|
||||||
state, info = env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
|
||||||
|
|
||||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
|
||||||
memory = ReplayMemory(10000)
|
|
||||||
|
|
||||||
|
|
||||||
steps_done = 0
|
|
||||||
|
|
||||||
def select_action(state):
|
|
||||||
global steps_done
|
|
||||||
sample = random.random()
|
|
||||||
eps_threshold = (
|
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * steps_done /
|
|
||||||
EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
steps_done += 1
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [env.action_space.sample()] ],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
|
|
||||||
def plot_durations(show_result=False):
|
|
||||||
plt.figure(1)
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
if show_result:
|
|
||||||
plt.title('Result')
|
|
||||||
else:
|
|
||||||
plt.clf()
|
|
||||||
plt.title('Training...')
|
|
||||||
plt.xlabel('Episode')
|
|
||||||
plt.ylabel('Duration')
|
|
||||||
plt.plot(durations_t.numpy())
|
|
||||||
# Take 100 episode averages and plot them too
|
|
||||||
if len(durations_t) >= 100:
|
|
||||||
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
|
|
||||||
means = torch.cat((torch.zeros(99), means))
|
|
||||||
plt.plot(means.numpy())
|
|
||||||
|
|
||||||
plt.pause(0.001) # pause a bit so that plots are updated
|
|
||||||
if is_ipython:
|
|
||||||
if not show_result:
|
|
||||||
display.display(plt.gcf())
|
|
||||||
display.clear_output(wait=True)
|
|
||||||
else:
|
|
||||||
display.display(plt.gcf())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_model():
|
|
||||||
if len(memory) < BATCH_SIZE:
|
|
||||||
return
|
|
||||||
transitions = memory.sample(BATCH_SIZE)
|
|
||||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
|
||||||
# detailed explanation). This converts batch-array of Transitions
|
|
||||||
# to Transition of batch-arrays.
|
|
||||||
batch = Transition(*zip(*transitions))
|
|
||||||
|
|
||||||
# Compute a mask of non-final states and concatenate the batch elements
|
|
||||||
# (a final state would've been the one after which simulation ended)
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(
|
|
||||||
map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
)
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
|
||||||
# columns of actions taken. These are the actions which would've been taken
|
|
||||||
# for each batch state according to policy_net
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
# Compute V(s_{t+1}) for all next states.
|
|
||||||
# Expected values of actions for non_final_next_states are computed based
|
|
||||||
# on the "older" target_net; selecting their best reward with max(1)[0].
|
|
||||||
# This is merged based on the mask, such that we'll have either the expected
|
|
||||||
# state value or 0 in case the state was final.
|
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
|
||||||
with torch.no_grad():
|
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
||||||
# Compute the expected Q values
|
|
||||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
|
||||||
|
|
||||||
# Compute Huber loss
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
|
||||||
|
|
||||||
# Optimize the model
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
# In-place gradient clipping
|
|
||||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
num_episodes = 600
|
|
||||||
else:
|
|
||||||
num_episodes = 50
|
|
||||||
|
|
||||||
for i_episode in range(num_episodes):
|
|
||||||
# Initialize the environment and get its state
|
|
||||||
state, info = env.reset()
|
|
||||||
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
|
|
||||||
for t in count():
|
|
||||||
action = select_action(state)
|
|
||||||
observation, reward, terminated, truncated, _ = env.step(action.item())
|
|
||||||
reward = torch.tensor([reward], device=device)
|
|
||||||
done = terminated or truncated
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
# Store the transition in memory
|
|
||||||
memory.push(state, action, next_state, reward)
|
|
||||||
|
|
||||||
# Move to the next state
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
# Perform one step of the optimization (on the policy network)
|
|
||||||
optimize_model()
|
|
||||||
|
|
||||||
# Soft update of the target network's weights
|
|
||||||
# θ′ ← τ θ + (1 −τ )θ′
|
|
||||||
target_net_state_dict = target_net.state_dict()
|
|
||||||
policy_net_state_dict = policy_net.state_dict()
|
|
||||||
for key in policy_net_state_dict:
|
|
||||||
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
|
|
||||||
target_net.load_state_dict(target_net_state_dict)
|
|
||||||
|
|
||||||
if done:
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
plot_durations()
|
|
||||||
break
|
|
||||||
|
|
||||||
print('Complete')
|
|
||||||
plot_durations(show_result=True)
|
|
||||||
plt.ioff()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
en = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = en.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
action = policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
( state, # new state
|
|
||||||
reward, # reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = en.step(action.item())
|
|
||||||
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
en.render()
|
|
||||||
en.reset()
|
|
@ -1,3 +0,0 @@
|
|||||||
gymnasium[classic_control]==0.27.1
|
|
||||||
matplotlib==3.6.3
|
|
||||||
torch==1.13.1
|
|
BIN
report/Astra/PTAstraSans-Bold.ttf
Executable file
BIN
report/Astra/PTAstraSans-BoldItalic.ttf
Executable file
BIN
report/Astra/PTAstraSans-Italic.ttf
Executable file
BIN
report/Astra/PTAstraSans-Regular.ttf
Executable file
BIN
report/Astra/PTAstraSerif-Bold.ttf
Executable file
BIN
report/Astra/PTAstraSerif-BoldItalic.ttf
Executable file
BIN
report/Astra/PTAstraSerif-Italic.ttf
Executable file
BIN
report/Astra/PTAstraSerif-Regular.ttf
Executable file
BIN
report/images/badprediction.png
Executable file
After Width: | Height: | Size: 50 KiB |
BIN
report/images/celeste.png
Executable file
After Width: | Height: | Size: 1.5 MiB |
BIN
report/images/dash.png
Executable file
After Width: | Height: | Size: 6.0 KiB |
BIN
report/images/goodprediction.png
Executable file
After Width: | Height: | Size: 45 KiB |
BIN
report/images/jump.png
Executable file
After Width: | Height: | Size: 5.8 KiB |
BIN
report/images/plots.png
Executable file
After Width: | Height: | Size: 564 KiB |
BIN
report/images/points.png
Executable file
After Width: | Height: | Size: 8.6 KiB |
BIN
report/main.pdf
Normal file
199
report/main.tex
Executable file
@ -0,0 +1,199 @@
|
|||||||
|
\documentclass{article}
|
||||||
|
|
||||||
|
\usepackage{geometry}
|
||||||
|
\geometry{
|
||||||
|
paper = letterpaper,
|
||||||
|
top = 25mm,
|
||||||
|
bottom = 30mm,
|
||||||
|
left = 30mm,
|
||||||
|
right = 30mm,
|
||||||
|
headheight = 75mm,
|
||||||
|
footskip = 15mm,
|
||||||
|
headsep = 75mm,
|
||||||
|
}
|
||||||
|
|
||||||
|
\usepackage[
|
||||||
|
left = ``,
|
||||||
|
right = '',
|
||||||
|
leftsub = `,
|
||||||
|
rightsub = '
|
||||||
|
]{dirtytalk}
|
||||||
|
|
||||||
|
|
||||||
|
\usepackage{tcolorbox}
|
||||||
|
\usepackage{fancyhdr}
|
||||||
|
\pagestyle{fancy}
|
||||||
|
\fancyhf{}
|
||||||
|
\renewcommand{\headrulewidth}{0mm}
|
||||||
|
\fancyfoot[C]{\thepage}
|
||||||
|
|
||||||
|
|
||||||
|
\usepackage{adjustbox} % For title
|
||||||
|
\usepackage{xcolor} % Colored text
|
||||||
|
\usepackage{titlesec} % Section customization
|
||||||
|
\usepackage{graphicx} % For images
|
||||||
|
\usepackage{hyperref} % Clickable references and PDF metadata
|
||||||
|
\usepackage{fontspec} % Powerful fonts, for XeTeX
|
||||||
|
\usepackage{biblatex} % Citations
|
||||||
|
\usepackage{enumitem} % List customization
|
||||||
|
\usepackage{graphicx} % Images
|
||||||
|
\usepackage{multicol}
|
||||||
|
\addbibresource{sources.bib}
|
||||||
|
%\usepackage{amsmath}
|
||||||
|
%\usepackage{amssymb}
|
||||||
|
|
||||||
|
\graphicspath{ {./images} }
|
||||||
|
|
||||||
|
\hypersetup{
|
||||||
|
colorlinks=true,
|
||||||
|
citecolor=black,
|
||||||
|
filecolor=black,
|
||||||
|
linkcolor=black,
|
||||||
|
urlcolor=blue,
|
||||||
|
pdftitle={Celeste-AI},
|
||||||
|
pdfauthor={Mark},
|
||||||
|
pdfcreator={Mark with XeLaTeX}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
%\frenchspacing
|
||||||
|
\renewcommand*{\thefootnote}{\arabic{footnote}}
|
||||||
|
|
||||||
|
\setmainfont{PTAstraSerif}[
|
||||||
|
Path = ./Astra/,
|
||||||
|
Extension = .ttf,
|
||||||
|
UprightFont = *-Regular,
|
||||||
|
SmallCapsFont = *-Regular,
|
||||||
|
BoldFont = *-Bold.ttf,
|
||||||
|
ItalicFont = *-Italic.ttf,
|
||||||
|
BoldItalicFont = *-BoldItalic.ttf,
|
||||||
|
WordSpace = {1.1, 1.2, 1}
|
||||||
|
]
|
||||||
|
|
||||||
|
\renewcommand{\labelitemi}{$-$}
|
||||||
|
\renewcommand{\labelitemii}{$-$}
|
||||||
|
\setlist{nosep}
|
||||||
|
\setlength\parindent{0mm}
|
||||||
|
|
||||||
|
|
||||||
|
% 1: command to modify
|
||||||
|
% 2: format of label and text
|
||||||
|
% 3: label text
|
||||||
|
% 4: horizontal sep between label and text
|
||||||
|
% 5: before code
|
||||||
|
% 6: after code
|
||||||
|
\titleformat
|
||||||
|
{\section}
|
||||||
|
{\centering\large\bfseries}
|
||||||
|
{Part \thesection:}
|
||||||
|
{1ex}
|
||||||
|
{}
|
||||||
|
[]
|
||||||
|
|
||||||
|
|
||||||
|
\newcommand{\tag}[1]{
|
||||||
|
\tcbox[
|
||||||
|
nobeforeafter,
|
||||||
|
colback=white!90!cyan,
|
||||||
|
colframe=black!90!cyan,
|
||||||
|
leftrule = 0.2mm,
|
||||||
|
rightrule = 0.2mm,
|
||||||
|
toprule = 0.2mm,
|
||||||
|
bottomrule = 0.2mm,
|
||||||
|
left = 0.5mm,
|
||||||
|
right = 0.5mm,
|
||||||
|
top = 0.5mm,
|
||||||
|
bottom = 0.5mm
|
||||||
|
]{#1}
|
||||||
|
}
|
||||||
|
|
||||||
|
% 5 - 7 pages
|
||||||
|
% TNR, 1 in margins
|
||||||
|
%
|
||||||
|
|
||||||
|
% However, while describing methods and results, I want each individual to emphasize the methods that they learned and used in the project (this is broadly interpreted, this could be things like learning new methods, learning how to code something new, learning how to collect and polish data, skills like learning how to read papers, or visualization tools). Projects are a great way to get hands on experience and learn from your peers, so I also want to hear about what you gained from doing the project! It is perfectly reasonable for different people to have different strengths, I have no objection to this. I want to hear what were challenges that YOU faced, how you overcame them, and what you were able to take away from doing this project!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
% 2. Each group should also submit a copy of their code ( a general working code is fine, you don't have to resubmit the code each time you change a line).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
% Good practices for the project report:
|
||||||
|
%
|
||||||
|
% Use figures and tables freely
|
||||||
|
% Make your figures nice
|
||||||
|
% Add a short desc to figs and tables
|
||||||
|
%
|
||||||
|
% acknowledge anyone that has helped you, as well as cite any references that you have used. You can add an acknowledgement section after contributions statement.
|
||||||
|
|
||||||
|
% Lastly, it is good practice to make sure all your results are reproducible. To do this, you need to tell people exactly what parameters you used to generate each plot. If this list is small, you can include in in the figure caption, or you can include it in the text body or in the Appendix.
|
||||||
|
|
||||||
|
\begin{document}
|
||||||
|
|
||||||
|
\thispagestyle{empty}
|
||||||
|
|
||||||
|
|
||||||
|
\begin{adjustbox}{minipage=0.7\textwidth, margin=0pt \smallskipamount, center}
|
||||||
|
\begin{center}
|
||||||
|
|
||||||
|
\rule{\linewidth}{0.2mm}\\
|
||||||
|
|
||||||
|
\huge
|
||||||
|
Celeste--AI \\
|
||||||
|
\normalsize
|
||||||
|
\vspace{1ex}
|
||||||
|
Mark Ponomarenko\footnotemark[1], Timothy Chang, Ricardo Parada, Kelly Chang.
|
||||||
|
\rule{\linewidth}{0.2mm} \\
|
||||||
|
|
||||||
|
\end{center}
|
||||||
|
\end{adjustbox}
|
||||||
|
|
||||||
|
% Hack to get the footnote in the title at the bottom of the page.
|
||||||
|
\phantom{\footnotemark{}}
|
||||||
|
\footnotetext{Wrote this paper.}
|
||||||
|
|
||||||
|
|
||||||
|
\section{Abstract}
|
||||||
|
% 10ish line summary
|
||||||
|
|
||||||
|
From \textit{Super Mario Bros} \cite{pt-mario} and \textit{Atari} \cite{atari} to \textit{Go} \cite{alphago} and even \textit{Starcraft} \cite{sc2ai}, various forms of machine learning have been used to create game-playing algorithms. A common technique used for this task is reinforcement learning, especially deep $Q$-Learning. In this paper, we present a novel attempt to use these reinforcement-learning techniques to solve the first stage of \textit{Celeste Classic} \cite{celesteclassic}.
|
||||||
|
|
||||||
|
\input{parts/background}
|
||||||
|
\input{parts/introduction}
|
||||||
|
\input{parts/methods}
|
||||||
|
\input{parts/results}
|
||||||
|
\input{parts/conclusion}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\section{Contribution Statement}
|
||||||
|
|
||||||
|
\subsection*{Ricardo:}
|
||||||
|
\tag{code} \tag{hypothesis} \tag{model design} \tag{literature review} \tag{research} \tag{report}
|
||||||
|
|
||||||
|
\subsection*{Mark:}
|
||||||
|
\tag{code} \tag{model design} \tag{report} \tag{literature review} \tag{plots}
|
||||||
|
|
||||||
|
\subsection*{Timothy:}
|
||||||
|
\tag{code} \tag{hypothesis} \tag{model design} \tag{research} \tag{code debugging} \tag{report}
|
||||||
|
|
||||||
|
\subsection*{Kelly:}
|
||||||
|
\tag{code} \tag{hypothesis} \tag{model design} \tag{organization} \tag{report} \tag{presentation}
|
||||||
|
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
|
||||||
|
|
||||||
|
\printbibliography[keyword={site}, title={References: Sites}]
|
||||||
|
\printbibliography[keyword={article}, title={References: Articles}]
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
|
||||||
|
\section{Appendix}
|
||||||
|
|
||||||
|
Our code is available at \texttt{https://git.betalupi.com/Mark/celeste-ai}
|
||||||
|
|
||||||
|
\end{document}
|
15
report/parts/background.tex
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
\section{Background}
|
||||||
|
% what other people did that is closely related to yours.
|
||||||
|
|
||||||
|
Our work is heavily based off the research done by Minh et. al in \textit{Human-Level Control through Deep Reinforcement Learning} \cite{humanlevel}. The algorithm we developed to solve \textit{Celeste Classic} uses a deep Q-learning algorithm supported by replay memory, with a modified reward system and explore-exploit probability. This is very similar to the architecture presented by Minh et al.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
The greatest difference between our approach and the approach of \textit{Human-Level Control} is the input space and neural network type. Minh et. al use a convolutional neural network, which takes the game screen as input. This requires a significant amount of training epochs and computation time, and was thus an unreasonable approach for us. We instead used a plain linear neural network with two inputs: player x and player y.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
Another project similar to ours is AiSpawn's \textit{AI Learns to Speedrun Celeste} \cite{aispawn}. Here, AiSpawn completes the same task we do---solving \textit{Celeste Classic}---but he uses a completely different, evolution-based approach.
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
59
report/parts/conclusion.tex
Executable file
@ -0,0 +1,59 @@
|
|||||||
|
\section{Conclusion}
|
||||||
|
% What is the answer to the question?
|
||||||
|
|
||||||
|
Using the methods described above, we were able to successfully train a Q-learning agent to play \textit{Celeste Classic}.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
The greatest limitation of our model is its slow training speed. It took the model 4000 episodes to complete the first stage, which translates to about 8 hours of training time. A simple evolutionary algorithm, such as the one presented in \textit{AI Learns to Speedrun Celeste} \cite{aispawn} would likely have better performance than our Q-learning agent. Such an algorithm is much better for incremental tasks (such as this one) than a Q-learning algorithm.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
We could further develop this model by making it more autonomous---specifically, by training it on raw pixel data rather than curated \texttt{(player\_x, player\_y)} tuples. This modification would \textit{significantly} slow down training, and is therefore best left out of a project with a ten-week time limit.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\vspace{5mm}
|
||||||
|
|
||||||
|
While developing our model, we encountered a few questions that we could not resolve. The first of these is the effect of position scaling, which is visible in the graphs below. Note that colors are inconsistent between the graphs---since we refactored our graphing tools after the right graph was generated.
|
||||||
|
|
||||||
|
\vspace{5mm}
|
||||||
|
|
||||||
|
\begin{minipage}{0.5\textwidth}
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=0.9\textwidth]{goodprediction}
|
||||||
|
|
||||||
|
\vspace{1mm}
|
||||||
|
\begin{minipage}{0.9\textwidth}
|
||||||
|
\raggedright
|
||||||
|
\say{Best-action} plot after 500 training episodes with position rescaled to the range $[0, 1]$.
|
||||||
|
\end{minipage}
|
||||||
|
\end{center}
|
||||||
|
\end{minipage}
|
||||||
|
\hfill
|
||||||
|
\begin{minipage}{0.5\textwidth}
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=0.9\textwidth]{badprediction}
|
||||||
|
|
||||||
|
\vspace{1mm}
|
||||||
|
\begin{minipage}{0.9\textwidth}
|
||||||
|
\raggedright
|
||||||
|
\say{Best-action} plot after 500 training episodes with position in the original range $[0, 128]$.
|
||||||
|
\end{minipage}
|
||||||
|
\end{center}
|
||||||
|
\end{minipage}
|
||||||
|
|
||||||
|
\vspace{5mm}
|
||||||
|
|
||||||
|
In these graphs, we see that, without changing the model, the scaling of input values has a \textit{significant} effect on the model's performance. Large inputs cause a \say{zoomed-out linear fanning} effect in the rightmost graph, while the left graph, with rescaled values, has a much more reasonable \say{blob} pattern.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
In addition to this, we found that re-centering the game's coordinate system so that \texttt{(0, 0)} is in the center rather than the top-left also has a significant effect on the model's performance. Without centering, the model performs perfectly well. With centering, our loss grows uncontrollably and the model fails to converge.
|
||||||
|
|
||||||
|
\vspace{5mm}
|
||||||
|
|
||||||
|
In both of these cases, the results are surprising. In theory, re-scaled or re-centered data should not affect the performance of the model. This should be accounted for while training, with the weights of the neural network being adjusted to account for different input ranges. We do not have an explanation for this behavior, and would be glad to find one.
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
76
report/parts/introduction.tex
Executable file
@ -0,0 +1,76 @@
|
|||||||
|
\section{Introduction}
|
||||||
|
% Detailed summary of the problem.
|
||||||
|
% Discuss why addressing this problem is important.
|
||||||
|
|
||||||
|
|
||||||
|
\textit{Celeste} \cite{celestegame} is a fairly successful 2018 platformer, known for high-quality level design, a vibrant speedrunning\footnotemark{} community, and brutally difficult progression. It is based on \textit{Celeste Classic}, a 4-day game jam project by the same authors. There are a few reasons we chose to create an agent for \textit{Celeste Classic}:
|
||||||
|
|
||||||
|
\footnotetext{\textit{speedrunning:} a competition where participants try to complete a game as quickly as possible, often abusing bugs and design mistakes.}
|
||||||
|
|
||||||
|
\vspace{4mm}
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
\begin{minipage}{0.5\textwidth}
|
||||||
|
\noindent
|
||||||
|
1: \textit{Celeste Classic} is designed for humans, unlike the environments from, for example, the \texttt{gymnasium} \cite{gymnasium} library.
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
2: It runs on the PICO-8 \cite{pico8}, which allows us to modify its code. This grants us a reliable way to interface with the game interface. This is not true of \textit{Celeste} (2018) --- writing a wrapper for \textit{Celeste} would take a significant amout of time.
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
3: The action space of \textit{Celeste Classic} is small, especially when ineffective actions are pruned.
|
||||||
|
\end{minipage}\hfill
|
||||||
|
\begin{minipage}{0.48\textwidth}
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=0.9\textwidth]{celeste}
|
||||||
|
|
||||||
|
\vspace{1mm}
|
||||||
|
\begin{minipage}{0.8\textwidth}
|
||||||
|
The first stage of \textit{Celeste} (2018), showing the player dashing to the upper-right.
|
||||||
|
\end{minipage}
|
||||||
|
\end{center}
|
||||||
|
\end{minipage}
|
||||||
|
|
||||||
|
\vspace{5mm}
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
When we started this project, our goal was to develop an agent that would learn to finish the first stage of this game. It starts in the bottom-left corner of the stage, and needs to reach the top right. If the agent touches the spikes at the bottom of the stage, the game is reset and the agent must try again.
|
||||||
|
|
||||||
|
|
||||||
|
To achieve this end, our agent selects one of nine actions (listed below) at every time step. It does this using a Q-learning algorithm, which is described in detail later in this paper.
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
\begin{minipage}{0.5\textwidth}
|
||||||
|
Possible actions:
|
||||||
|
\begin{itemize}
|
||||||
|
\item \texttt{left}: move left
|
||||||
|
\item \texttt{right}: move right
|
||||||
|
\item \texttt{jump-l}: jump left
|
||||||
|
\item \texttt{jump-r}: jump right
|
||||||
|
\item \texttt{dash-l}: dash left
|
||||||
|
\item \texttt{dash-r}: dash right
|
||||||
|
\item \texttt{dash-u}: dash up
|
||||||
|
\item \texttt{dash-ru}: dash right-up
|
||||||
|
\item \texttt{dash-lu}: dash left-up
|
||||||
|
\end{itemize}
|
||||||
|
\end{minipage}\hfill
|
||||||
|
\begin{minipage}{0.48\textwidth}
|
||||||
|
\vspace{3ex}
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=0.48\textwidth]{jump}
|
||||||
|
\includegraphics[width=0.48\textwidth]{dash}
|
||||||
|
|
||||||
|
\vspace{1mm}
|
||||||
|
\begin{minipage}{0.9\textwidth}
|
||||||
|
The first stage of \textit{Celeste Classic}. Two possible actions our agent can take are shown: \texttt{jump-r} followed by \texttt{dash-lu}.
|
||||||
|
\end{minipage}
|
||||||
|
\end{center}
|
||||||
|
\end{minipage}
|
||||||
|
|
||||||
|
|
||||||
|
\vfill{}
|
||||||
|
|
||||||
|
This task has no direct practical applications. However, by developing an agent that completes this task, we will explore possible techniques and modifications to the traditional DQN algorithm, and we will learn how a simple machine learning model can be adjusted for a rather complicated task.
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
106
report/parts/methods.tex
Executable file
@ -0,0 +1,106 @@
|
|||||||
|
\section{Methods}
|
||||||
|
% Detailed description of methods used or developed.
|
||||||
|
|
||||||
|
Our solution to \textit{Celeste Classic} consists of two major parts: the \textit{interface} and the \textit{agent}. The first provides a high-level interface for the game, and the second uses deep Q-learning techniques to control the player.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsection{Interface}
|
||||||
|
|
||||||
|
The interface component does not have any machine-learning logic. Its primary job is to send input and receive game state from \textit{Celeste Classic}. We send input by emulating keypresses with the standard X11 utility \texttt{xtodool}. A minor consequence of this is the fact that our agent may only be run in a linux environment, but this can be remedied with a bit of extra code.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
We receive game state by abusing the PICO-8's debugging features. Since PICO-8 games are plain text files, we were able to modify the code of \textit{Celeste Classic} with a few well-placed debug-print statements. The interface captures this text, parses it, and feeds it to our model.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
The final component of the interface is timing. First, we modified \textit{Celeste Classic} to only run frames when a key is pressed. This allows the agent to run in in-game time, which wouldn't be possible otherwise: \textit{Celeste} usually runs at 30 fps, and the hardware we used to train our model cannot compute gradients that quickly.
|
||||||
|
|
||||||
|
Second, we implemented a \say{frame skip} mechanism to the interface, which tells the game to run a certain number of frames---many more than one---after the agent selects an action. The benefit of this is twofold: first, it prevents our model from training on redundant information. The game's state does not see significant change over consecutive frames. Second, frame skipping allows transitions to more directly reflect the consequences of an action.
|
||||||
|
|
||||||
|
For example, say the agent chooses to dash upwards. Due to the way \textit{Celeste} is designed, the player cannot take any other action until that dash is complete. Our frame-skip mechanism will run the game until the dash is complete, returning a new state only when a new action can be taken.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsection{Agent}
|
||||||
|
|
||||||
|
The agent we trained to solve \textit{Celeste Classic} is a plain deep Q-learning agent. A neural network estimates the reward of taking each possible action at a given state, and the agent selects the action with the highest predicted reward. This network is a four-layer fully-connected linear net with 128 nodes in each layer and a ReLU activation function on each hidden node. It has two input nodes that track the player's X and Y-position, and nine output nodes which each correspond to an action the agent can take.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsubsection{Reward}
|
||||||
|
|
||||||
|
\noindent
|
||||||
|
\begin{minipage}{0.58\textwidth}
|
||||||
|
During training, the agent receives 10 reward whenever it reaches a checkpoint (at right) or completes the stage. If the agent skips a checkpoint, it gets extra reward for each checkpoint it skipped. For example, jumping from point 1 to point 3 would give the agent 20 reward.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
These checkpoints are distributed close enough to keep the agent progressing, but far enough away to give it a challenge. Points 4 and 5 are particularly interesting in this respect. When training an agent without point 4, it would often reach the ledge and fall off, getting no reward.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
Despite many thousand epochs, this training process was unable to finish the stage. Though the ledge under point 4 is fairly easy to reach from either point 2 or 3, it is highly unlikely that an untrained agent would make it from point 2 to point 5 without the extra reward at point 4.
|
||||||
|
|
||||||
|
\end{minipage}\hfill
|
||||||
|
\begin{minipage}{0.4\textwidth}
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=0.9\textwidth]{points}
|
||||||
|
|
||||||
|
\vspace{1mm}
|
||||||
|
\begin{minipage}{0.8\textwidth}
|
||||||
|
Locations of non-final checkpoints
|
||||||
|
\end{minipage}
|
||||||
|
\end{center}
|
||||||
|
\end{minipage}
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsubsection{Exploration Probability}
|
||||||
|
|
||||||
|
At every step, we use the Q network to predict the expected reward for taking each of the nine actions. Naturally, the best action to take is the one with the highest predicted reward. In order to encourage exploration, we also take a random action with a probability given by
|
||||||
|
$$
|
||||||
|
P(c) = \epsilon_1 + (\epsilon_0 - \epsilon_1) e^{-c / d}
|
||||||
|
$$
|
||||||
|
|
||||||
|
Where $\epsilon_0$ is the initial random probability, $\epsilon_1$ is the end random probability, and $d$ is the rate at which $P(c)$ decays to $\epsilon_1$. $c$ is a rather unusual \say{time} parameter: it counts the number of times the agent has reached the next point.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
Usually, such $\epsilon$ policies depend on the number of training steps competed. For many applications, this makes sense: if a model is trained on many iterations, it begins to perform better, and thus has less of a need to explore. In our case, that doesn't work: we need to explore until we find a way to reach a checkpoint, and rely on the model's preditions once we've found one. Therefore, instead of computing $P$ with respect to a simple iteration counter, we instead compute it with respect to $c$.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsubsection{Target Network, Replay Memory}
|
||||||
|
|
||||||
|
To prevent an unstable training process, we use a \textit{target network} as described in \textit{Human-Level Control} \cite{humanlevel}. However, instead of periodically hard-resetting the target network to the Q network, we use a soft update defined by the following equation, where $W_Q$ and $W_T$ are weights of the Q network and target network, respectively.
|
||||||
|
$$
|
||||||
|
W_T = 0.05 W_Q + 0.95 W_T
|
||||||
|
$$
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
We also use \textit{replay memory} from the same paper, with a batch size of 100 and a total size of 50,000. Our model is optimized using Adam with a learning rate of 0.001.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
\subsubsection{Bellman Equation}
|
||||||
|
|
||||||
|
Our goal is to train our model to approximate the value function $Q(s, a)$, which tells us the value of taking action $a$ at state $s$. This approximation can then be used to choose the best action at each state. We define $Q$ using the Bellman equation:
|
||||||
|
$$
|
||||||
|
Q(s, a) = r(s) + \gamma Q(s_a)
|
||||||
|
$$
|
||||||
|
Where $r(s)$ is the reward at state $s$, $Q(s_a)$ is the value of the state we get to when we perform action $a$ at state $s$, and $\gamma$ is a discount factor that makes present reward more valuable than future reward. In our model, we set $\gamma$ to $0.9$.
|
||||||
|
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
22
report/parts/results.tex
Executable file
@ -0,0 +1,22 @@
|
|||||||
|
\section{Results}
|
||||||
|
% The results of applying the methods to the data set.
|
||||||
|
% Also discuss why the results makes sense, possible implications.
|
||||||
|
|
||||||
|
After sufficient training, our model consistently completed the first stage of \textit{Celeste}. 4000 training episodes were required to achieve this result.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
The figure below summarizes our model's performance during training. The color of each pixel in the plot is determined by the action with the highest predicted value, and the path the agent takes through the stage is shown in white. The agent completes the stage in the \say{4000 Episodes} plot, and fails to complete it within the allocated time limit in all the rest. Training the model on more than 4000 episodes did not have a significant effect on the agent's behavior.
|
||||||
|
|
||||||
|
\begin{center}
|
||||||
|
\includegraphics[width=\textwidth]{plots}
|
||||||
|
\end{center}
|
||||||
|
|
||||||
|
A few things are interesting about these results. First, we see that the best-action patterns in the above graphs to not resemble the shape of the stage. At every point that the agent doesn't visit, the predicted best action does not resemble the action an intelligent human player would take. This is because the model is not trained on these points. The predictions there are a side-effect of the training steps applied to the points in the agent's path.
|
||||||
|
|
||||||
|
\vspace{2mm}
|
||||||
|
|
||||||
|
Second, the plots above clearly depict the effect of our modified explore/exploit policy. We can see the first few segments of the agent's path are the same in each graph. In addition, the more the agent trains, the longer this repeated path is. This is a direct result of our explore/exploit policy: our agent stops exploring sections of the stage it can reliably complete, and therefore repeats paths that work.
|
||||||
|
|
||||||
|
\vfill
|
||||||
|
\pagebreak
|
113
report/sources.bib
Executable file
@ -0,0 +1,113 @@
|
|||||||
|
@article{humanlevel,
|
||||||
|
author = {Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Rusu, Andrei A. and Veness, Joel and Bellemare, Marc G. and Graves, Alex and Riedmiller, Martin and Fidjeland, Andreas K. and Ostrovski, Georg and Petersen, Stig and Beattie, Charles and Sadik, Amir and Antonoglou, Ioannis and King, Helen and Kumaran, Dharshan and Wierstra, Daan and Legg, Shane and Hassabis, Demis},
|
||||||
|
title = {Human-level control through deep reinforcement learning},
|
||||||
|
|
||||||
|
description = {Human-level control through deep reinforcement learning - nature14236.pdf},
|
||||||
|
issn = {00280836},
|
||||||
|
journal = {Nature},
|
||||||
|
month = feb,
|
||||||
|
number = 7540,
|
||||||
|
pages = {529--533},
|
||||||
|
publisher = {Nature Publishing Group, a division of Macmillan Publishers Limited. All Rights Reserved.},
|
||||||
|
timestamp = {2015-08-26T14:46:40.000+0200},
|
||||||
|
url = {http://dx.doi.org/10.1038/nature14236},
|
||||||
|
volume = 518,
|
||||||
|
year = 2015,
|
||||||
|
keywords = {article}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@article{atari,
|
||||||
|
author = {Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Graves, Alex and Antonoglou, Ioannis and Wierstra, Daan and Riedmiller, Martin},
|
||||||
|
title = {Playing Atari with Deep Reinforcement Learning},
|
||||||
|
url = {http://arxiv.org/abs/1312.5602},
|
||||||
|
year = 2013,
|
||||||
|
keywords = {article}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{alphago,
|
||||||
|
author = {Silver, David and Schrittwieser, Julian and Simonyan, Karen and Antonoglou, Ioannis and Huang, Aja and Guez, Arthur and Hubert, Thomas and Baker, Lucas and Lai, Matthew and Bolton, Adrian and Chen, Yutian and Lillicrap, Timothy and Hui, Fan and Sifre, Laurent and van den Driessche, George and Graepel, Thore and Hassabis, Demis},
|
||||||
|
description = {Mastering the game of Go without human knowledge},
|
||||||
|
journal = {Nature},
|
||||||
|
pages = {354--},
|
||||||
|
publisher = {Macmillan Publishers Limited},
|
||||||
|
title = {Mastering the game of Go without human knowledge},
|
||||||
|
url = {http://dx.doi.org/10.1038/nature24270},
|
||||||
|
volume = 550,
|
||||||
|
year = 2017,
|
||||||
|
keywords = {article}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@online{sc2ai,
|
||||||
|
author = {},
|
||||||
|
title = {SC2 AI Arena},
|
||||||
|
url = {https://sc2ai.net},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@online{pt-mario,
|
||||||
|
author = {Feng, Yuansong and Subramanian, Suraj and Wang, Howard and Guo, Steven},
|
||||||
|
title = {Train a Mario-playing RL Agent},
|
||||||
|
url = {https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@online{pt-cart,
|
||||||
|
author = {Paszke, Adam and Towers, Mark},
|
||||||
|
title = {Reinforcement Learning (DQN) Tutorial},
|
||||||
|
url = {https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
||||||
|
@online{celestegame,
|
||||||
|
author = {},
|
||||||
|
title = {Celeste},
|
||||||
|
url = {https://www.celestegame.com},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site},
|
||||||
|
year = 2018
|
||||||
|
}
|
||||||
|
|
||||||
|
@online{celesteclassic,
|
||||||
|
author = {},
|
||||||
|
title = {Celeste Classic},
|
||||||
|
url = {https://www.lexaloffle.com/bbs/?pid=11722},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site},
|
||||||
|
year = 2015
|
||||||
|
}
|
||||||
|
|
||||||
|
@online{pico8,
|
||||||
|
author = {},
|
||||||
|
title = {PICO-8},
|
||||||
|
url = {https://www.lexaloffle.com/pico-8.php},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
||||||
|
@online{gymnasium,
|
||||||
|
author = {},
|
||||||
|
title = {Gymnasium},
|
||||||
|
url = {https://github.com/Farama-Foundation/Gymnasium},
|
||||||
|
addendum = {Accessed 2023-02-25},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
||||||
|
@online{aispawn,
|
||||||
|
author = {AiSpawn},
|
||||||
|
title = {AI Learns to Speedrun Celeste},
|
||||||
|
url = {https://www.youtube.com/watch?v=y8g1AcTYovg},
|
||||||
|
organization = {Youtube},
|
||||||
|
addendum = {Accessed 2023-02-22},
|
||||||
|
keywords = {site}
|
||||||
|
}
|
||||||
|
|
21
resources/README.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Celeste-AI resources
|
||||||
|
|
||||||
|
- `./carts/celeste.p8`: Unmodified *Celeste Classic* cart
|
||||||
|
- `./carts/hackcel.p8`: *Celeste Classic* modified with delays and debug
|
||||||
|
- `./pico8`: An old version of PICO-8
|
||||||
|
- `./images`: Miscellaneous images for pretty plots. Not used by scripts yet.
|
||||||
|
|
||||||
|
|
||||||
|
## PICO-8 setup
|
||||||
|
1. Run `./pico-8/linux/pico8`
|
||||||
|
2. Once it starts, type `folder`. This will open the PICO-8 root folder in your file browser.
|
||||||
|
3. Copy both carts in `./carts` into that folder.
|
||||||
|
|
||||||
|
## PICO-8 basics
|
||||||
|
- `load name.p8`: load a game into memory
|
||||||
|
- `run`: run that game
|
||||||
|
- `folder`: open PICO-8 root in file browser
|
||||||
|
- `shutdown`: exit PICO-8
|
||||||
|
- `<escape key>`: toggle editor or exit a game.
|
||||||
|
|
||||||
|
That's all you need to know to train Celeste-AI.
|
@ -30,6 +30,16 @@ k_jump=4
|
|||||||
k_dash=5
|
k_dash=5
|
||||||
|
|
||||||
|
|
||||||
|
-- Set to false while training or running the model.
|
||||||
|
-- Set to true to play the game manually with debug print.
|
||||||
|
-- (good for finding coordinates of checkpoints)
|
||||||
|
--
|
||||||
|
-- If true, disables most hack features:
|
||||||
|
-- - screenshots at every frame
|
||||||
|
-- - frame skipping
|
||||||
|
-- - waiting for input
|
||||||
|
hack_human_mode = false
|
||||||
|
|
||||||
-- If true, disable screensake
|
-- If true, disable screensake
|
||||||
hack_no_shake = true
|
hack_no_shake = true
|
||||||
|
|
||||||
@ -276,11 +286,11 @@ player =
|
|||||||
this.djump-=1
|
this.djump-=1
|
||||||
this.dash_time=4
|
this.dash_time=4
|
||||||
has_dashed=true
|
has_dashed=true
|
||||||
|
|
||||||
-- HACK: fast-forward dashes
|
-- HACK: fast-forward dashes
|
||||||
hack_frame_foward_bonus = 10
|
hack_frame_foward_bonus = 10
|
||||||
hack_can_dash = false
|
hack_can_dash = false
|
||||||
|
|
||||||
this.dash_effect_time=10
|
this.dash_effect_time=10
|
||||||
local v_input=(btn(k_up) and -1 or (btn(k_down) and 1 or 0))
|
local v_input=(btn(k_up) and -1 or (btn(k_down) and 1 or 0))
|
||||||
if input!=0 then
|
if input!=0 then
|
||||||
@ -1201,14 +1211,48 @@ function load_room(x,y)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
-- update function --
|
|
||||||
-----------------------
|
|
||||||
|
|
||||||
|
function hack_send_state()
|
||||||
|
out_string = "dc:" .. tostr(deaths) .. ";"
|
||||||
|
|
||||||
|
-- Dash status
|
||||||
|
if hack_can_dash then
|
||||||
|
out_string = out_string .. "ds:t;"
|
||||||
|
else
|
||||||
|
out_string = out_string .. "ds:f;"
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Player state
|
||||||
|
for k, v in pairs(hack_player_state) do
|
||||||
|
out_string = out_string .. k ..":" .. v .. ";"
|
||||||
|
end
|
||||||
|
|
||||||
|
-- Fruit status
|
||||||
|
out_string = out_string .. "fr:"
|
||||||
|
for i = 0,29 do
|
||||||
|
if got_fruit[i] then
|
||||||
|
out_string = out_string .. "t"
|
||||||
|
else
|
||||||
|
out_string = out_string .. "f"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
out_string = out_string .. ";"
|
||||||
|
printh(out_string)
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
-- update function --
|
||||||
|
----------------------
|
||||||
|
|
||||||
-- _update runs at 30 fps
|
-- _update runs at 30 fps
|
||||||
-- _update60 does 60 fps
|
-- _update60 does 60 fps
|
||||||
-- default for celeste is 30.
|
-- default for celeste is 30.
|
||||||
function _update()
|
function _update()
|
||||||
|
if hack_human_mode then
|
||||||
|
old_update()
|
||||||
|
hack_send_state()
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
-- Run at full speed until ready
|
-- Run at full speed until ready
|
||||||
if not hack_ready then
|
if not hack_ready then
|
||||||
@ -1273,38 +1317,16 @@ function _update()
|
|||||||
|
|
||||||
|
|
||||||
hack_has_sent_first_message = true
|
hack_has_sent_first_message = true
|
||||||
out_string = "dc:" .. tostr(deaths) .. ";"
|
hack_send_state()
|
||||||
|
|
||||||
-- Dash status
|
|
||||||
if hack_can_dash then
|
|
||||||
out_string = out_string .. "ds:t;"
|
|
||||||
else
|
|
||||||
out_string = out_string .. "ds:f;"
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
-- Fruit status
|
|
||||||
out_string = out_string .. "fr:"
|
|
||||||
for i = 0,29 do
|
|
||||||
if got_fruit[i] then
|
|
||||||
out_string = out_string .. "t"
|
|
||||||
else
|
|
||||||
out_string = out_string .. "f"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
out_string = out_string .. ";"
|
|
||||||
|
|
||||||
|
|
||||||
for k, v in pairs(hack_player_state) do
|
|
||||||
out_string = out_string .. k ..":" .. v .. ";"
|
|
||||||
end
|
|
||||||
printh(out_string)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Called at the same rate as _update,
|
-- Called at the same rate as _update,
|
||||||
-- but not necessarily at the same time.
|
-- but not necessarily at the same time.
|
||||||
function _draw()
|
function _draw()
|
||||||
--old_draw()
|
if hack_human_mode then
|
||||||
|
old_draw()
|
||||||
|
return
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function old_update()
|
function old_update()
|
BIN
resources/first-bot-finish.gif
Normal file
After Width: | Height: | Size: 1.7 MiB |
Before Width: | Height: | Size: 9.5 KiB After Width: | Height: | Size: 9.5 KiB |
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
Before Width: | Height: | Size: 442 B After Width: | Height: | Size: 442 B |