Mark
/
celeste-ai
Archived
1
0
Fork 0

Reorganized celeste code

master
Mark 2023-02-19 20:57:19 -08:00
parent d648d896c0
commit 55ac62dc47
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
11 changed files with 201 additions and 157 deletions

15
celeste/README.md Normal file
View File

@ -0,0 +1,15 @@
# Celeste-AI: A Celeste Classic DQL Agent
This is an attempt to create an agent that learns to play Celeste Classic.
## Contents
- `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
## Setup
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
1. `cd` into this directory
2. Make and enter a venv
3. `pip install -e .`

View File

@ -0,0 +1,6 @@
from .network import DQN
from .network import Transition
from .celeste import Celeste
from .celeste import CelesteError
from .celeste import CelesteState

View File

@ -63,14 +63,15 @@ class Celeste:
def __init__( def __init__(
self, self,
pico_path,
*, *,
state_timeout = 30, state_timeout = 30,
cart_name = "hackcel.p8" cart_name = "hackcel.p8",
): ):
# Start pico-8 # Start pico-8
self._process = subprocess.Popen( self._process = subprocess.Popen(
"resources/pico-8/linux/pico8", pico_path,
shell=True, shell=True,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.STDOUT stderr=subprocess.STDOUT
@ -272,7 +273,6 @@ class Celeste:
def update_loop(self, before, after): def update_loop(self, before, after):
# Waits for stdout from pico-8 process # Waits for stdout from pico-8 process
for line in iter(self._process.stdout.readline, ""): for line in iter(self._process.stdout.readline, ""):
l = line.decode("utf-8")[:-1].strip() l = line.decode("utf-8")[:-1].strip()

View File

@ -0,0 +1,36 @@
import torch
from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
def forward(self, x):
return self.layers(x)

View File

@ -0,0 +1,2 @@
from .plot_actual_reward import actual_reward
from .plot_predicted_reward import predicted_reward

View File

@ -0,0 +1,81 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def actual_reward(
model_file: Path,
target_point: tuple[int, int],
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
checkpoint = torch.load(
model_file,
map_location = device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
# Only plot memory related to this point
if (x_target, y_target) != target_point:
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
# Draw target point on plot
ax.plot(
target_point[0],
target_point[1],
"k."
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -2,38 +2,36 @@ import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste # All of the following are required to load
from main import DQN # a pickled model.
from main import Transition from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
# Use cpu, this script is faster in parallel. from celeste_ai.network import Transition
compute_device = torch.device("cpu")
input_model = Path("model_data/current")
src_dir = input_model / "model_archive"
out_dir = input_model_dir / "plots/predicted_value"
out_dir.mkdir(parents = True, exist_ok = True)
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
def plot(src): # Create and load model
policy_net = DQN( policy_net = DQN(
len(Celeste.state_number_map), len(Celeste.state_number_map),
len(Celeste.action_space) len(Celeste.action_space)
).to(compute_device) ).to(device)
checkpoint = torch.load( checkpoint = torch.load(
src, model_file,
map_location = compute_device map_location = device
) )
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Compute preditions # Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32) p = np.zeros((128, 128, 8), dtype=np.float32)
@ -44,13 +42,14 @@ def plot(src):
torch.tensor( torch.tensor(
[c, r, 60, 80], [c, r, 60, 80],
dtype = torch.float32, dtype = torch.float32,
device = compute_device device = device
).unsqueeze(0) ).unsqueeze(0)
)[0]) )[0])
p[r][c] = k p[r][c] = k
# Plot predictions # Plot predictions
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())): for a in range(len(axs.ravel())):
ax = axs.ravel()[a] ax = axs.ravel()[a]
ax.set( ax.set(
@ -67,14 +66,12 @@ def plot(src):
ax.invert_yaxis() ax.invert_yaxis()
fig.colorbar(plot) fig.colorbar(plot)
print(src)
fig.savefig(out_dir / f"{src.stem}.png") fig.savefig(out_filename)
plt.close() plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))

View File

@ -6,7 +6,9 @@ import math
import json import json
import torch import torch
from celeste import Celeste from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
if __name__ == "__main__": if __name__ == "__main__":
@ -62,48 +64,6 @@ if __name__ == "__main__":
# GAMMA is the discount factor as mentioned in the previous section # GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9 GAMMA = 0.9
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
if __name__ == "__main__":
steps_done = 0 steps_done = 0
num_episodes = 100 num_episodes = 100
episode_number = 0 episode_number = 0
@ -139,7 +99,10 @@ if __name__ == "__main__":
if model_save_path.is_file(): if model_save_path.is_file():
# Load model if one exists # Load model if one exists
checkpoint = torch.load(model_save_path) checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@ -315,7 +278,6 @@ def optimize_model():
return loss return loss
def on_state_before(celeste): def on_state_before(celeste):
global steps_done global steps_done
@ -351,7 +313,6 @@ def on_state_before(celeste):
return state, action return state, action
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global episode_number global episode_number
@ -478,7 +439,9 @@ def on_state_after(celeste, before_out):
if __name__ == "__main__": if __name__ == "__main__":
c = Celeste() c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop( c.update_loop(
on_state_before, on_state_before,

View File

@ -1,79 +0,0 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
input_model = Path("model_data/after_change")
out_dir = input_model / "plots/actual_reward"
out_dir.mkdir(parents = True, exist_ok = True)
checkpoint = torch.load(
input_model / "model.torch",
map_location = compute_device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
str_action = Celeste.action_space[action]
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
if (x_target, y_target) != (60, 80):
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Plot predictions
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
ax.plot(60, 80, "k.")
#ax.annotate(
# "Target",
# (60, 80),
# textcoords = "offset points",
# xytext = (0, -20),
# ha = "center"
#)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_dir / "actual.png")
plt.close()

25
celeste/pyproject.toml Normal file
View File

@ -0,0 +1,25 @@
[build-system]
requires = [ "setuptools>=61.0" ]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = [ "." ]
include = ["celeste_ai*"]
namespaces = false
[project]
name = "celeste_ai"
description = "A reinforcement learning agent for Celeste Classic"
version = "1.0.0"
dependencies = [
"matplotlib==3.6.3",
"torch==1.13.1"
]
authors = [
{ name="Mark", email="mark@betalupi.com" }
]
readme = "README.md"
requires-python = ">=3.7"
license = {text = "GNU General Public License v3 (GPLv3)"}

View File

@ -1,2 +0,0 @@
matplotlib==3.6.3
torch==1.13.1