Added path configuration
parent
610e5eef92
commit
4876664178
|
@ -9,8 +9,17 @@ import torch
|
||||||
from celeste import Celeste
|
from celeste import Celeste
|
||||||
|
|
||||||
|
|
||||||
run_data_path = Path("out")
|
# Where to read/write model data.
|
||||||
run_data_path.mkdir(parents = True, exist_ok = True)
|
model_data_root = Path("model_data")
|
||||||
|
|
||||||
|
model_save_path = model_data_root / "model.torch"
|
||||||
|
model_archive_dir = model_data_root / "model_archive"
|
||||||
|
model_train_log = model_data_root / "train_log"
|
||||||
|
screenshot_dir = model_data_root / "screenshots"
|
||||||
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
|
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
@ -307,9 +316,9 @@ def optimize_model():
|
||||||
episode_number = 0
|
episode_number = 0
|
||||||
|
|
||||||
|
|
||||||
if (run_data_path/"checkpoint.torch").is_file():
|
if model_save_path.is_file():
|
||||||
# Load model if one exists
|
# Load model if one exists
|
||||||
checkpoint = torch.load((run_data_path/"checkpoint.torch"))
|
checkpoint = torch.load(model_save_path)
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
@ -437,7 +446,7 @@ def on_state_after(celeste, before_out):
|
||||||
# a terminal state.
|
# a terminal state.
|
||||||
if (next_state.deaths != 0):
|
if (next_state.deaths != 0):
|
||||||
s = celeste.state
|
s = celeste.state
|
||||||
with open(run_data_path / "train.log", "a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
"checkpoints": s.next_point,
|
"checkpoints": s.next_point,
|
||||||
"state_count": s.state_count
|
"state_count": s.state_count
|
||||||
|
@ -452,13 +461,13 @@ def on_state_after(celeste, before_out):
|
||||||
"memory": memory,
|
"memory": memory,
|
||||||
"episode_number": episode_number,
|
"episode_number": episode_number,
|
||||||
"steps_done": steps_done
|
"steps_done": steps_done
|
||||||
}, run_data_path / "checkpoint.torch")
|
}, model_save_path)
|
||||||
|
|
||||||
|
|
||||||
# Clean up screenshots
|
# Clean up screenshots
|
||||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||||
|
|
||||||
target = run_data_path / Path(f"screenshots/{episode_number}")
|
target = screenshot_dir / Path(f"{episode_number}")
|
||||||
target.mkdir(parents = True)
|
target.mkdir(parents = True)
|
||||||
|
|
||||||
for s in shots:
|
for s in shots:
|
||||||
|
@ -466,8 +475,6 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
# Save a prediction graph
|
# Save a prediction graph
|
||||||
if episode_number % image_interval == 0:
|
if episode_number % image_interval == 0:
|
||||||
p = run_data_path / Path("model_images")
|
|
||||||
p.mkdir(parents = True, exist_ok = True)
|
|
||||||
torch.save({
|
torch.save({
|
||||||
"policy_state_dict": policy_net.state_dict(),
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
"target_state_dict": target_net.state_dict(),
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
@ -475,10 +482,10 @@ def on_state_after(celeste, before_out):
|
||||||
"memory": memory,
|
"memory": memory,
|
||||||
"episode_number": episode_number,
|
"episode_number": episode_number,
|
||||||
"steps_done": steps_done
|
"steps_done": steps_done
|
||||||
}, p / f"{episode_number}.torch")
|
}, model_archive_dir / f"{episode_number}.torch")
|
||||||
|
|
||||||
|
|
||||||
print("State over, resetting")
|
print("Game over. Resetting.")
|
||||||
episode_number += 1
|
episode_number += 1
|
||||||
celeste.reset()
|
celeste.reset()
|
||||||
|
|
||||||
|
|
Reference in New Issue