Mark
/
celeste-ai
Archived
1
0
Fork 0

Added plotter

master
Mark 2023-02-19 19:12:51 -08:00
parent 97f3cabd75
commit 88f871816c
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 87 additions and 3 deletions

79
celeste/plot-actual.py Normal file
View File

@ -0,0 +1,79 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
input_model = Path("model_data/after_change")
out_dir = input_model / "plots/actual_reward"
out_dir.mkdir(parents = True, exist_ok = True)
checkpoint = torch.load(
input_model / "model.torch",
map_location = compute_device
)
memory = checkpoint["memory"]
r = np.zeros((128, 128, 8), dtype=np.float32)
for m in memory:
x, y, x_target, y_target = list(m.state[0])
action = m.action[0].item()
str_action = Celeste.action_space[action]
x = int(x.item())
y = int(y.item())
x_target = int(x_target.item())
y_target = int(y_target.item())
if (x_target, y_target) != (60, 80):
continue
if m.reward[0].item() == 1:
r[y][x][action] += 1
else:
r[y][x][action] -= 1
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Plot predictions
for a in range(len(axs.ravel())):
ax = axs.ravel()[a]
ax.set(
adjustable = "box",
aspect = "equal",
title = Celeste.action_space[a]
)
plot = ax.pcolor(
r[:,:,a],
cmap = "seismic_r",
vmin = -10,
vmax = 10
)
ax.plot(60, 80, "k.")
#ax.annotate(
# "Target",
# (60, 80),
# textcoords = "offset points",
# xytext = (0, -20),
# ha = "center"
#)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_dir / "actual.png")
plt.close()

View File

@ -11,10 +11,12 @@ from main import Transition
# Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu")
out_dir = Path("out/plots")
input_model = Path("model_data/current")
src_dir = input_model / "model_archive"
out_dir = input_model_dir / "plots/predicted_value"
out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/current/model_archive")
def plot(src):
@ -23,7 +25,10 @@ def plot(src):
len(Celeste.action_space)
).to(compute_device)
checkpoint = torch.load(src)
checkpoint = torch.load(
src,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])