Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/plots.py

77 lines
1.4 KiB
Python
Raw Normal View History

2023-02-17 22:29:12 -08:00
import torch
import numpy as np
2023-02-18 19:50:43 -08:00
from pathlib import Path
2023-02-17 22:29:12 -08:00
import matplotlib.pyplot as plt
2023-02-18 19:50:43 -08:00
from multiprocessing import Pool
2023-02-17 22:29:12 -08:00
2023-02-18 19:50:43 -08:00
from celeste import Celeste
from main import DQN
from main import Transition
2023-02-17 22:29:12 -08:00
2023-02-18 21:10:13 -08:00
# Use cpu, this script is faster in parallel.
2023-02-18 19:50:43 -08:00
compute_device = torch.device("cpu")
2023-02-17 22:29:12 -08:00
2023-02-18 19:50:43 -08:00
out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True)
2023-02-17 22:29:12 -08:00
2023-02-18 21:10:13 -08:00
src_dir = Path("model_data/current/model_archive")
2023-02-17 22:29:12 -08:00
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
2023-02-18 21:10:13 -08:00
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
2023-02-17 22:29:12 -08:00
k = net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0][i].item()
p[r][c] = k
return p
2023-02-18 19:50:43 -08:00
def plot(src):
2023-02-18 21:10:13 -08:00
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(compute_device)
2023-02-18 19:50:43 -08:00
checkpoint = torch.load(src)
2023-02-17 22:29:12 -08:00
policy_net.load_state_dict(checkpoint["policy_state_dict"])
2023-02-18 21:10:13 -08:00
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
2023-02-17 22:29:12 -08:00
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
2023-02-18 21:10:13 -08:00
ax.set(
adjustable = "box",
aspect = "equal"
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
2023-02-17 22:29:12 -08:00
plot = ax.pcolor(
makeplt(a, policy_net),
2023-02-18 19:50:43 -08:00
cmap = "Greens",
2023-02-17 22:29:12 -08:00
vmin = 0,
)
2023-02-18 21:10:13 -08:00
2023-02-17 22:29:12 -08:00
fig.colorbar(plot)
2023-02-18 19:50:43 -08:00
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
2023-02-17 22:29:12 -08:00
plt.close()
2023-02-18 19:50:43 -08:00
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))