import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt from multiprocessing import Pool from celeste import Celeste from main import DQN from main import Transition # Use cpu, this script is faster in parallel. compute_device = torch.device("cpu") input_model = Path("model_data/current") src_dir = input_model / "model_archive" out_dir = input_model_dir / "plots/predicted_value" out_dir.mkdir(parents = True, exist_ok = True) def plot(src): policy_net = DQN( len(Celeste.state_number_map), len(Celeste.action_space) ).to(compute_device) checkpoint = torch.load( src, map_location = compute_device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) fig, axs = plt.subplots(2, 4, figsize = (20, 10)) # Compute preditions p = np.zeros((128, 128, 8), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): k = np.asarray(policy_net( torch.tensor( [c, r, 60, 80], dtype = torch.float32, device = compute_device ).unsqueeze(0) )[0]) p[r][c] = k # Plot predictions for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set( adjustable = "box", aspect = "equal", title = Celeste.action_space[a] ) plot = ax.pcolor( p[:,:,a], cmap = "Greens", vmin = 0, ) ax.invert_yaxis() fig.colorbar(plot) print(src) fig.savefig(out_dir / f"{src.stem}.png") plt.close() if __name__ == "__main__": with Pool(5) as p: p.map(plot, list(src_dir.iterdir()))