Reorganized celeste code
parent
d648d896c0
commit
55ac62dc47
|
@ -0,0 +1,15 @@
|
||||||
|
# Celeste-AI: A Celeste Classic DQL Agent
|
||||||
|
|
||||||
|
This is an attempt to create an agent that learns to play Celeste Classic.
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
- `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
|
||||||
|
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
|
||||||
|
|
||||||
|
1. `cd` into this directory
|
||||||
|
2. Make and enter a venv
|
||||||
|
3. `pip install -e .`
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .network import DQN
|
||||||
|
from .network import Transition
|
||||||
|
|
||||||
|
from .celeste import Celeste
|
||||||
|
from .celeste import CelesteError
|
||||||
|
from .celeste import CelesteState
|
|
@ -63,14 +63,15 @@ class Celeste:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
pico_path,
|
||||||
*,
|
*,
|
||||||
state_timeout = 30,
|
state_timeout = 30,
|
||||||
cart_name = "hackcel.p8"
|
cart_name = "hackcel.p8",
|
||||||
):
|
):
|
||||||
|
|
||||||
# Start pico-8
|
# Start pico-8
|
||||||
self._process = subprocess.Popen(
|
self._process = subprocess.Popen(
|
||||||
"resources/pico-8/linux/pico8",
|
pico_path,
|
||||||
shell=True,
|
shell=True,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT
|
stderr=subprocess.STDOUT
|
||||||
|
@ -272,7 +273,6 @@ class Celeste:
|
||||||
|
|
||||||
|
|
||||||
def update_loop(self, before, after):
|
def update_loop(self, before, after):
|
||||||
|
|
||||||
# Waits for stdout from pico-8 process
|
# Waits for stdout from pico-8 process
|
||||||
for line in iter(self._process.stdout.readline, ""):
|
for line in iter(self._process.stdout.readline, ""):
|
||||||
l = line.decode("utf-8")[:-1].strip()
|
l = line.decode("utf-8")[:-1].strip()
|
|
@ -0,0 +1,36 @@
|
||||||
|
import torch
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
|
Transition = namedtuple(
|
||||||
|
"Transition",
|
||||||
|
(
|
||||||
|
"state",
|
||||||
|
"action",
|
||||||
|
"next_state",
|
||||||
|
"reward"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DQN(torch.nn.Module):
|
||||||
|
def __init__(self, n_observations: int, n_actions: int):
|
||||||
|
super(DQN, self).__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_observations, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.torch.nn.Linear(128, n_actions)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .plot_actual_reward import actual_reward
|
||||||
|
from .plot_predicted_reward import predicted_reward
|
|
@ -0,0 +1,81 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
# All of the following are required to load
|
||||||
|
# a pickled model.
|
||||||
|
from celeste_ai.celeste import Celeste
|
||||||
|
from celeste_ai.network import DQN
|
||||||
|
from celeste_ai.network import Transition
|
||||||
|
|
||||||
|
def actual_reward(
|
||||||
|
model_file: Path,
|
||||||
|
target_point: tuple[int, int],
|
||||||
|
out_filename: Path,
|
||||||
|
*,
|
||||||
|
device = torch.device("cpu")
|
||||||
|
):
|
||||||
|
if not model_file.is_file():
|
||||||
|
raise Exception(f"Bad model file {model_file}")
|
||||||
|
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = torch.load(
|
||||||
|
model_file,
|
||||||
|
map_location = device
|
||||||
|
)
|
||||||
|
memory = checkpoint["memory"]
|
||||||
|
|
||||||
|
|
||||||
|
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||||||
|
for m in memory:
|
||||||
|
x, y, x_target, y_target = list(m.state[0])
|
||||||
|
|
||||||
|
action = m.action[0].item()
|
||||||
|
x = int(x.item())
|
||||||
|
y = int(y.item())
|
||||||
|
x_target = int(x_target.item())
|
||||||
|
y_target = int(y_target.item())
|
||||||
|
|
||||||
|
# Only plot memory related to this point
|
||||||
|
if (x_target, y_target) != target_point:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.reward[0].item() == 1:
|
||||||
|
r[y][x][action] += 1
|
||||||
|
else:
|
||||||
|
r[y][x][action] -= 1
|
||||||
|
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||||
|
|
||||||
|
|
||||||
|
for a in range(len(axs.ravel())):
|
||||||
|
ax = axs.ravel()[a]
|
||||||
|
ax.set(
|
||||||
|
adjustable = "box",
|
||||||
|
aspect = "equal",
|
||||||
|
title = Celeste.action_space[a]
|
||||||
|
)
|
||||||
|
|
||||||
|
plot = ax.pcolor(
|
||||||
|
r[:,:,a],
|
||||||
|
cmap = "seismic_r",
|
||||||
|
vmin = -10,
|
||||||
|
vmax = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw target point on plot
|
||||||
|
ax.plot(
|
||||||
|
target_point[0],
|
||||||
|
target_point[1],
|
||||||
|
"k."
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.invert_yaxis()
|
||||||
|
fig.colorbar(plot)
|
||||||
|
|
||||||
|
fig.savefig(out_filename)
|
||||||
|
plt.close()
|
|
@ -2,38 +2,36 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
from celeste import Celeste
|
# All of the following are required to load
|
||||||
from main import DQN
|
# a pickled model.
|
||||||
from main import Transition
|
from celeste_ai.celeste import Celeste
|
||||||
|
from celeste_ai.network import DQN
|
||||||
# Use cpu, this script is faster in parallel.
|
from celeste_ai.network import Transition
|
||||||
compute_device = torch.device("cpu")
|
|
||||||
|
|
||||||
input_model = Path("model_data/current")
|
|
||||||
|
|
||||||
src_dir = input_model / "model_archive"
|
|
||||||
out_dir = input_model_dir / "plots/predicted_value"
|
|
||||||
out_dir.mkdir(parents = True, exist_ok = True)
|
|
||||||
|
|
||||||
|
|
||||||
|
def predicted_reward(
|
||||||
|
model_file: Path,
|
||||||
|
out_filename: Path,
|
||||||
|
*,
|
||||||
|
device = torch.device("cpu")
|
||||||
|
):
|
||||||
|
if not model_file.is_file():
|
||||||
|
raise Exception(f"Bad model file {model_file}")
|
||||||
|
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||||
|
|
||||||
def plot(src):
|
# Create and load model
|
||||||
policy_net = DQN(
|
policy_net = DQN(
|
||||||
len(Celeste.state_number_map),
|
len(Celeste.state_number_map),
|
||||||
len(Celeste.action_space)
|
len(Celeste.action_space)
|
||||||
).to(compute_device)
|
).to(device)
|
||||||
|
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
src,
|
model_file,
|
||||||
map_location = compute_device
|
map_location = device
|
||||||
)
|
)
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
|
||||||
|
|
||||||
|
|
||||||
# Compute preditions
|
# Compute preditions
|
||||||
p = np.zeros((128, 128, 8), dtype=np.float32)
|
p = np.zeros((128, 128, 8), dtype=np.float32)
|
||||||
|
@ -44,13 +42,14 @@ def plot(src):
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[c, r, 60, 80],
|
[c, r, 60, 80],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = compute_device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
)[0])
|
)[0])
|
||||||
p[r][c] = k
|
p[r][c] = k
|
||||||
|
|
||||||
|
|
||||||
# Plot predictions
|
# Plot predictions
|
||||||
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||||
for a in range(len(axs.ravel())):
|
for a in range(len(axs.ravel())):
|
||||||
ax = axs.ravel()[a]
|
ax = axs.ravel()[a]
|
||||||
ax.set(
|
ax.set(
|
||||||
|
@ -67,14 +66,12 @@ def plot(src):
|
||||||
|
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
fig.colorbar(plot)
|
fig.colorbar(plot)
|
||||||
print(src)
|
|
||||||
fig.savefig(out_dir / f"{src.stem}.png")
|
fig.savefig(out_filename)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(plot, list(src_dir.iterdir()))
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,9 @@ import math
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from celeste import Celeste
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
from celeste_ai import Transition
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -62,48 +64,6 @@ if __name__ == "__main__":
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
# GAMMA is the discount factor as mentioned in the previous section
|
||||||
GAMMA = 0.9
|
GAMMA = 0.9
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
|
||||||
class DQN(torch.nn.Module):
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
|
|
||||||
self.layers = torch.nn.Sequential(
|
|
||||||
torch.nn.Linear(n_observations, 128),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
|
|
||||||
torch.nn.Linear(128, 128),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
|
|
||||||
torch.nn.Linear(128, 128),
|
|
||||||
torch.nn.ReLU(),
|
|
||||||
|
|
||||||
torch.torch.nn.Linear(128, n_actions)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
|
||||||
#
|
|
||||||
# Returns tensor(
|
|
||||||
# [ Q(s, left), Q(s, right) ], ...
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
|
||||||
# action `a` at state `s`
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
steps_done = 0
|
steps_done = 0
|
||||||
num_episodes = 100
|
num_episodes = 100
|
||||||
episode_number = 0
|
episode_number = 0
|
||||||
|
@ -139,7 +99,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if model_save_path.is_file():
|
if model_save_path.is_file():
|
||||||
# Load model if one exists
|
# Load model if one exists
|
||||||
checkpoint = torch.load(model_save_path)
|
checkpoint = torch.load(
|
||||||
|
model_save_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
@ -315,7 +278,6 @@ def optimize_model():
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
global steps_done
|
global steps_done
|
||||||
|
|
||||||
|
@ -351,7 +313,6 @@ def on_state_before(celeste):
|
||||||
return state, action
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
def on_state_after(celeste, before_out):
|
||||||
global episode_number
|
global episode_number
|
||||||
|
|
||||||
|
@ -478,7 +439,9 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
c = Celeste()
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
c.update_loop(
|
c.update_loop(
|
||||||
on_state_before,
|
on_state_before,
|
|
@ -1,79 +0,0 @@
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
from celeste import Celeste
|
|
||||||
from main import DQN
|
|
||||||
from main import Transition
|
|
||||||
|
|
||||||
# Use cpu, this script is faster in parallel.
|
|
||||||
compute_device = torch.device("cpu")
|
|
||||||
|
|
||||||
input_model = Path("model_data/after_change")
|
|
||||||
|
|
||||||
out_dir = input_model / "plots/actual_reward"
|
|
||||||
out_dir.mkdir(parents = True, exist_ok = True)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = torch.load(
|
|
||||||
input_model / "model.torch",
|
|
||||||
map_location = compute_device
|
|
||||||
)
|
|
||||||
memory = checkpoint["memory"]
|
|
||||||
|
|
||||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
|
||||||
|
|
||||||
for m in memory:
|
|
||||||
x, y, x_target, y_target = list(m.state[0])
|
|
||||||
|
|
||||||
action = m.action[0].item()
|
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
x = int(x.item())
|
|
||||||
y = int(y.item())
|
|
||||||
x_target = int(x_target.item())
|
|
||||||
y_target = int(y_target.item())
|
|
||||||
|
|
||||||
if (x_target, y_target) != (60, 80):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.reward[0].item() == 1:
|
|
||||||
r[y][x][action] += 1
|
|
||||||
else:
|
|
||||||
r[y][x][action] -= 1
|
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
|
||||||
|
|
||||||
|
|
||||||
# Plot predictions
|
|
||||||
for a in range(len(axs.ravel())):
|
|
||||||
ax = axs.ravel()[a]
|
|
||||||
ax.set(
|
|
||||||
adjustable = "box",
|
|
||||||
aspect = "equal",
|
|
||||||
title = Celeste.action_space[a]
|
|
||||||
)
|
|
||||||
|
|
||||||
plot = ax.pcolor(
|
|
||||||
r[:,:,a],
|
|
||||||
cmap = "seismic_r",
|
|
||||||
vmin = -10,
|
|
||||||
vmax = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.plot(60, 80, "k.")
|
|
||||||
#ax.annotate(
|
|
||||||
# "Target",
|
|
||||||
# (60, 80),
|
|
||||||
# textcoords = "offset points",
|
|
||||||
# xytext = (0, -20),
|
|
||||||
# ha = "center"
|
|
||||||
#)
|
|
||||||
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
|
||||||
|
|
||||||
fig.savefig(out_dir / "actual.png")
|
|
||||||
plt.close()
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
[build-system]
|
||||||
|
requires = [ "setuptools>=61.0" ]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = [ "." ]
|
||||||
|
include = ["celeste_ai*"]
|
||||||
|
namespaces = false
|
||||||
|
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "celeste_ai"
|
||||||
|
description = "A reinforcement learning agent for Celeste Classic"
|
||||||
|
version = "1.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"matplotlib==3.6.3",
|
||||||
|
"torch==1.13.1"
|
||||||
|
]
|
||||||
|
authors = [
|
||||||
|
{ name="Mark", email="mark@betalupi.com" }
|
||||||
|
]
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.7"
|
||||||
|
license = {text = "GNU General Public License v3 (GPLv3)"}
|
|
@ -1,2 +0,0 @@
|
||||||
matplotlib==3.6.3
|
|
||||||
torch==1.13.1
|
|
Reference in New Issue