Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/plots.py

120 lines
2.1 KiB
Python

from pathlib import Path
import torch
from celeste import Celeste
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate
amsgrad = True
)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
for r in range(len(p)):
for c in range(len(p[r])):
with torch.no_grad():
k = net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0][i].item()
p[r][c] = k
return p
for i in Path("out/model_images").iterdir():
checkpoint = torch.load(i)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (15, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor(
makeplt(a, policy_net),
cmap = "Greens_r",
vmin = 0,
vmax = 20
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
fig.colorbar(plot)
print(i)
fig.savefig(f"out/{i.stem}.png")
plt.close()