diff --git a/celeste/plots.py b/celeste/plots.py index d2e338d..bfc10fa 100644 --- a/celeste/plots.py +++ b/celeste/plots.py @@ -8,43 +8,21 @@ from celeste import Celeste from main import DQN from main import Transition -# Use cpu, the script is faster in parallel. +# Use cpu, this script is faster in parallel. compute_device = torch.device("cpu") - -# Celeste env properties -n_observations = len(Celeste.state_number_map) -n_actions = len(Celeste.action_space) - - out_dir = Path("out/plots") out_dir.mkdir(parents = True, exist_ok = True) -src_dir = Path("model_data/model_archive") - -policy_net = DQN( - n_observations, - n_actions -).to(compute_device) - -target_net = DQN( - n_observations, - n_actions -).to(compute_device) - -optimizer = torch.optim.AdamW( - policy_net.parameters(), - lr = 0.01, # Hyperparameter: learning rate - amsgrad = True -) +src_dir = Path("model_data/current/model_archive") def makeplt(i, net): p = np.zeros((128, 128), dtype=np.float32) - for r in range(len(p)): - for c in range(len(p[r])): - with torch.no_grad(): + with torch.no_grad(): + for r in range(len(p)): + for c in range(len(p[r])): k = net( torch.tensor( [c, r, 60, 80], @@ -52,29 +30,38 @@ def makeplt(i, net): device = compute_device ).unsqueeze(0) )[0][i].item() - p[r][c] = k return p def plot(src): + policy_net = DQN( + len(Celeste.state_number_map), + len(Celeste.action_space) + ).to(compute_device) + checkpoint = torch.load(src) policy_net.load_state_dict(checkpoint["policy_state_dict"]) - fig, axs = plt.subplots(2, 4, figsize = (15, 10)) + fig, axs = plt.subplots(2, 4, figsize = (20, 10)) for a in range(len(axs.ravel())): ax = axs.ravel()[a] - ax.set(adjustable="box", aspect="equal") + ax.set( + adjustable = "box", + aspect = "equal" + ) + ax.set_title(Celeste.action_space[a]) + ax.invert_yaxis() + plot = ax.pcolor( makeplt(a, policy_net), cmap = "Greens", vmin = 0, ) - ax.set_title(Celeste.action_space[a]) - ax.invert_yaxis() + fig.colorbar(plot) print(src) fig.savefig(out_dir / f"{src.stem}.png")