Added plotter
parent
97f3cabd75
commit
88f871816c
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
from celeste import Celeste
|
||||||
|
from main import DQN
|
||||||
|
from main import Transition
|
||||||
|
|
||||||
|
# Use cpu, this script is faster in parallel.
|
||||||
|
compute_device = torch.device("cpu")
|
||||||
|
|
||||||
|
input_model = Path("model_data/after_change")
|
||||||
|
|
||||||
|
out_dir = input_model / "plots/actual_reward"
|
||||||
|
out_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
checkpoint = torch.load(
|
||||||
|
input_model / "model.torch",
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
memory = checkpoint["memory"]
|
||||||
|
|
||||||
|
r = np.zeros((128, 128, 8), dtype=np.float32)
|
||||||
|
|
||||||
|
for m in memory:
|
||||||
|
x, y, x_target, y_target = list(m.state[0])
|
||||||
|
|
||||||
|
action = m.action[0].item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
x = int(x.item())
|
||||||
|
y = int(y.item())
|
||||||
|
x_target = int(x_target.item())
|
||||||
|
y_target = int(y_target.item())
|
||||||
|
|
||||||
|
if (x_target, y_target) != (60, 80):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.reward[0].item() == 1:
|
||||||
|
r[y][x][action] += 1
|
||||||
|
else:
|
||||||
|
r[y][x][action] -= 1
|
||||||
|
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||||
|
|
||||||
|
|
||||||
|
# Plot predictions
|
||||||
|
for a in range(len(axs.ravel())):
|
||||||
|
ax = axs.ravel()[a]
|
||||||
|
ax.set(
|
||||||
|
adjustable = "box",
|
||||||
|
aspect = "equal",
|
||||||
|
title = Celeste.action_space[a]
|
||||||
|
)
|
||||||
|
|
||||||
|
plot = ax.pcolor(
|
||||||
|
r[:,:,a],
|
||||||
|
cmap = "seismic_r",
|
||||||
|
vmin = -10,
|
||||||
|
vmax = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.plot(60, 80, "k.")
|
||||||
|
#ax.annotate(
|
||||||
|
# "Target",
|
||||||
|
# (60, 80),
|
||||||
|
# textcoords = "offset points",
|
||||||
|
# xytext = (0, -20),
|
||||||
|
# ha = "center"
|
||||||
|
#)
|
||||||
|
|
||||||
|
ax.invert_yaxis()
|
||||||
|
fig.colorbar(plot)
|
||||||
|
|
||||||
|
fig.savefig(out_dir / "actual.png")
|
||||||
|
plt.close()
|
|
@ -11,10 +11,12 @@ from main import Transition
|
||||||
# Use cpu, this script is faster in parallel.
|
# Use cpu, this script is faster in parallel.
|
||||||
compute_device = torch.device("cpu")
|
compute_device = torch.device("cpu")
|
||||||
|
|
||||||
out_dir = Path("out/plots")
|
input_model = Path("model_data/current")
|
||||||
|
|
||||||
|
src_dir = input_model / "model_archive"
|
||||||
|
out_dir = input_model_dir / "plots/predicted_value"
|
||||||
out_dir.mkdir(parents = True, exist_ok = True)
|
out_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
src_dir = Path("model_data/current/model_archive")
|
|
||||||
|
|
||||||
|
|
||||||
def plot(src):
|
def plot(src):
|
||||||
|
@ -23,7 +25,10 @@ def plot(src):
|
||||||
len(Celeste.action_space)
|
len(Celeste.action_space)
|
||||||
).to(compute_device)
|
).to(compute_device)
|
||||||
|
|
||||||
checkpoint = torch.load(src)
|
checkpoint = torch.load(
|
||||||
|
src,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in New Issue