import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt # All of the following are required to load # a pickled model. from celeste_ai.celeste import Celeste from celeste_ai.network import DQN from celeste_ai.network import Transition def predicted_reward( model_file: Path, out_filename: Path, *, device = torch.device("cpu") ): if not model_file.is_file(): raise Exception(f"Bad model file {model_file}") out_filename.parent.mkdir(exist_ok = True, parents = True) # Create and load model policy_net = DQN( len(Celeste.state_number_map), len(Celeste.action_space) ).to(device) checkpoint = torch.load( model_file, map_location = device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) # Compute preditions p = np.zeros((128, 128, 9), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): x = c / 128.0 y = r / 128.0 k = np.asarray(policy_net( torch.tensor( [x, y, 0], dtype = torch.float32, device = device ).unsqueeze(0) )[0]) p[r][c] = k # Plot predictions fig, axs = plt.subplots(2, 5, figsize = (20, 10)) for a in range(len(axs.ravel())): if a >= len(Celeste.action_space): continue ax = axs.ravel()[a] ax.set( adjustable = "box", aspect = "equal", title = Celeste.action_space[a] ) plot = ax.pcolor( p[:,:,a], cmap = "Greens", vmin = 0, #vmax = 5 ) ax.invert_yaxis() fig.colorbar(plot) fig.savefig(out_filename) plt.close()