Archived
1
0

Reorganized celeste code

This commit is contained in:
2023-02-19 20:57:19 -08:00
parent d648d896c0
commit 55ac62dc47
11 changed files with 201 additions and 157 deletions

View File

@ -0,0 +1,2 @@
from .plot_actual_reward import actual_reward
from .plot_predicted_reward import predicted_reward

View File

@ -0,0 +1,81 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def actual_reward(
model_file: Path,
target_point: tuple[int, int],
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
checkpoint = torch.load(
model_file,
map_location = device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
# Only plot memory related to this point
if (x_target, y_target) != target_point:
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
# Draw target point on plot
ax.plot(
target_point[0],
target_point[1],
"k."
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -0,0 +1,77 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
# Create and load model
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(device)
checkpoint = torch.load(
model_file,
map_location = device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
# Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = np.asarray(policy_net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
p[:,:,a],
cmap = "Greens",
vmin = 0,
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()