import torch import numpy as np from pathlib import Path import matplotlib as mpl import matplotlib.pyplot as plt # All of the following are required to load # a pickled model. from celeste_ai.celeste import Celeste from celeste_ai.network import DQN from celeste_ai.network import Transition def best_action( model_file: Path, out_filename: Path, *, device = torch.device("cpu") ): if not model_file.is_file(): raise Exception(f"Bad model file {model_file}") out_filename.parent.mkdir(exist_ok = True, parents = True) # Create and load model policy_net = DQN( len(Celeste.state_number_map), len(Celeste.action_space) ).to(device) checkpoint = torch.load( model_file, map_location = device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) # Compute preditions p = np.zeros((128, 128), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): x = c / 128.0 y = r / 128.0 k = np.asarray(policy_net( torch.tensor( [x, y], dtype = torch.float32, device = device ).unsqueeze(0) )[0]) p[r][c] = np.argmax(k) cmap = mpl.colors.ListedColormap( [ "forestgreen", "firebrick", "lightgreen", "salmon", "darkturquoise", "sandybrown", "olive", "darkorchid", "mediumvioletred" ] ) # Plot predictions fig, axs = plt.subplots(1, 1, figsize = (20, 20)) ax = axs ax.set( adjustable = "box", aspect = "equal", title = "Best Action" ) plot = ax.pcolor( p, cmap = cmap, vmin = 0, vmax = 8 ) ax.invert_yaxis() cbar = fig.colorbar(plot, ticks = list(range(0, 9))) cbar.ax.set_yticklabels(Celeste.action_space) fig.savefig(out_filename) plt.close()