From bae70e0cfaeca3327555b7572ae12a66370f7150 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 24 Feb 2023 14:23:25 -0800 Subject: [PATCH] Tweaked plotter for new model --- .../celeste_ai/plotting/plot_predicted_reward.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/celeste/celeste_ai/plotting/plot_predicted_reward.py b/celeste/celeste_ai/plotting/plot_predicted_reward.py index 98e9c99..0b61487 100644 --- a/celeste/celeste_ai/plotting/plot_predicted_reward.py +++ b/celeste/celeste_ai/plotting/plot_predicted_reward.py @@ -34,13 +34,16 @@ def predicted_reward( # Compute preditions - p = np.zeros((128, 128, 8), dtype=np.float32) + p = np.zeros((128, 128, 9), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): + x = c / 128.0 + y = r / 128.0 + k = np.asarray(policy_net( torch.tensor( - [c, r, 60, 80], + [x, y, 0], dtype = torch.float32, device = device ).unsqueeze(0) @@ -49,8 +52,11 @@ def predicted_reward( # Plot predictions - fig, axs = plt.subplots(2, 4, figsize = (20, 10)) + fig, axs = plt.subplots(2, 5, figsize = (20, 10)) for a in range(len(axs.ravel())): + if a >= len(Celeste.action_space): + continue + ax = axs.ravel()[a] ax.set( adjustable = "box", @@ -62,6 +68,7 @@ def predicted_reward( p[:,:,a], cmap = "Greens", vmin = 0, + #vmax = 5 ) ax.invert_yaxis()