Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/celeste_ai/plotting/plot_predicted_reward.py

85 lines
1.5 KiB
Python
Raw Normal View History

2023-02-17 22:29:12 -08:00
import torch
import numpy as np
2023-02-18 19:50:43 -08:00
from pathlib import Path
2023-02-17 22:29:12 -08:00
import matplotlib.pyplot as plt
2023-02-19 20:57:19 -08:00
# All of the following are required to load
# a pickled model.
from celeste_ai.celeste import Celeste
from celeste_ai.network import DQN
from celeste_ai.network import Transition
2023-02-17 22:29:12 -08:00
2023-02-19 20:57:19 -08:00
def predicted_reward(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
out_filename.parent.mkdir(exist_ok = True, parents = True)
2023-02-19 19:12:51 -08:00
2023-02-19 20:57:19 -08:00
# Create and load model
2023-02-18 21:10:13 -08:00
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
2023-02-19 20:57:19 -08:00
).to(device)
2023-02-19 19:12:51 -08:00
checkpoint = torch.load(
2023-02-19 20:57:19 -08:00
model_file,
map_location = device
2023-02-19 19:12:51 -08:00
)
2023-02-17 22:29:12 -08:00
policy_net.load_state_dict(checkpoint["policy_state_dict"])
2023-02-19 08:49:33 -08:00
# Compute preditions
2023-02-24 14:23:25 -08:00
p = np.zeros((128, 128, 9), dtype=np.float32)
2023-02-19 08:49:33 -08:00
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
2023-02-24 14:23:25 -08:00
x = c / 128.0
y = r / 128.0
2023-02-19 08:49:33 -08:00
k = np.asarray(policy_net(
torch.tensor(
2023-02-26 12:09:05 -08:00
[x, y],
2023-02-19 08:49:33 -08:00
dtype = torch.float32,
2023-02-19 20:57:19 -08:00
device = device
2023-02-19 08:49:33 -08:00
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
2023-02-24 14:23:25 -08:00
fig, axs = plt.subplots(2, 5, figsize = (20, 10))
2023-02-17 22:29:12 -08:00
for a in range(len(axs.ravel())):
2023-02-24 14:23:25 -08:00
if a >= len(Celeste.action_space):
continue
2023-02-17 22:29:12 -08:00
ax = axs.ravel()[a]
2023-02-18 21:10:13 -08:00
ax.set(
adjustable = "box",
2023-02-19 08:49:33 -08:00
aspect = "equal",
title = Celeste.action_space[a]
2023-02-18 21:10:13 -08:00
)
2023-02-17 22:29:12 -08:00
plot = ax.pcolor(
2023-02-19 08:49:33 -08:00
p[:,:,a],
2023-02-18 19:50:43 -08:00
cmap = "Greens",
2023-02-17 22:29:12 -08:00
vmin = 0,
2023-02-24 14:23:25 -08:00
#vmax = 5
2023-02-17 22:29:12 -08:00
)
2023-02-18 21:10:13 -08:00
2023-02-19 08:49:33 -08:00
ax.invert_yaxis()
2023-02-17 22:29:12 -08:00
fig.colorbar(plot)
2023-02-19 20:57:19 -08:00
fig.savefig(out_filename)
2023-02-17 22:29:12 -08:00
plt.close()
2023-02-18 19:50:43 -08:00
2023-02-19 20:57:19 -08:00
2023-02-18 19:50:43 -08:00