2023-02-17 22:29:12 -08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
2023-02-18 19:50:43 -08:00
|
|
|
from pathlib import Path
|
2023-02-17 22:29:12 -08:00
|
|
|
import matplotlib.pyplot as plt
|
2023-02-18 19:50:43 -08:00
|
|
|
from multiprocessing import Pool
|
2023-02-17 22:29:12 -08:00
|
|
|
|
2023-02-18 19:50:43 -08:00
|
|
|
from celeste import Celeste
|
|
|
|
from main import DQN
|
|
|
|
from main import Transition
|
2023-02-17 22:29:12 -08:00
|
|
|
|
2023-02-18 21:10:13 -08:00
|
|
|
# Use cpu, this script is faster in parallel.
|
2023-02-18 19:50:43 -08:00
|
|
|
compute_device = torch.device("cpu")
|
2023-02-17 22:29:12 -08:00
|
|
|
|
2023-02-19 19:12:51 -08:00
|
|
|
input_model = Path("model_data/current")
|
|
|
|
|
|
|
|
src_dir = input_model / "model_archive"
|
|
|
|
out_dir = input_model_dir / "plots/predicted_value"
|
2023-02-18 19:50:43 -08:00
|
|
|
out_dir.mkdir(parents = True, exist_ok = True)
|
2023-02-17 22:29:12 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
2023-02-18 19:50:43 -08:00
|
|
|
def plot(src):
|
2023-02-18 21:10:13 -08:00
|
|
|
policy_net = DQN(
|
|
|
|
len(Celeste.state_number_map),
|
|
|
|
len(Celeste.action_space)
|
|
|
|
).to(compute_device)
|
|
|
|
|
2023-02-19 19:12:51 -08:00
|
|
|
checkpoint = torch.load(
|
|
|
|
src,
|
|
|
|
map_location = compute_device
|
|
|
|
)
|
2023-02-17 22:29:12 -08:00
|
|
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
|
|
|
|
|
|
|
|
2023-02-18 21:10:13 -08:00
|
|
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
2023-02-17 22:29:12 -08:00
|
|
|
|
2023-02-19 08:49:33 -08:00
|
|
|
|
|
|
|
# Compute preditions
|
|
|
|
p = np.zeros((128, 128, 8), dtype=np.float32)
|
|
|
|
with torch.no_grad():
|
|
|
|
for r in range(len(p)):
|
|
|
|
for c in range(len(p[r])):
|
|
|
|
k = np.asarray(policy_net(
|
|
|
|
torch.tensor(
|
|
|
|
[c, r, 60, 80],
|
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
)[0])
|
|
|
|
p[r][c] = k
|
|
|
|
|
|
|
|
|
|
|
|
# Plot predictions
|
2023-02-17 22:29:12 -08:00
|
|
|
for a in range(len(axs.ravel())):
|
|
|
|
ax = axs.ravel()[a]
|
2023-02-18 21:10:13 -08:00
|
|
|
ax.set(
|
|
|
|
adjustable = "box",
|
2023-02-19 08:49:33 -08:00
|
|
|
aspect = "equal",
|
|
|
|
title = Celeste.action_space[a]
|
2023-02-18 21:10:13 -08:00
|
|
|
)
|
|
|
|
|
2023-02-17 22:29:12 -08:00
|
|
|
plot = ax.pcolor(
|
2023-02-19 08:49:33 -08:00
|
|
|
p[:,:,a],
|
2023-02-18 19:50:43 -08:00
|
|
|
cmap = "Greens",
|
2023-02-17 22:29:12 -08:00
|
|
|
vmin = 0,
|
|
|
|
)
|
2023-02-18 21:10:13 -08:00
|
|
|
|
2023-02-19 08:49:33 -08:00
|
|
|
ax.invert_yaxis()
|
2023-02-17 22:29:12 -08:00
|
|
|
fig.colorbar(plot)
|
2023-02-18 19:50:43 -08:00
|
|
|
print(src)
|
|
|
|
fig.savefig(out_dir / f"{src.stem}.png")
|
2023-02-17 22:29:12 -08:00
|
|
|
plt.close()
|
2023-02-18 19:50:43 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
with Pool(5) as p:
|
|
|
|
p.map(plot, list(src_dir.iterdir()))
|
|
|
|
|
|
|
|
|