import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt from multiprocessing import Pool from celeste import Celeste from main import DQN from main import Transition # Use cpu, the script is faster in parallel. compute_device = torch.device("cpu") # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) out_dir = Path("out/plots") out_dir.mkdir(parents = True, exist_ok = True) src_dir = Path("model_data/model_archive") policy_net = DQN( n_observations, n_actions ).to(compute_device) target_net = DQN( n_observations, n_actions ).to(compute_device) optimizer = torch.optim.AdamW( policy_net.parameters(), lr = 0.01, # Hyperparameter: learning rate amsgrad = True ) def makeplt(i, net): p = np.zeros((128, 128), dtype=np.float32) for r in range(len(p)): for c in range(len(p[r])): with torch.no_grad(): k = net( torch.tensor( [c, r, 60, 80], dtype = torch.float32, device = compute_device ).unsqueeze(0) )[0][i].item() p[r][c] = k return p def plot(src): checkpoint = torch.load(src) policy_net.load_state_dict(checkpoint["policy_state_dict"]) fig, axs = plt.subplots(2, 4, figsize = (15, 10)) for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set(adjustable="box", aspect="equal") plot = ax.pcolor( makeplt(a, policy_net), cmap = "Greens", vmin = 0, ) ax.set_title(Celeste.action_space[a]) ax.invert_yaxis() fig.colorbar(plot) print(src) fig.savefig(out_dir / f"{src.stem}.png") plt.close() if __name__ == "__main__": with Pool(5) as p: p.map(plot, list(src_dir.iterdir()))