import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt from multiprocessing import Pool # All of the following are required to load # a pickled model. from celeste_ai.celeste import Celeste from celeste_ai.network import DQN from celeste_ai.network import Transition def actual_reward( model_file: Path, target_point: tuple[int, int], out_filename: Path, *, device = torch.device("cpu") ): if not model_file.is_file(): raise Exception(f"Bad model file {model_file}") out_filename.parent.mkdir(exist_ok = True, parents = True) checkpoint = torch.load( model_file, map_location = device ) memory = checkpoint["memory"] r = np.zeros((128, 128, 8), dtype=np.float32) for m in memory: x, y, x_target, y_target = list(m.state[0]) action = m.action[0].item() x = int(x.item()) y = int(y.item()) x_target = int(x_target.item()) y_target = int(y_target.item()) # Only plot memory related to this point if (x_target, y_target) != target_point: continue if m.reward[0].item() == 1: r[y][x][action] += 1 else: r[y][x][action] -= 1 fig, axs = plt.subplots(2, 4, figsize = (20, 10)) for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set( adjustable = "box", aspect = "equal", title = Celeste.action_space[a] ) plot = ax.pcolor( r[:,:,a], cmap = "seismic_r", vmin = -10, vmax = 10 ) # Draw target point on plot ax.plot( target_point[0], target_point[1], "k." ) ax.invert_yaxis() fig.colorbar(plot) fig.savefig(out_filename) plt.close()