Mark
/
celeste-ai
Archived
1
0
Fork 0

Added best-action plot

master
Mark 2023-02-24 14:23:38 -08:00
parent bae70e0cfa
commit c9e04dcd41
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 107 additions and 0 deletions

View File

@ -1,2 +1,4 @@
from .plot_actual_reward import actual_reward
from .plot_predicted_reward import predicted_reward
from .plot_best_action import best_action

View File

@ -0,0 +1,105 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def best_action(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128, 2), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
x = c / 128.0
y = r / 128.0
k = np.asarray(policy_net(
torch.tensor(
[x, y, 0],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][0] = np.argmax(k)
k = np.asarray(policy_net(
torch.tensor(
[x, y, 1],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][1] = np.argmax(k)
# Plot predictions
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
ax = axs[0]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
ax = axs[1]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()