Mark
/
celeste-ai
Archived
1
0
Fork 0

Reorganized celeste code

master
Mark 2023-02-19 20:57:19 -08:00
parent d648d896c0
commit 55ac62dc47
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
11 changed files with 201 additions and 157 deletions

15
celeste/README.md Normal file
View File

@ -0,0 +1,15 @@
# Celeste-AI: A Celeste Classic DQL Agent
This is an attempt to create an agent that learns to play Celeste Classic.
## Contents
- `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`.
## Setup
This is designed to work on Linux. You will need `xdotool` to send keypresses to the game.
1. `cd` into this directory
2. Make and enter a venv
3. `pip install -e .`

View File

@ -0,0 +1,6 @@
from .network import DQN
from .network import Transition
from .celeste import Celeste
from .celeste import CelesteError
from .celeste import CelesteState

View File

@ -63,14 +63,15 @@ class Celeste:
def __init__(
self,
pico_path,
*,
state_timeout = 30,
cart_name = "hackcel.p8"
cart_name = "hackcel.p8",
):
# Start pico-8
self._process = subprocess.Popen(
"resources/pico-8/linux/pico8",
pico_path,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT
@ -272,7 +273,6 @@ class Celeste:
def update_loop(self, before, after):
# Waits for stdout from pico-8 process
for line in iter(self._process.stdout.readline, ""):
l = line.decode("utf-8")[:-1].strip()

View File

@ -0,0 +1,36 @@
import torch
from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
def forward(self, x):
return self.layers(x)

View File

@ -0,0 +1,2 @@
from .plot_actual_reward import actual_reward
from .plot_predicted_reward import predicted_reward

View File

@ -0,0 +1,81 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def actual_reward(
model_file: Path,
target_point: tuple[int, int],
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
checkpoint = torch.load(
model_file,
map_location = device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
# Only plot memory related to this point
if (x_target, y_target) != target_point:
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
# Draw target point on plot
ax.plot(
target_point[0],
target_point[1],
"k."
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -2,38 +2,36 @@ import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
input_model = Path("model_data/current")
src_dir = input_model / "model_archive"
out_dir = input_model_dir / "plots/predicted_value"
out_dir.mkdir(parents = True, exist_ok = True)
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
def plot(src):
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(compute_device)
).to(device)
checkpoint = torch.load(
src,
map_location = compute_device
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32)
@ -44,13 +42,14 @@ def plot(src):
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
device = device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
@ -67,14 +66,12 @@ def plot(src):
ax.invert_yaxis()
fig.colorbar(plot)
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
fig.savefig(out_filename)
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))

View File

@ -6,7 +6,9 @@ import math
import json
import torch
from celeste import Celeste
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
if __name__ == "__main__":
@ -62,48 +64,6 @@ if __name__ == "__main__":
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
if __name__ == "__main__":
steps_done = 0
num_episodes = 100
episode_number = 0
@ -139,7 +99,10 @@ if __name__ == "__main__":
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@ -315,7 +278,6 @@ def optimize_model():
return loss
def on_state_before(celeste):
global steps_done
@ -351,7 +313,6 @@ def on_state_before(celeste):
return state, action
def on_state_after(celeste, before_out):
global episode_number
@ -478,7 +439,9 @@ def on_state_after(celeste, before_out):
if __name__ == "__main__":
c = Celeste()
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,

View File

@ -1,79 +0,0 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
input_model = Path("model_data/after_change")
out_dir = input_model / "plots/actual_reward"
out_dir.mkdir(parents = True, exist_ok = True)
checkpoint = torch.load(
input_model / "model.torch",
map_location = compute_device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
str_action = Celeste.action_space[action]
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
if (x_target, y_target) != (60, 80):
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Plot predictions
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
ax.plot(60, 80, "k.")
#ax.annotate(
# "Target",
# (60, 80),
# textcoords = "offset points",
# xytext = (0, -20),
# ha = "center"
#)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_dir / "actual.png")
plt.close()

25
celeste/pyproject.toml Normal file
View File

@ -0,0 +1,25 @@
[build-system]
requires = [ "setuptools>=61.0" ]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = [ "." ]
include = ["celeste_ai*"]
namespaces = false
[project]
name = "celeste_ai"
description = "A reinforcement learning agent for Celeste Classic"
version = "1.0.0"
dependencies = [
"matplotlib==3.6.3",
"torch==1.13.1"
]
authors = [
{ name="Mark", email="mark@betalupi.com" }
]
readme = "README.md"
requires-python = ">=3.7"
license = {text = "GNU General Public License v3 (GPLv3)"}

View File

@ -1,2 +0,0 @@
matplotlib==3.6.3
torch==1.13.1