From 0b617026775c2dba406f2775e7e7080ffd243cca Mon Sep 17 00:00:00 2001 From: Mark Date: Sun, 26 Feb 2023 12:09:05 -0800 Subject: [PATCH] Removed "can_dash" input value --- celeste/celeste_ai/celeste.py | 2 +- .../celeste_ai/plotting/plot_best_action.py | 56 ++++++++----------- .../plotting/plot_predicted_reward.py | 2 +- celeste/plot.py | 16 +++--- 4 files changed, 34 insertions(+), 42 deletions(-) diff --git a/celeste/celeste_ai/celeste.py b/celeste/celeste_ai/celeste.py index 72ce277..f1e738f 100755 --- a/celeste/celeste_ai/celeste.py +++ b/celeste/celeste_ai/celeste.py @@ -70,7 +70,7 @@ class Celeste: #"ypos", "xpos_scaled", "ypos_scaled", - "can_dash_int" + #"can_dash_int" #"next_point_x", #"next_point_y" ] diff --git a/celeste/celeste_ai/plotting/plot_best_action.py b/celeste/celeste_ai/plotting/plot_best_action.py index bb26188..100bea2 100644 --- a/celeste/celeste_ai/plotting/plot_best_action.py +++ b/celeste/celeste_ai/plotting/plot_best_action.py @@ -1,6 +1,7 @@ import torch import numpy as np from pathlib import Path +import matplotlib as mpl import matplotlib.pyplot as plt # All of the following are required to load @@ -34,7 +35,7 @@ def best_action( # Compute preditions - p = np.zeros((128, 128, 2), dtype=np.float32) + p = np.zeros((128, 128), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): @@ -43,26 +44,31 @@ def best_action( k = np.asarray(policy_net( torch.tensor( - [x, y, 0], + [x, y], dtype = torch.float32, device = device ).unsqueeze(0) )[0]) - p[r][c][0] = np.argmax(k) + p[r][c] = np.argmax(k) - k = np.asarray(policy_net( - torch.tensor( - [x, y, 1], - dtype = torch.float32, - device = device - ).unsqueeze(0) - )[0]) - p[r][c][1] = np.argmax(k) + cmap = mpl.colors.ListedColormap( + [ + "forestgreen", + "firebrick", + "lightgreen", + "salmon", + "darkturquoise", + "sandybrown", + "olive", + "darkorchid", + "mediumvioletred" + ] + ) # Plot predictions - fig, axs = plt.subplots(1, 2, figsize = (10, 10)) - ax = axs[0] + fig, axs = plt.subplots(1, 1, figsize = (20, 20)) + ax = axs ax.set( adjustable = "box", aspect = "equal", @@ -70,30 +76,16 @@ def best_action( ) plot = ax.pcolor( - p[:,:,0], - cmap = "Set1", + p, + cmap = cmap, vmin = 0, vmax = 8 ) ax.invert_yaxis() - fig.colorbar(plot) + cbar = fig.colorbar(plot, ticks = list(range(0, 9))) + cbar.ax.set_yticklabels(Celeste.action_space) - ax = axs[1] - ax.set( - adjustable = "box", - aspect = "equal", - title = "Best Action" - ) - - plot = ax.pcolor( - p[:,:,0], - cmap = "Set1", - vmin = 0, - vmax = 8 - ) - - ax.invert_yaxis() - fig.colorbar(plot) + fig.savefig(out_filename) plt.close() diff --git a/celeste/celeste_ai/plotting/plot_predicted_reward.py b/celeste/celeste_ai/plotting/plot_predicted_reward.py index 0b61487..05cc50d 100644 --- a/celeste/celeste_ai/plotting/plot_predicted_reward.py +++ b/celeste/celeste_ai/plotting/plot_predicted_reward.py @@ -43,7 +43,7 @@ def predicted_reward( k = np.asarray(policy_net( torch.tensor( - [x, y, 0], + [x, y], dtype = torch.float32, device = device ).unsqueeze(0) diff --git a/celeste/plot.py b/celeste/plot.py index a9314f4..7c64793 100644 --- a/celeste/plot.py +++ b/celeste/plot.py @@ -47,14 +47,6 @@ plots = { if __name__ == "__main__": - if plots["prediction"]: - print("Making prediction plots...") - with Pool(5) as p: - p.map( - plot_pred, - list((m / "model_archive").iterdir()) - ) - if plots["best"]: print("Making best-action plots...") with Pool(5) as p: @@ -63,6 +55,14 @@ if __name__ == "__main__": list((m / "model_archive").iterdir()) ) + if plots["prediction"]: + print("Making prediction plots...") + with Pool(5) as p: + p.map( + plot_pred, + list((m / "model_archive").iterdir()) + ) + if plots["actual"]: print("Making actual plots...") with Pool(5) as p: