Cleanup
This commit is contained in:
3
celeste_ai/plotting/__init__.py
Normal file
3
celeste_ai/plotting/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .plot_predicted_reward import predicted_reward
|
||||
from .plot_best_action import best_action
|
||||
|
121
celeste_ai/plotting/plot_best_action.py
Normal file
121
celeste_ai/plotting/plot_best_action.py
Normal file
@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
|
||||
def best_action(
|
||||
model_file: Path,
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu"),
|
||||
draw_path = True
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Create and load model
|
||||
policy_net = DQN(
|
||||
len(Celeste.state_number_map),
|
||||
len(Celeste.action_space)
|
||||
).to(device)
|
||||
checkpoint = torch.load(
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
|
||||
# Compute preditions
|
||||
p = np.zeros((128, 128), dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
for r in range(len(p)):
|
||||
for c in range(len(p[r])):
|
||||
x = c / 128.0
|
||||
y = r / 128.0
|
||||
|
||||
k = np.asarray(policy_net(
|
||||
torch.tensor(
|
||||
[x, y],
|
||||
dtype = torch.float32,
|
||||
device = device
|
||||
).unsqueeze(0)
|
||||
)[0])
|
||||
p[r][c] = np.argmax(k)
|
||||
|
||||
|
||||
cmap = mpl.colors.ListedColormap(
|
||||
[
|
||||
"forestgreen",
|
||||
"firebrick",
|
||||
"lightgreen",
|
||||
"salmon",
|
||||
"darkturquoise",
|
||||
"sandybrown",
|
||||
"olive",
|
||||
"darkorchid",
|
||||
"mediumvioletred"
|
||||
]
|
||||
)
|
||||
|
||||
# Plot predictions
|
||||
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
|
||||
ax = axs
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = "Best Action"
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
p,
|
||||
cmap = cmap,
|
||||
vmin = 0,
|
||||
vmax = 8
|
||||
)
|
||||
|
||||
|
||||
if draw_path:
|
||||
d = None
|
||||
with Path("model_data/solved_4layer/paths.json").open("r") as f:
|
||||
for l in f.readlines():
|
||||
t = json.loads(l)
|
||||
if t["current_image"] == model_file.name:
|
||||
break
|
||||
d = t
|
||||
assert d is not None
|
||||
|
||||
plt.plot(
|
||||
[max(0,x["xpos"]) for x in d["hist"]],
|
||||
[max(0,x["ypos"] + 5) for x in d["hist"]],
|
||||
marker = "",
|
||||
markersize = 0,
|
||||
linestyle = "-",
|
||||
linewidth = 5,
|
||||
color = "white",
|
||||
solid_capstyle = "round",
|
||||
solid_joinstyle = "round"
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
||||
cbar.ax.set_yticklabels(Celeste.action_space)
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
84
celeste_ai/plotting/plot_predicted_reward.py
Normal file
84
celeste_ai/plotting/plot_predicted_reward.py
Normal file
@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
from celeste_ai.celeste import Celeste
|
||||
from celeste_ai.network import DQN
|
||||
from celeste_ai.network import Transition
|
||||
|
||||
|
||||
def predicted_reward(
|
||||
model_file: Path,
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||||
|
||||
# Create and load model
|
||||
policy_net = DQN(
|
||||
len(Celeste.state_number_map),
|
||||
len(Celeste.action_space)
|
||||
).to(device)
|
||||
checkpoint = torch.load(
|
||||
model_file,
|
||||
map_location = device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
|
||||
# Compute preditions
|
||||
p = np.zeros((128, 128, 9), dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
for r in range(len(p)):
|
||||
for c in range(len(p[r])):
|
||||
x = c / 128.0
|
||||
y = r / 128.0
|
||||
|
||||
k = np.asarray(policy_net(
|
||||
torch.tensor(
|
||||
[x, y],
|
||||
dtype = torch.float32,
|
||||
device = device
|
||||
).unsqueeze(0)
|
||||
)[0])
|
||||
p[r][c] = k
|
||||
|
||||
|
||||
# Plot predictions
|
||||
fig, axs = plt.subplots(2, 5, figsize = (20, 10))
|
||||
for a in range(len(axs.ravel())):
|
||||
if a >= len(Celeste.action_space):
|
||||
continue
|
||||
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
|
||||
plot = ax.pcolor(
|
||||
p[:,:,a],
|
||||
cmap = "Greens",
|
||||
vmin = 0,
|
||||
#vmax = 5
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user