Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-03-04 14:40:24 -08:00
parent 8420e719d8
commit 058292c0bd
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 29 additions and 5 deletions

View File

@ -3,6 +3,7 @@ import numpy as np
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import json
# All of the following are required to load
# a pickled model.
@ -15,7 +16,8 @@ def best_action(
model_file: Path,
out_filename: Path,
*,
device = torch.device("cpu")
device = torch.device("cpu"),
draw_path = True
):
if not model_file.is_file():
raise Exception(f"Bad model file {model_file}")
@ -81,12 +83,34 @@ def best_action(
vmin = 0,
vmax = 8
)
if draw_path:
d = None
with Path("model_data/solved_4layer/paths.json").open("r") as f:
for l in f.readlines():
t = json.loads(l)
if t["current_image"] == model_file.name:
break
d = t
assert d is not None
plt.plot(
[max(0,x["xpos"]) for x in d["hist"]],
[max(0,x["ypos"] + 5) for x in d["hist"]],
marker = "",
markersize = 0,
linestyle = "-",
linewidth = 5,
color = "white",
solid_capstyle = "round",
solid_joinstyle = "round"
)
ax.invert_yaxis()
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
fig.savefig(out_filename)
plt.close()

View File

@ -7,7 +7,7 @@ from celeste_ai import DQN
model_data_root = Path("model_data/solved_1")
model_data_root = Path("model_data/current")
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
@ -94,7 +94,7 @@ def on_state_after(celeste, before_out):
f.write(json.dumps(
{
"hist": state_history,
"current_image": str(current_path)
"current_image": current_path.name
}
) + "\n")