Archived
1
0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues or pull requests.
Files
celeste-ai/celeste/plots.py
2023-02-19 19:12:51 -08:00

81 lines
1.5 KiB
Python

import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
input_model = Path("model_data/current")
src_dir = input_model / "model_archive"
out_dir = input_model_dir / "plots/predicted_value"
out_dir.mkdir(parents = True, exist_ok = True)
def plot(src):
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(compute_device)
checkpoint = torch.load(
src,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = np.asarray(policy_net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
p[:,:,a],
cmap = "Greens",
vmin = 0,
)
ax.invert_yaxis()
fig.colorbar(plot)
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))