Reorganized celeste code
parent
d648d896c0
commit
55ac62dc47
|
@ -0,0 +1,15 @@
|
|||
# Celeste-AI: A Celeste Classic DQL Agent
|
||||
|
||||
This is an attempt to create an agent that learns to play Celeste Classic.
|
||||
|
||||
## Contents
|
||||
- `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
|
||||
|
||||
|
||||
## Setup
|
||||
|
||||
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
|
||||
|
||||
1. `cd` into this directory
|
||||
2. Make and enter a venv
|
||||
3. `pip install -e .`
|
|
@ -0,0 +1,6 @@
|
|||
from .network import DQN
|
||||
from .network import Transition
|
||||
|
||||
from .celeste import Celeste
|
||||
from .celeste import CelesteError
|
||||
from .celeste import CelesteState
|
|
@ -63,14 +63,15 @@ class Celeste:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
pico_path,
|
||||
*,
|
||||
state_timeout = 30,
|
||||
cart_name = "hackcel.p8"
|
||||
cart_name = "hackcel.p8",
|
||||
):
|
||||
|
||||
# Start pico-8
|
||||
self._process = subprocess.Popen(
|
||||
"resources/pico-8/linux/pico8",
|
||||
pico_path,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT
|
||||
|
@ -272,7 +273,6 @@ class Celeste:
|
|||
|
||||
|
||||
def update_loop(self, before, after):
|
||||
|
||||
# Waits for stdout from pico-8 process
|
||||
for line in iter(self._process.stdout.readline, ""):
|
||||
l = line.decode("utf-8")[:-1].strip()
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DQN(torch.nn.Module):
|
||||
def __init__(self, n_observations: int, n_actions: int):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_observations, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.torch.nn.Linear(128, n_actions)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .plot_actual_reward import actual_reward
|
||||
from .plot_predicted_reward import predicted_reward
|
|
@ -0,0 +1,81 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
def actual_reward(
|
||||
model_file: Path,
|
||||
target_point: tuple[int, int],
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
|
||||
checkpoint = torch.load(
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
memory = checkpoint["memory"]
|
||||
|
||||
|
||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
for m in memory:
|
||||
x, y, x_target, y_target = list(m.state[0])
|
||||
|
||||
action = m.action[0].item()
|
||||
x = int(x.item())
|
||||
y = int(y.item())
|
||||
x_target = int(x_target.item())
|
||||
y_target = int(y_target.item())
|
||||
|
||||
# Only plot memory related to this point
|
||||
if (x_target, y_target) != target_point:
|
||||
continue
|
||||
|
||||
if m.reward[0].item() == 1:
|
||||
r[y][x][action] += 1
|
||||
else:
|
||||
r[y][x][action] -= 1
|
||||
|
||||
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
|
||||
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
r[:,:,a],
|
||||
cmap = "seismic_r",
|
||||
vmin = -10,
|
||||
vmax = 10
|
||||
)
|
||||
|
||||
# Draw target point on plot
|
||||
ax.plot(
|
||||
target_point[0],
|
||||
target_point[1],
|
||||
"k."
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
|
@ -2,38 +2,36 @@ import torch
|
|||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool
|
||||
|
||||
from celeste import Celeste
|
||||
from main import DQN
|
||||
from main import Transition
|
||||
|
||||
# Use cpu, this script is faster in parallel.
|
||||
compute_device = torch.device("cpu")
|
||||
|
||||
input_model = Path("model_data/current")
|
||||
|
||||
src_dir = input_model / "model_archive"
|
||||
out_dir = input_model_dir / "plots/predicted_value"
|
||||
out_dir.mkdir(parents = True, exist_ok = True)
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
|
||||
def predicted_reward(
|
||||
model_file: Path,
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
def plot(src):
|
||||
# Create and load model
|
||||
policy_net = DQN(
|
||||
len(Celeste.state_number_map),
|
||||
len(Celeste.action_space)
|
||||
).to(compute_device)
|
||||
|
||||
).to(device)
|
||||
checkpoint = torch.load(
|
||||
src,
|
||||
map_location = compute_device
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
|
||||
|
||||
# Compute preditions
|
||||
p = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
|
@ -44,13 +42,14 @@ def plot(src):
|
|||
torch.tensor(
|
||||
[c, r, 60, 80],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
device = device
|
||||
).unsqueeze(0)
|
||||
)[0])
|
||||
p[r][c] = k
|
||||
|
||||
|
||||
# Plot predictions
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
|
@ -67,14 +66,12 @@ def plot(src):
|
|||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
print(src)
|
||||
fig.savefig(out_dir / f"{src.stem}.png")
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Pool(5) as p:
|
||||
p.map(plot, list(src_dir.iterdir()))
|
||||
|
||||
|
||||
|
|
@ -6,7 +6,9 @@ import math
|
|||
import json
|
||||
import torch
|
||||
|
||||
from celeste import Celeste
|
||||
from celeste_ai import Celeste
|
||||
from celeste_ai import DQN
|
||||
from celeste_ai import Transition
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -62,48 +64,6 @@ if __name__ == "__main__":
|
|||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
GAMMA = 0.9
|
||||
|
||||
|
||||
# Outline our network
|
||||
class DQN(torch.nn.Module):
|
||||
def __init__(self, n_observations: int, n_actions: int):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_observations, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.torch.nn.Linear(128, n_actions)
|
||||
)
|
||||
|
||||
# Can be called with one input, or with a batch.
|
||||
#
|
||||
# Returns tensor(
|
||||
# [ Q(s, left), Q(s, right) ], ...
|
||||
# )
|
||||
#
|
||||
# Recall that Q(s, a) is the (expected) return of taking
|
||||
# action `a` at state `s`
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
steps_done = 0
|
||||
num_episodes = 100
|
||||
episode_number = 0
|
||||
|
@ -139,7 +99,10 @@ if __name__ == "__main__":
|
|||
|
||||
if model_save_path.is_file():
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(model_save_path)
|
||||
checkpoint = torch.load(
|
||||
model_save_path,
|
||||
map_location = compute_device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
@ -315,7 +278,6 @@ def optimize_model():
|
|||
return loss
|
||||
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
|
@ -351,7 +313,6 @@ def on_state_before(celeste):
|
|||
return state, action
|
||||
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
|
||||
|
@ -478,7 +439,9 @@ def on_state_after(celeste, before_out):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Celeste()
|
||||
c = Celeste(
|
||||
"resources/pico-8/linux/pico8"
|
||||
)
|
||||
|
||||
c.update_loop(
|
||||
on_state_before,
|
|
@ -1,79 +0,0 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from multiprocessing import Pool
|
||||
|
||||
from celeste import Celeste
|
||||
from main import DQN
|
||||
from main import Transition
|
||||
|
||||
# Use cpu, this script is faster in parallel.
|
||||
compute_device = torch.device("cpu")
|
||||
|
||||
input_model = Path("model_data/after_change")
|
||||
|
||||
out_dir = input_model / "plots/actual_reward"
|
||||
out_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
checkpoint = torch.load(
|
||||
input_model / "model.torch",
|
||||
map_location = compute_device
|
||||
)
|
||||
memory = checkpoint["memory"]
|
||||
|
||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
|
||||
for m in memory:
|
||||
x, y, x_target, y_target = list(m.state[0])
|
||||
|
||||
action = m.action[0].item()
|
||||
str_action = Celeste.action_space[action]
|
||||
x = int(x.item())
|
||||
y = int(y.item())
|
||||
x_target = int(x_target.item())
|
||||
y_target = int(y_target.item())
|
||||
|
||||
if (x_target, y_target) != (60, 80):
|
||||
continue
|
||||
|
||||
if m.reward[0].item() == 1:
|
||||
r[y][x][action] += 1
|
||||
else:
|
||||
r[y][x][action] -= 1
|
||||
|
||||
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
|
||||
|
||||
# Plot predictions
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
r[:,:,a],
|
||||
cmap = "seismic_r",
|
||||
vmin = -10,
|
||||
vmax = 10
|
||||
)
|
||||
|
||||
ax.plot(60, 80, "k.")
|
||||
#ax.annotate(
|
||||
# "Target",
|
||||
# (60, 80),
|
||||
# textcoords = "offset points",
|
||||
# xytext = (0, -20),
|
||||
# ha = "center"
|
||||
#)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
|
||||
fig.savefig(out_dir / "actual.png")
|
||||
plt.close()
|
|
@ -0,0 +1,25 @@
|
|||
[build-system]
|
||||
requires = [ "setuptools>=61.0" ]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = [ "." ]
|
||||
include = ["celeste_ai*"]
|
||||
namespaces = false
|
||||
|
||||
|
||||
[project]
|
||||
name = "celeste_ai"
|
||||
description = "A reinforcement learning agent for Celeste Classic"
|
||||
version = "1.0.0"
|
||||
dependencies = [
|
||||
"matplotlib==3.6.3",
|
||||
"torch==1.13.1"
|
||||
]
|
||||
authors = [
|
||||
{ name="Mark", email="mark@betalupi.com" }
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.7"
|
||||
license = {text = "GNU General Public License v3 (GPLv3)"}
|
|
@ -1,2 +0,0 @@
|
|||
matplotlib==3.6.3
|
||||
torch==1.13.1
|
Reference in New Issue