Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/plots.py

77 lines
1.4 KiB
Python

import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/current/model_archive")
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0][i].item()
p[r][c] = k
return p
def plot(src):
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(compute_device)
checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal"
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
plot = ax.pcolor(
makeplt(a, policy_net),
cmap = "Greens",
vmin = 0,
)
fig.colorbar(plot)
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))