diff --git a/celeste/plots.py b/celeste/plots.py new file mode 100644 index 0000000..62c6ce3 --- /dev/null +++ b/celeste/plots.py @@ -0,0 +1,119 @@ +from pathlib import Path +import torch +from celeste import Celeste +import numpy as np +import matplotlib.pyplot as plt +from collections import namedtuple + + +compute_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" +) + + +# Celeste env properties +n_observations = len(Celeste.state_number_map) +n_actions = len(Celeste.action_space) + + +# Outline our network +class DQN(torch.nn.Module): + def __init__(self, n_observations: int, n_actions: int): + super(DQN, self).__init__() + + self.layers = torch.nn.Sequential( + torch.nn.Linear(n_observations, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.nn.Linear(128, 128), + torch.nn.ReLU(), + + torch.torch.nn.Linear(128, n_actions) + ) + + # Can be called with one input, or with a batch. + # + # Returns tensor( + # [ Q(s, left), Q(s, right) ], ... + # ) + # + # Recall that Q(s, a) is the (expected) return of taking + # action `a` at state `s` + def forward(self, x): + return self.layers(x) + + +policy_net = DQN( + n_observations, + n_actions +).to(compute_device) + +target_net = DQN( + n_observations, + n_actions +).to(compute_device) + +optimizer = torch.optim.AdamW( + policy_net.parameters(), + lr = 0.01, # Hyperparameter: learning rate + amsgrad = True +) + +Transition = namedtuple( + "Transition", + ( + "state", + "action", + "next_state", + "reward" + ) +) + + + + +def makeplt(i, net): + p = np.zeros((128, 128), dtype=np.float32) + + for r in range(len(p)): + for c in range(len(p[r])): + with torch.no_grad(): + k = net( + torch.tensor( + [c, r, 60, 80], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + )[0][i].item() + + p[r][c] = k + return p + + +for i in Path("out/model_images").iterdir(): + + + checkpoint = torch.load(i) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + + + fig, axs = plt.subplots(2, 4, figsize = (15, 10)) + + for a in range(len(axs.ravel())): + ax = axs.ravel()[a] + ax.set(adjustable="box", aspect="equal") + plot = ax.pcolor( + makeplt(a, policy_net), + cmap = "Greens_r", + vmin = 0, + vmax = 20 + ) + ax.set_title(Celeste.action_space[a]) + ax.invert_yaxis() + fig.colorbar(plot) + print(i) + fig.savefig(f"out/{i.stem}.png") + plt.close()