diff --git a/celeste/plot-actual.py b/celeste/plot-actual.py new file mode 100644 index 0000000..f034b3c --- /dev/null +++ b/celeste/plot-actual.py @@ -0,0 +1,79 @@ +import torch +import numpy as np +from pathlib import Path +import matplotlib.pyplot as plt +from multiprocessing import Pool + +from celeste import Celeste +from main import DQN +from main import Transition + +# Use cpu, this script is faster in parallel. +compute_device = torch.device("cpu") + +input_model = Path("model_data/after_change") + +out_dir = input_model / "plots/actual_reward" +out_dir.mkdir(parents = True, exist_ok = True) + + +checkpoint = torch.load( + input_model / "model.torch", + map_location = compute_device +) +memory = checkpoint["memory"] + +r = np.zeros((128, 128, 8), dtype=np.float32) + +for m in memory: + x, y, x_target, y_target = list(m.state[0]) + + action = m.action[0].item() + str_action = Celeste.action_space[action] + x = int(x.item()) + y = int(y.item()) + x_target = int(x_target.item()) + y_target = int(y_target.item()) + + if (x_target, y_target) != (60, 80): + continue + + if m.reward[0].item() == 1: + r[y][x][action] += 1 + else: + r[y][x][action] -= 1 + + +fig, axs = plt.subplots(2, 4, figsize = (20, 10)) + + +# Plot predictions +for a in range(len(axs.ravel())): + ax = axs.ravel()[a] + ax.set( + adjustable = "box", + aspect = "equal", + title = Celeste.action_space[a] + ) + + plot = ax.pcolor( + r[:,:,a], + cmap = "seismic_r", + vmin = -10, + vmax = 10 + ) + + ax.plot(60, 80, "k.") + #ax.annotate( + # "Target", + # (60, 80), + # textcoords = "offset points", + # xytext = (0, -20), + # ha = "center" + #) + + ax.invert_yaxis() + fig.colorbar(plot) + +fig.savefig(out_dir / "actual.png") +plt.close() diff --git a/celeste/plots.py b/celeste/plots.py index ac81346..4e76e82 100644 --- a/celeste/plots.py +++ b/celeste/plots.py @@ -11,10 +11,12 @@ from main import Transition # Use cpu, this script is faster in parallel. compute_device = torch.device("cpu") -out_dir = Path("out/plots") +input_model = Path("model_data/current") + +src_dir = input_model / "model_archive" +out_dir = input_model_dir / "plots/predicted_value" out_dir.mkdir(parents = True, exist_ok = True) -src_dir = Path("model_data/current/model_archive") def plot(src): @@ -23,7 +25,10 @@ def plot(src): len(Celeste.action_space) ).to(compute_device) - checkpoint = torch.load(src) + checkpoint = torch.load( + src, + map_location = compute_device + ) policy_net.load_state_dict(checkpoint["policy_state_dict"])