from pathlib import Path import torch from celeste import Celeste import numpy as np import matplotlib.pyplot as plt from collections import namedtuple compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) # Outline our network class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(n_observations, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.torch.nn.Linear(128, n_actions) ) # Can be called with one input, or with a batch. # # Returns tensor( # [ Q(s, left), Q(s, right) ], ... # ) # # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): return self.layers(x) policy_net = DQN( n_observations, n_actions ).to(compute_device) target_net = DQN( n_observations, n_actions ).to(compute_device) optimizer = torch.optim.AdamW( policy_net.parameters(), lr = 0.01, # Hyperparameter: learning rate amsgrad = True ) Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) def makeplt(i, net): p = np.zeros((128, 128), dtype=np.float32) for r in range(len(p)): for c in range(len(p[r])): with torch.no_grad(): k = net( torch.tensor( [c, r, 60, 80], dtype = torch.float32, device = compute_device ).unsqueeze(0) )[0][i].item() p[r][c] = k return p for i in Path("out/model_images").iterdir(): checkpoint = torch.load(i) policy_net.load_state_dict(checkpoint["policy_state_dict"]) fig, axs = plt.subplots(2, 4, figsize = (15, 10)) for a in range(len(axs.ravel())): ax = axs.ravel()[a] ax.set(adjustable="box", aspect="equal") plot = ax.pcolor( makeplt(a, policy_net), cmap = "Greens_r", vmin = 0, vmax = 20 ) ax.set_title(Celeste.action_space[a]) ax.invert_yaxis() fig.colorbar(plot) print(i) fig.savefig(f"out/{i.stem}.png") plt.close()