Added action plot generator
parent
2706a0af3f
commit
e76a78d199
|
@ -0,0 +1,119 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from celeste import Celeste
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
|
||||||
|
# Outline our network
|
||||||
|
class DQN(torch.nn.Module):
|
||||||
|
def __init__(self, n_observations: int, n_actions: int):
|
||||||
|
super(DQN, self).__init__()
|
||||||
|
|
||||||
|
self.layers = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(n_observations, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.torch.nn.Linear(128, n_actions)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Can be called with one input, or with a batch.
|
||||||
|
#
|
||||||
|
# Returns tensor(
|
||||||
|
# [ Q(s, left), Q(s, right) ], ...
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# Recall that Q(s, a) is the (expected) return of taking
|
||||||
|
# action `a` at state `s`
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
target_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
policy_net.parameters(),
|
||||||
|
lr = 0.01, # Hyperparameter: learning rate
|
||||||
|
amsgrad = True
|
||||||
|
)
|
||||||
|
|
||||||
|
Transition = namedtuple(
|
||||||
|
"Transition",
|
||||||
|
(
|
||||||
|
"state",
|
||||||
|
"action",
|
||||||
|
"next_state",
|
||||||
|
"reward"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def makeplt(i, net):
|
||||||
|
p = np.zeros((128, 128), dtype=np.float32)
|
||||||
|
|
||||||
|
for r in range(len(p)):
|
||||||
|
for c in range(len(p[r])):
|
||||||
|
with torch.no_grad():
|
||||||
|
k = net(
|
||||||
|
torch.tensor(
|
||||||
|
[c, r, 60, 80],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
)[0][i].item()
|
||||||
|
|
||||||
|
p[r][c] = k
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
for i in Path("out/model_images").iterdir():
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = torch.load(i)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 4, figsize = (15, 10))
|
||||||
|
|
||||||
|
for a in range(len(axs.ravel())):
|
||||||
|
ax = axs.ravel()[a]
|
||||||
|
ax.set(adjustable="box", aspect="equal")
|
||||||
|
plot = ax.pcolor(
|
||||||
|
makeplt(a, policy_net),
|
||||||
|
cmap = "Greens_r",
|
||||||
|
vmin = 0,
|
||||||
|
vmax = 20
|
||||||
|
)
|
||||||
|
ax.set_title(Celeste.action_space[a])
|
||||||
|
ax.invert_yaxis()
|
||||||
|
fig.colorbar(plot)
|
||||||
|
print(i)
|
||||||
|
fig.savefig(f"out/{i.stem}.png")
|
||||||
|
plt.close()
|
Reference in New Issue