From 55ac62dc47f6a2a83ccc3250506e701a588b51cb Mon Sep 17 00:00:00 2001 From: Mark Date: Sun, 19 Feb 2023 20:57:19 -0800 Subject: [PATCH] Reorganized celeste code --- celeste/README.md | 15 ++++ celeste/celeste_ai/__init__.py | 6 ++ celeste/{ => celeste_ai}/celeste.py | 6 +- celeste/celeste_ai/network.py | 36 +++++++++ celeste/celeste_ai/plotting/__init__.py | 2 + .../celeste_ai/plotting/plot_actual_reward.py | 81 +++++++++++++++++++ .../plotting/plot_predicted_reward.py} | 49 ++++++----- celeste/{main.py => celeste_ai/train.py} | 57 +++---------- celeste/plot-actual.py | 79 ------------------ celeste/pyproject.toml | 25 ++++++ celeste/requirements.txt | 2 - 11 files changed, 201 insertions(+), 157 deletions(-) create mode 100644 celeste/README.md create mode 100644 celeste/celeste_ai/__init__.py rename celeste/{ => celeste_ai}/celeste.py (99%) create mode 100644 celeste/celeste_ai/network.py create mode 100644 celeste/celeste_ai/plotting/__init__.py create mode 100644 celeste/celeste_ai/plotting/plot_actual_reward.py rename celeste/{plots.py => celeste_ai/plotting/plot_predicted_reward.py} (59%) rename celeste/{main.py => celeste_ai/train.py} (92%) delete mode 100644 celeste/plot-actual.py create mode 100644 celeste/pyproject.toml delete mode 100755 celeste/requirements.txt diff --git a/celeste/README.md b/celeste/README.md new file mode 100644 index 0000000..009d7c2 --- /dev/null +++ b/celeste/README.md @@ -0,0 +1,15 @@ +# Celeste-AI: A Celeste Classic DQL Agent + +This is an attempt to create an agent that learns to play Celeste Classic. + +## Contents + - `./resources`: contain files these scripts require. Notably, we have an (old) version of PICO-8 that's known to work with this script, and a version of Celeste Classic with telementery and delays called `hackcel.p8`. + + +## Setup + +This is designed to work on Linux. You will need `xdotool` to send keypresses to the game. + +1. `cd` into this directory +2. Make and enter a venv +3. `pip install -e .` \ No newline at end of file diff --git a/celeste/celeste_ai/__init__.py b/celeste/celeste_ai/__init__.py new file mode 100644 index 0000000..262e0bb --- /dev/null +++ b/celeste/celeste_ai/__init__.py @@ -0,0 +1,6 @@ +from .network import DQN +from .network import Transition + +from .celeste import Celeste +from .celeste import CelesteError +from .celeste import CelesteState diff --git a/celeste/celeste.py b/celeste/celeste_ai/celeste.py similarity index 99% rename from celeste/celeste.py rename to celeste/celeste_ai/celeste.py index 9b1e254..9988080 100755 --- a/celeste/celeste.py +++ b/celeste/celeste_ai/celeste.py @@ -63,14 +63,15 @@ class Celeste: def __init__( self, + pico_path, *, state_timeout = 30, - cart_name = "hackcel.p8" + cart_name = "hackcel.p8", ): # Start pico-8 self._process = subprocess.Popen( - "resources/pico-8/linux/pico8", + pico_path, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT @@ -272,7 +273,6 @@ class Celeste: def update_loop(self, before, after): - # Waits for stdout from pico-8 process for line in iter(self._process.stdout.readline, ""): l = line.decode("utf-8")[:-1].strip() diff --git a/celeste/celeste_ai/network.py b/celeste/celeste_ai/network.py new file mode 100644 index 0000000..309ea12 --- /dev/null +++ b/celeste/celeste_ai/network.py @@ -0,0 +1,36 @@ +import torch +from collections import namedtuple + + +Transition = namedtuple( + "Transition", + ( + "state", + "action", + "next_state", + "reward" + ) +) + + +class DQN(torch.nn.Module): + def __init__(self, n_observations: int, n_actions: int): + super(DQN, self).__init__() + + self.layers = torch.nn.Sequential( + torch.nn.Linear(n_observations, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.torch.nn.Linear(128, n_actions) + ) + + def forward(self, x): + return self.layers(x) + + diff --git a/celeste/celeste_ai/plotting/__init__.py b/celeste/celeste_ai/plotting/__init__.py new file mode 100644 index 0000000..1495b35 --- /dev/null +++ b/celeste/celeste_ai/plotting/__init__.py @@ -0,0 +1,2 @@ +from .plot_actual_reward import actual_reward +from .plot_predicted_reward import predicted_reward diff --git a/celeste/celeste_ai/plotting/plot_actual_reward.py b/celeste/celeste_ai/plotting/plot_actual_reward.py new file mode 100644 index 0000000..7bcfed0 --- /dev/null +++ b/celeste/celeste_ai/plotting/plot_actual_reward.py @@ -0,0 +1,81 @@ +import torch +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt +from multiprocessing import Pool + +# All of the following are required to load +# a pickled model. +from celeste_ai.celeste import Celeste +from celeste_ai.network import DQN +from celeste_ai.network import Transition + +def actual_reward( + model_file: Path, + target_point: tuple[int, int], + out_filename: Path, + *, + device = torch.device("cpu") +): + if not model_file.is_file(): + raise Exception(f"Bad model file {model_file}") + out_filename.parent.mkdir(exist_ok = True, parents = True) + + + checkpoint = torch.load( + model_file, + map_location = device + ) + memory = checkpoint["memory"] + + + r = np.zeros((128, 128, 8), dtype=np.float32) + for m in memory: + x, y, x_target, y_target = list(m.state[0]) + + action = m.action[0].item() + x = int(x.item()) + y = int(y.item()) + x_target = int(x_target.item()) + y_target = int(y_target.item()) + + # Only plot memory related to this point + if (x_target, y_target) != target_point: + continue + + if m.reward[0].item() == 1: + r[y][x][action] += 1 + else: + r[y][x][action] -= 1 + + + fig, axs = plt.subplots(2, 4, figsize = (20, 10)) + + + for a in range(len(axs.ravel())): + ax = axs.ravel()[a] + ax.set( + adjustable = "box", + aspect = "equal", + title = Celeste.action_space[a] + ) + + plot = ax.pcolor( + r[:,:,a], + cmap = "seismic_r", + vmin = -10, + vmax = 10 + ) + + # Draw target point on plot + ax.plot( + target_point[0], + target_point[1], + "k." + ) + + ax.invert_yaxis() + fig.colorbar(plot) + + fig.savefig(out_filename) + plt.close() \ No newline at end of file diff --git a/celeste/plots.py b/celeste/celeste_ai/plotting/plot_predicted_reward.py similarity index 59% rename from celeste/plots.py rename to celeste/celeste_ai/plotting/plot_predicted_reward.py index 4e76e82..98e9c99 100644 --- a/celeste/plots.py +++ b/celeste/celeste_ai/plotting/plot_predicted_reward.py @@ -2,38 +2,36 @@ import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt -from multiprocessing import Pool -from celeste import Celeste -from main import DQN -from main import Transition - -# Use cpu, this script is faster in parallel. -compute_device = torch.device("cpu") - -input_model = Path("model_data/current") - -src_dir = input_model / "model_archive" -out_dir = input_model_dir / "plots/predicted_value" -out_dir.mkdir(parents = True, exist_ok = True) +# All of the following are required to load +# a pickled model. +from celeste_ai.celeste import Celeste +from celeste_ai.network import DQN +from celeste_ai.network import Transition +def predicted_reward( + model_file: Path, + out_filename: Path, + *, + device = torch.device("cpu") +): + if not model_file.is_file(): + raise Exception(f"Bad model file {model_file}") + out_filename.parent.mkdir(exist_ok = True, parents = True) -def plot(src): + # Create and load model policy_net = DQN( len(Celeste.state_number_map), len(Celeste.action_space) - ).to(compute_device) - + ).to(device) checkpoint = torch.load( - src, - map_location = compute_device + model_file, + map_location = device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) - fig, axs = plt.subplots(2, 4, figsize = (20, 10)) - # Compute preditions p = np.zeros((128, 128, 8), dtype=np.float32) @@ -44,13 +42,14 @@ def plot(src): torch.tensor( [c, r, 60, 80], dtype = torch.float32, - device = compute_device + device = device ).unsqueeze(0) )[0]) p[r][c] = k # Plot predictions + fig, axs = plt.subplots(2, 4, figsize = (20, 10)) for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set( @@ -67,14 +66,12 @@ def plot(src): ax.invert_yaxis() fig.colorbar(plot) - print(src) - fig.savefig(out_dir / f"{src.stem}.png") + + fig.savefig(out_filename) plt.close() -if __name__ == "__main__": - with Pool(5) as p: - p.map(plot, list(src_dir.iterdir())) + diff --git a/celeste/main.py b/celeste/celeste_ai/train.py similarity index 92% rename from celeste/main.py rename to celeste/celeste_ai/train.py index a1043c4..adb093f 100644 --- a/celeste/main.py +++ b/celeste/celeste_ai/train.py @@ -6,7 +6,9 @@ import math import json import torch -from celeste import Celeste +from celeste_ai import Celeste +from celeste_ai import DQN +from celeste_ai import Transition if __name__ == "__main__": @@ -62,48 +64,6 @@ if __name__ == "__main__": # GAMMA is the discount factor as mentioned in the previous section GAMMA = 0.9 - -# Outline our network -class DQN(torch.nn.Module): - def __init__(self, n_observations: int, n_actions: int): - super(DQN, self).__init__() - - self.layers = torch.nn.Sequential( - torch.nn.Linear(n_observations, 128), - torch.nn.ReLU(), - - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - - torch.torch.nn.Linear(128, n_actions) - ) - - # Can be called with one input, or with a batch. - # - # Returns tensor( - # [ Q(s, left), Q(s, right) ], ... - # ) - # - # Recall that Q(s, a) is the (expected) return of taking - # action `a` at state `s` - def forward(self, x): - return self.layers(x) - -Transition = namedtuple( - "Transition", - ( - "state", - "action", - "next_state", - "reward" - ) -) - - -if __name__ == "__main__": steps_done = 0 num_episodes = 100 episode_number = 0 @@ -139,7 +99,10 @@ if __name__ == "__main__": if model_save_path.is_file(): # Load model if one exists - checkpoint = torch.load(model_save_path) + checkpoint = torch.load( + model_save_path, + map_location = compute_device + ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) @@ -315,7 +278,6 @@ def optimize_model(): return loss - def on_state_before(celeste): global steps_done @@ -351,7 +313,6 @@ def on_state_before(celeste): return state, action - def on_state_after(celeste, before_out): global episode_number @@ -478,7 +439,9 @@ def on_state_after(celeste, before_out): if __name__ == "__main__": - c = Celeste() + c = Celeste( + "resources/pico-8/linux/pico8" + ) c.update_loop( on_state_before, diff --git a/celeste/plot-actual.py b/celeste/plot-actual.py deleted file mode 100644 index f034b3c..0000000 --- a/celeste/plot-actual.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import numpy as np -from pathlib import Path -import matplotlib.pyplot as plt -from multiprocessing import Pool - -from celeste import Celeste -from main import DQN -from main import Transition - -# Use cpu, this script is faster in parallel. -compute_device = torch.device("cpu") - -input_model = Path("model_data/after_change") - -out_dir = input_model / "plots/actual_reward" -out_dir.mkdir(parents = True, exist_ok = True) - - -checkpoint = torch.load( - input_model / "model.torch", - map_location = compute_device -) -memory = checkpoint["memory"] - -r = np.zeros((128, 128, 8), dtype=np.float32) - -for m in memory: - x, y, x_target, y_target = list(m.state[0]) - - action = m.action[0].item() - str_action = Celeste.action_space[action] - x = int(x.item()) - y = int(y.item()) - x_target = int(x_target.item()) - y_target = int(y_target.item()) - - if (x_target, y_target) != (60, 80): - continue - - if m.reward[0].item() == 1: - r[y][x][action] += 1 - else: - r[y][x][action] -= 1 - - -fig, axs = plt.subplots(2, 4, figsize = (20, 10)) - - -# Plot predictions -for a in range(len(axs.ravel())): - ax = axs.ravel()[a] - ax.set( - adjustable = "box", - aspect = "equal", - title = Celeste.action_space[a] - ) - - plot = ax.pcolor( - r[:,:,a], - cmap = "seismic_r", - vmin = -10, - vmax = 10 - ) - - ax.plot(60, 80, "k.") - #ax.annotate( - # "Target", - # (60, 80), - # textcoords = "offset points", - # xytext = (0, -20), - # ha = "center" - #) - - ax.invert_yaxis() - fig.colorbar(plot) - -fig.savefig(out_dir / "actual.png") -plt.close() diff --git a/celeste/pyproject.toml b/celeste/pyproject.toml new file mode 100644 index 0000000..dd82713 --- /dev/null +++ b/celeste/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = [ "setuptools>=61.0" ] +build-backend = "setuptools.build_meta" + + +[tool.setuptools.packages.find] +where = [ "." ] +include = ["celeste_ai*"] +namespaces = false + + +[project] +name = "celeste_ai" +description = "A reinforcement learning agent for Celeste Classic" +version = "1.0.0" +dependencies = [ + "matplotlib==3.6.3", + "torch==1.13.1" +] +authors = [ + { name="Mark", email="mark@betalupi.com" } +] +readme = "README.md" +requires-python = ">=3.7" +license = {text = "GNU General Public License v3 (GPLv3)"} diff --git a/celeste/requirements.txt b/celeste/requirements.txt deleted file mode 100755 index 222f50a..0000000 --- a/celeste/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -matplotlib==3.6.3 -torch==1.13.1 \ No newline at end of file