81 lines
1.5 KiB
Python
81 lines
1.5 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
from pathlib import Path
|
||
|
import matplotlib.pyplot as plt
|
||
|
from multiprocessing import Pool
|
||
|
|
||
|
# All of the following are required to load
|
||
|
# a pickled model.
|
||
|
from celeste_ai.celeste import Celeste
|
||
|
from celeste_ai.network import DQN
|
||
|
from celeste_ai.network import Transition
|
||
|
|
||
|
def actual_reward(
|
||
|
model_file: Path,
|
||
|
target_point: tuple[int, int],
|
||
|
out_filename: Path,
|
||
|
*,
|
||
|
device = torch.device("cpu")
|
||
|
):
|
||
|
if not model_file.is_file():
|
||
|
raise Exception(f"Bad model file {model_file}")
|
||
|
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
||
|
|
||
|
|
||
|
checkpoint = torch.load(
|
||
|
model_file,
|
||
|
map_location = device
|
||
|
)
|
||
|
memory = checkpoint["memory"]
|
||
|
|
||
|
|
||
|
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||
|
for m in memory:
|
||
|
x, y, x_target, y_target = list(m.state[0])
|
||
|
|
||
|
action = m.action[0].item()
|
||
|
x = int(x.item())
|
||
|
y = int(y.item())
|
||
|
x_target = int(x_target.item())
|
||
|
y_target = int(y_target.item())
|
||
|
|
||
|
# Only plot memory related to this point
|
||
|
if (x_target, y_target) != target_point:
|
||
|
continue
|
||
|
|
||
|
if m.reward[0].item() == 1:
|
||
|
r[y][x][action] += 1
|
||
|
else:
|
||
|
r[y][x][action] -= 1
|
||
|
|
||
|
|
||
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||
|
|
||
|
|
||
|
for a in range(len(axs.ravel())):
|
||
|
ax = axs.ravel()[a]
|
||
|
ax.set(
|
||
|
adjustable = "box",
|
||
|
aspect = "equal",
|
||
|
title = Celeste.action_space[a]
|
||
|
)
|
||
|
|
||
|
plot = ax.pcolor(
|
||
|
r[:,:,a],
|
||
|
cmap = "seismic_r",
|
||
|
vmin = -10,
|
||
|
vmax = 10
|
||
|
)
|
||
|
|
||
|
# Draw target point on plot
|
||
|
ax.plot(
|
||
|
target_point[0],
|
||
|
target_point[1],
|
||
|
"k."
|
||
|
)
|
||
|
|
||
|
ax.invert_yaxis()
|
||
|
fig.colorbar(plot)
|
||
|
|
||
|
fig.savefig(out_filename)
|
||
|
plt.close()
|