From c9e04dcd4176e7736a319d12cee5a4bee6aa6b8f Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 24 Feb 2023 14:23:38 -0800 Subject: [PATCH] Added best-action plot --- celeste/celeste_ai/plotting/__init__.py | 2 + .../celeste_ai/plotting/plot_best_action.py | 105 ++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 celeste/celeste_ai/plotting/plot_best_action.py diff --git a/celeste/celeste_ai/plotting/__init__.py b/celeste/celeste_ai/plotting/__init__.py index 1495b35..da903ac 100644 --- a/celeste/celeste_ai/plotting/__init__.py +++ b/celeste/celeste_ai/plotting/__init__.py @@ -1,2 +1,4 @@ from .plot_actual_reward import actual_reward from .plot_predicted_reward import predicted_reward +from .plot_best_action import best_action + diff --git a/celeste/celeste_ai/plotting/plot_best_action.py b/celeste/celeste_ai/plotting/plot_best_action.py new file mode 100644 index 0000000..bb26188 --- /dev/null +++ b/celeste/celeste_ai/plotting/plot_best_action.py @@ -0,0 +1,105 @@ +import torch +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt + +# All of the following are required to load +# a pickled model. +from celeste_ai.celeste import Celeste +from celeste_ai.network import DQN +from celeste_ai.network import Transition + + +def best_action( + model_file: Path, + out_filename: Path, + *, + device = torch.device("cpu") +): + if not model_file.is_file(): + raise Exception(f"Bad model file {model_file}") + out_filename.parent.mkdir(exist_ok = True, parents = True) + + # Create and load model + policy_net = DQN( + len(Celeste.state_number_map), + len(Celeste.action_space) + ).to(device) + checkpoint = torch.load( + model_file, + map_location = device + ) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + + + + # Compute preditions + p = np.zeros((128, 128, 2), dtype=np.float32) + with torch.no_grad(): + for r in range(len(p)): + for c in range(len(p[r])): + x = c / 128.0 + y = r / 128.0 + + k = np.asarray(policy_net( + torch.tensor( + [x, y, 0], + dtype = torch.float32, + device = device + ).unsqueeze(0) + )[0]) + p[r][c][0] = np.argmax(k) + + k = np.asarray(policy_net( + torch.tensor( + [x, y, 1], + dtype = torch.float32, + device = device + ).unsqueeze(0) + )[0]) + p[r][c][1] = np.argmax(k) + + + # Plot predictions + fig, axs = plt.subplots(1, 2, figsize = (10, 10)) + ax = axs[0] + ax.set( + adjustable = "box", + aspect = "equal", + title = "Best Action" + ) + + plot = ax.pcolor( + p[:,:,0], + cmap = "Set1", + vmin = 0, + vmax = 8 + ) + ax.invert_yaxis() + fig.colorbar(plot) + + ax = axs[1] + ax.set( + adjustable = "box", + aspect = "equal", + title = "Best Action" + ) + + plot = ax.pcolor( + p[:,:,0], + cmap = "Set1", + vmin = 0, + vmax = 8 + ) + + ax.invert_yaxis() + fig.colorbar(plot) + + fig.savefig(out_filename) + plt.close() + + + + + +