import torch import numpy as np from pathlib import Path import matplotlib.pyplot as plt from multiprocessing import Pool from celeste import Celeste from main import DQN from main import Transition # Use cpu, this script is faster in parallel. compute_device = torch.device("cpu") input_model = Path("model_data/after_change") out_dir = input_model / "plots/actual_reward" out_dir.mkdir(parents = True, exist_ok = True) checkpoint = torch.load( input_model / "model.torch", map_location = compute_device ) memory = checkpoint["memory"] r = np.zeros((128, 128, 8), dtype=np.float32) for m in memory: x, y, x_target, y_target = list(m.state[0]) action = m.action[0].item() str_action = Celeste.action_space[action] x = int(x.item()) y = int(y.item()) x_target = int(x_target.item()) y_target = int(y_target.item()) if (x_target, y_target) != (60, 80): continue if m.reward[0].item() == 1: r[y][x][action] += 1 else: r[y][x][action] -= 1 fig, axs = plt.subplots(2, 4, figsize = (20, 10)) # Plot predictions for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set( adjustable = "box", aspect = "equal", title = Celeste.action_space[a] ) plot = ax.pcolor( r[:,:,a], cmap = "seismic_r", vmin = -10, vmax = 10 ) ax.plot(60, 80, "k.") #ax.annotate( # "Target", # (60, 80), # textcoords = "offset points", # xytext = (0, -20), # ha = "center" #) ax.invert_yaxis() fig.colorbar(plot) fig.savefig(out_dir / "actual.png") plt.close()