Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/celeste_ai/plotting/plot_best_action.py

122 lines
2.2 KiB
Python

import torch
import numpy as np
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import json
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def best_action(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu"),
draw_path = True
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
x = c / 128.0
y = r / 128.0
k = np.asarray(policy_net(
torch.tensor(
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
ax = axs
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p,
cmap = cmap,
vmin = 0,
vmax = 8
)
if draw_path:
d = None
with Path("model_data/solved_4layer/paths.json").open("r") as f:
for l in f.readlines():
t = json.loads(l)
if t["current_image"] == model_file.name:
break
d = t
assert d is not None
plt.plot(
[max(0,x["xpos"]) for x in d["hist"]],
[max(0,x["ypos"] + 5) for x in d["hist"]],
marker = "",
markersize = 0,
linestyle = "-",
linewidth = 5,
color = "white",
solid_capstyle = "round",
solid_joinstyle = "round"
)
ax.invert_yaxis()
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
fig.savefig(out_filename)
plt.close()