From 3d25d63efe8e7b41d07ff65b0f5972d4304a75c5 Mon Sep 17 00:00:00 2001 From: Mark Date: Sun, 26 Feb 2023 15:26:45 -0800 Subject: [PATCH] Plotter cleanup --- celeste/celeste_ai/plotting/__init__.py | 1 - .../celeste_ai/plotting/plot_actual_reward.py | 81 ------------------- celeste/plot.py | 57 +++---------- 3 files changed, 11 insertions(+), 128 deletions(-) delete mode 100644 celeste/celeste_ai/plotting/plot_actual_reward.py diff --git a/celeste/celeste_ai/plotting/__init__.py b/celeste/celeste_ai/plotting/__init__.py index da903ac..609a157 100644 --- a/celeste/celeste_ai/plotting/__init__.py +++ b/celeste/celeste_ai/plotting/__init__.py @@ -1,4 +1,3 @@ -from .plot_actual_reward import actual_reward from .plot_predicted_reward import predicted_reward from .plot_best_action import best_action diff --git a/celeste/celeste_ai/plotting/plot_actual_reward.py b/celeste/celeste_ai/plotting/plot_actual_reward.py deleted file mode 100644 index 7bcfed0..0000000 --- a/celeste/celeste_ai/plotting/plot_actual_reward.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import numpy as np -from pathlib import Path -import matplotlib.pyplot as plt -from multiprocessing import Pool - -# All of the following are required to load -# a pickled model. -from celeste_ai.celeste import Celeste -from celeste_ai.network import DQN -from celeste_ai.network import Transition - -def actual_reward( - model_file: Path, - target_point: tuple[int, int], - out_filename: Path, - *, - device = torch.device("cpu") -): - if not model_file.is_file(): - raise Exception(f"Bad model file {model_file}") - out_filename.parent.mkdir(exist_ok = True, parents = True) - - - checkpoint = torch.load( - model_file, - map_location = device - ) - memory = checkpoint["memory"] - - - r = np.zeros((128, 128, 8), dtype=np.float32) - for m in memory: - x, y, x_target, y_target = list(m.state[0]) - - action = m.action[0].item() - x = int(x.item()) - y = int(y.item()) - x_target = int(x_target.item()) - y_target = int(y_target.item()) - - # Only plot memory related to this point - if (x_target, y_target) != target_point: - continue - - if m.reward[0].item() == 1: - r[y][x][action] += 1 - else: - r[y][x][action] -= 1 - - - fig, axs = plt.subplots(2, 4, figsize = (20, 10)) - - - for a in range(len(axs.ravel())): - ax = axs.ravel()[a] - ax.set( - adjustable = "box", - aspect = "equal", - title = Celeste.action_space[a] - ) - - plot = ax.pcolor( - r[:,:,a], - cmap = "seismic_r", - vmin = -10, - vmax = 10 - ) - - # Draw target point on plot - ax.plot( - target_point[0], - target_point[1], - "k." - ) - - ax.invert_yaxis() - fig.colorbar(plot) - - fig.savefig(out_filename) - plt.close() \ No newline at end of file diff --git a/celeste/plot.py b/celeste/plot.py index 7c64793..7eae43d 100644 --- a/celeste/plot.py +++ b/celeste/plot.py @@ -5,10 +5,9 @@ import celeste_ai.plotting as plotting from multiprocessing import Pool - m = Path("model_data/current") -# Make "predicted reward" plots + def plot_pred(src_model): plotting.predicted_reward( src_model, @@ -17,7 +16,6 @@ def plot_pred(src_model): device = torch.device("cpu") ) -# Make "best action" plots def plot_best(src_model): plotting.best_action( src_model, @@ -26,47 +24,14 @@ def plot_best(src_model): device = torch.device("cpu") ) -# Make "actual reward" plots -def plot_act(src_model): - plotting.actual_reward( - src_model, - (60, 80), - m / f"plots/actual/{src_model.stem}.png", - device = torch.device("cpu") - ) - - -# Which plots should we make? -plots = { - "prediction": True, - "actual": False, - "best": True -} - - -if __name__ == "__main__": - - if plots["best"]: - print("Making best-action plots...") - with Pool(5) as p: - p.map( - plot_best, - list((m / "model_archive").iterdir()) - ) - - if plots["prediction"]: - print("Making prediction plots...") - with Pool(5) as p: - p.map( - plot_pred, - list((m / "model_archive").iterdir()) - ) - - if plots["actual"]: - print("Making actual plots...") - with Pool(5) as p: - p.map( - plot_act, - list((m / "model_archive").iterdir()) - ) \ No newline at end of file +for k, v in { + #"prediction": plot_pred, + "best_action": plot_best, +}.items(): + print(f"Making {k} plots...") + with Pool(5) as p: + p.map( + v, + list((m / "model_archive").iterdir()) + ) \ No newline at end of file