import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt # All of the following are required to load # a pickled model. from celeste_ai.celeste import Celeste from celeste_ai.network import DQN from celeste_ai.network import Transition def best_action( model_file: Path, out_filename: Path, *, device = torch.device("cpu") ): if not model_file.is_file(): raise Exception(f"Bad model file {model_file}") out_filename.parent.mkdir(exist_ok = True, parents = True) # Create and load model policy_net = DQN( len(Celeste.state_number_map), len(Celeste.action_space) ).to(device) checkpoint = torch.load( model_file, map_location = device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) # Compute preditions p = np.zeros((128, 128, 2), dtype=np.float32) with torch.no_grad(): for r in range(len(p)): for c in range(len(p[r])): x = c / 128.0 y = r / 128.0 k = np.asarray(policy_net( torch.tensor( [x, y, 0], dtype = torch.float32, device = device ).unsqueeze(0) )[0]) p[r][c][0] = np.argmax(k) k = np.asarray(policy_net( torch.tensor( [x, y, 1], dtype = torch.float32, device = device ).unsqueeze(0) )[0]) p[r][c][1] = np.argmax(k) # Plot predictions fig, axs = plt.subplots(1, 2, figsize = (10, 10)) ax = axs[0] ax.set( adjustable = "box", aspect = "equal", title = "Best Action" ) plot = ax.pcolor( p[:,:,0], cmap = "Set1", vmin = 0, vmax = 8 ) ax.invert_yaxis() fig.colorbar(plot) ax = axs[1] ax.set( adjustable = "box", aspect = "equal", title = "Best Action" ) plot = ax.pcolor( p[:,:,0], cmap = "Set1", vmin = 0, vmax = 8 ) ax.invert_yaxis() fig.colorbar(plot) fig.savefig(out_filename) plt.close()