2023-02-24 14:23:38 -08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
from pathlib import Path
|
2023-02-26 12:09:05 -08:00
|
|
|
import matplotlib as mpl
|
2023-02-24 14:23:38 -08:00
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
# All of the following are required to load
|
|
|
|
# a pickled model.
|
|
|
|
from celeste_ai.celeste import Celeste
|
|
|
|
from celeste_ai.network import DQN
|
|
|
|
from celeste_ai.network import Transition
|
|
|
|
|
|
|
|
|
|
|
|
def best_action(
|
|
|
|
model_file: Path,
|
|
|
|
out_filename: Path,
|
|
|
|
*,
|
|
|
|
device = torch.device("cpu")
|
|
|
|
):
|
|
|
|
if not model_file.is_file():
|
|
|
|
raise Exception(f"Bad model file {model_file}")
|
|
|
|
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
|
|
|
|
|
|
|
# Create and load model
|
|
|
|
policy_net = DQN(
|
|
|
|
len(Celeste.state_number_map),
|
|
|
|
len(Celeste.action_space)
|
|
|
|
).to(device)
|
|
|
|
checkpoint = torch.load(
|
|
|
|
model_file,
|
|
|
|
map_location = device
|
|
|
|
)
|
|
|
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compute preditions
|
2023-02-26 12:09:05 -08:00
|
|
|
p = np.zeros((128, 128), dtype=np.float32)
|
2023-02-24 14:23:38 -08:00
|
|
|
with torch.no_grad():
|
|
|
|
for r in range(len(p)):
|
|
|
|
for c in range(len(p[r])):
|
|
|
|
x = c / 128.0
|
|
|
|
y = r / 128.0
|
|
|
|
|
|
|
|
k = np.asarray(policy_net(
|
|
|
|
torch.tensor(
|
2023-02-26 12:09:05 -08:00
|
|
|
[x, y],
|
2023-02-24 14:23:38 -08:00
|
|
|
dtype = torch.float32,
|
|
|
|
device = device
|
|
|
|
).unsqueeze(0)
|
|
|
|
)[0])
|
2023-02-26 12:09:05 -08:00
|
|
|
p[r][c] = np.argmax(k)
|
|
|
|
|
|
|
|
|
|
|
|
cmap = mpl.colors.ListedColormap(
|
|
|
|
[
|
|
|
|
"forestgreen",
|
|
|
|
"firebrick",
|
|
|
|
"lightgreen",
|
|
|
|
"salmon",
|
|
|
|
"darkturquoise",
|
|
|
|
"sandybrown",
|
|
|
|
"olive",
|
|
|
|
"darkorchid",
|
|
|
|
"mediumvioletred"
|
|
|
|
]
|
|
|
|
)
|
2023-02-24 14:23:38 -08:00
|
|
|
|
|
|
|
# Plot predictions
|
2023-02-26 12:09:05 -08:00
|
|
|
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
|
|
|
|
ax = axs
|
2023-02-24 14:23:38 -08:00
|
|
|
ax.set(
|
|
|
|
adjustable = "box",
|
|
|
|
aspect = "equal",
|
|
|
|
title = "Best Action"
|
|
|
|
)
|
|
|
|
|
|
|
|
plot = ax.pcolor(
|
2023-02-26 12:09:05 -08:00
|
|
|
p,
|
|
|
|
cmap = cmap,
|
2023-02-24 14:23:38 -08:00
|
|
|
vmin = 0,
|
|
|
|
vmax = 8
|
|
|
|
)
|
|
|
|
ax.invert_yaxis()
|
2023-02-26 12:09:05 -08:00
|
|
|
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
|
|
|
cbar.ax.set_yticklabels(Celeste.action_space)
|
2023-02-24 14:23:38 -08:00
|
|
|
|
2023-02-26 12:09:05 -08:00
|
|
|
|
2023-02-24 14:23:38 -08:00
|
|
|
|
|
|
|
fig.savefig(out_filename)
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|