Cleanup
parent
8420e719d8
commit
058292c0bd
|
@ -3,6 +3,7 @@ import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import matplotlib as mpl
|
import matplotlib as mpl
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
|
||||||
# All of the following are required to load
|
# All of the following are required to load
|
||||||
# a pickled model.
|
# a pickled model.
|
||||||
|
@ -15,7 +16,8 @@ def best_action(
|
||||||
model_file: Path,
|
model_file: Path,
|
||||||
out_filename: Path,
|
out_filename: Path,
|
||||||
*,
|
*,
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu"),
|
||||||
|
draw_path = True
|
||||||
):
|
):
|
||||||
if not model_file.is_file():
|
if not model_file.is_file():
|
||||||
raise Exception(f"Bad model file {model_file}")
|
raise Exception(f"Bad model file {model_file}")
|
||||||
|
@ -81,12 +83,34 @@ def best_action(
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
vmax = 8
|
vmax = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if draw_path:
|
||||||
|
d = None
|
||||||
|
with Path("model_data/solved_4layer/paths.json").open("r") as f:
|
||||||
|
for l in f.readlines():
|
||||||
|
t = json.loads(l)
|
||||||
|
if t["current_image"] == model_file.name:
|
||||||
|
break
|
||||||
|
d = t
|
||||||
|
assert d is not None
|
||||||
|
|
||||||
|
plt.plot(
|
||||||
|
[max(0,x["xpos"]) for x in d["hist"]],
|
||||||
|
[max(0,x["ypos"] + 5) for x in d["hist"]],
|
||||||
|
marker = "",
|
||||||
|
markersize = 0,
|
||||||
|
linestyle = "-",
|
||||||
|
linewidth = 5,
|
||||||
|
color = "white",
|
||||||
|
solid_capstyle = "round",
|
||||||
|
solid_joinstyle = "round"
|
||||||
|
)
|
||||||
|
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
||||||
cbar.ax.set_yticklabels(Celeste.action_space)
|
cbar.ax.set_yticklabels(Celeste.action_space)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
fig.savefig(out_filename)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from celeste_ai import DQN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_data_root = Path("model_data/solved_1")
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
@ -94,7 +94,7 @@ def on_state_after(celeste, before_out):
|
||||||
f.write(json.dumps(
|
f.write(json.dumps(
|
||||||
{
|
{
|
||||||
"hist": state_history,
|
"hist": state_history,
|
||||||
"current_image": str(current_path)
|
"current_image": current_path.name
|
||||||
}
|
}
|
||||||
) + "\n")
|
) + "\n")
|
||||||
|
|
Reference in New Issue