Cleanup
parent
8420e719d8
commit
058292c0bd
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
from pathlib import Path
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
|
||||
# All of the following are required to load
|
||||
# a pickled model.
|
||||
|
@ -15,7 +16,8 @@ def best_action(
|
|||
model_file: Path,
|
||||
out_filename: Path,
|
||||
*,
|
||||
device = torch.device("cpu")
|
||||
device = torch.device("cpu"),
|
||||
draw_path = True
|
||||
):
|
||||
if not model_file.is_file():
|
||||
raise Exception(f"Bad model file {model_file}")
|
||||
|
@ -81,12 +83,34 @@ def best_action(
|
|||
vmin = 0,
|
||||
vmax = 8
|
||||
)
|
||||
|
||||
|
||||
if draw_path:
|
||||
d = None
|
||||
with Path("model_data/solved_4layer/paths.json").open("r") as f:
|
||||
for l in f.readlines():
|
||||
t = json.loads(l)
|
||||
if t["current_image"] == model_file.name:
|
||||
break
|
||||
d = t
|
||||
assert d is not None
|
||||
|
||||
plt.plot(
|
||||
[max(0,x["xpos"]) for x in d["hist"]],
|
||||
[max(0,x["ypos"] + 5) for x in d["hist"]],
|
||||
marker = "",
|
||||
markersize = 0,
|
||||
linestyle = "-",
|
||||
linewidth = 5,
|
||||
color = "white",
|
||||
solid_capstyle = "round",
|
||||
solid_joinstyle = "round"
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
||||
cbar.ax.set_yticklabels(Celeste.action_space)
|
||||
|
||||
|
||||
|
||||
fig.savefig(out_filename)
|
||||
plt.close()
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from celeste_ai import DQN
|
|||
|
||||
|
||||
|
||||
model_data_root = Path("model_data/solved_1")
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
@ -94,7 +94,7 @@ def on_state_after(celeste, before_out):
|
|||
f.write(json.dumps(
|
||||
{
|
||||
"hist": state_history,
|
||||
"current_image": str(current_path)
|
||||
"current_image": current_path.name
|
||||
}
|
||||
) + "\n")
|
||||
|
Reference in New Issue