From 4fbf1ea3f5ed08b81a7a7446fd48842846c7263d Mon Sep 17 00:00:00 2001 From: Mark Date: Sun, 19 Feb 2023 08:49:33 -0800 Subject: [PATCH] Made plotter faster --- celeste/plots.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/celeste/plots.py b/celeste/plots.py index bfc10fa..ac81346 100644 --- a/celeste/plots.py +++ b/celeste/plots.py @@ -17,24 +17,6 @@ out_dir.mkdir(parents = True, exist_ok = True) src_dir = Path("model_data/current/model_archive") -def makeplt(i, net): - p = np.zeros((128, 128), dtype=np.float32) - - with torch.no_grad(): - for r in range(len(p)): - for c in range(len(p[r])): - k = net( - torch.tensor( - [c, r, 60, 80], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) - )[0][i].item() - p[r][c] = k - return p - - - def plot(src): policy_net = DQN( len(Celeste.state_number_map), @@ -47,21 +29,38 @@ def plot(src): fig, axs = plt.subplots(2, 4, figsize = (20, 10)) + + # Compute preditions + p = np.zeros((128, 128, 8), dtype=np.float32) + with torch.no_grad(): + for r in range(len(p)): + for c in range(len(p[r])): + k = np.asarray(policy_net( + torch.tensor( + [c, r, 60, 80], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + )[0]) + p[r][c] = k + + + # Plot predictions for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set( adjustable = "box", - aspect = "equal" + aspect = "equal", + title = Celeste.action_space[a] ) - ax.set_title(Celeste.action_space[a]) - ax.invert_yaxis() plot = ax.pcolor( - makeplt(a, policy_net), + p[:,:,a], cmap = "Greens", vmin = 0, ) + ax.invert_yaxis() fig.colorbar(plot) print(src) fig.savefig(out_dir / f"{src.stem}.png")