diff --git a/celeste/celeste_ai/plotting/plot_best_action.py b/celeste/celeste_ai/plotting/plot_best_action.py index 100bea2..4690616 100644 --- a/celeste/celeste_ai/plotting/plot_best_action.py +++ b/celeste/celeste_ai/plotting/plot_best_action.py @@ -3,6 +3,7 @@ import numpy as np from pathlib import Path import matplotlib as mpl import matplotlib.pyplot as plt +import json # All of the following are required to load # a pickled model. @@ -15,7 +16,8 @@ def best_action( model_file: Path, out_filename: Path, *, - device = torch.device("cpu") + device = torch.device("cpu"), + draw_path = True ): if not model_file.is_file(): raise Exception(f"Bad model file {model_file}") @@ -81,12 +83,34 @@ def best_action( vmin = 0, vmax = 8 ) + + + if draw_path: + d = None + with Path("model_data/solved_4layer/paths.json").open("r") as f: + for l in f.readlines(): + t = json.loads(l) + if t["current_image"] == model_file.name: + break + d = t + assert d is not None + + plt.plot( + [max(0,x["xpos"]) for x in d["hist"]], + [max(0,x["ypos"] + 5) for x in d["hist"]], + marker = "", + markersize = 0, + linestyle = "-", + linewidth = 5, + color = "white", + solid_capstyle = "round", + solid_joinstyle = "round" + ) + ax.invert_yaxis() cbar = fig.colorbar(plot, ticks = list(range(0, 9))) cbar.ax.set_yticklabels(Celeste.action_space) - - fig.savefig(out_filename) plt.close() diff --git a/celeste/celeste_ai/paths.py b/celeste/celeste_ai/record_paths.py similarity index 95% rename from celeste/celeste_ai/paths.py rename to celeste/celeste_ai/record_paths.py index 210e5a3..a4dfc2e 100644 --- a/celeste/celeste_ai/paths.py +++ b/celeste/celeste_ai/record_paths.py @@ -7,7 +7,7 @@ from celeste_ai import DQN -model_data_root = Path("model_data/solved_1") +model_data_root = Path("model_data/current") compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" @@ -94,7 +94,7 @@ def on_state_after(celeste, before_out): f.write(json.dumps( { "hist": state_history, - "current_image": str(current_path) + "current_image": current_path.name } ) + "\n")