Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/plots.py

90 lines
1.6 KiB
Python

import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, the script is faster in parallel.
compute_device = torch.device("cpu")
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/model_archive")
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate
amsgrad = True
)
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
for r in range(len(p)):
for c in range(len(p[r])):
with torch.no_grad():
k = net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0][i].item()
p[r][c] = k
return p
def plot(src):
checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (15, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor(
makeplt(a, policy_net),
cmap = "Greens",
vmin = 0,
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
fig.colorbar(plot)
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))