Added path configuration
parent
610e5eef92
commit
4876664178
|
@ -9,8 +9,17 @@ import torch
|
|||
from celeste import Celeste
|
||||
|
||||
|
||||
run_data_path = Path("out")
|
||||
run_data_path.mkdir(parents = True, exist_ok = True)
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
model_train_log = model_data_root / "train_log"
|
||||
screenshot_dir = model_data_root / "screenshots"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
@ -307,9 +316,9 @@ def optimize_model():
|
|||
episode_number = 0
|
||||
|
||||
|
||||
if (run_data_path/"checkpoint.torch").is_file():
|
||||
if model_save_path.is_file():
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load((run_data_path/"checkpoint.torch"))
|
||||
checkpoint = torch.load(model_save_path)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
@ -437,7 +446,7 @@ def on_state_after(celeste, before_out):
|
|||
# a terminal state.
|
||||
if (next_state.deaths != 0):
|
||||
s = celeste.state
|
||||
with open(run_data_path / "train.log", "a") as f:
|
||||
with model_train_log.open("a") as f:
|
||||
f.write(json.dumps({
|
||||
"checkpoints": s.next_point,
|
||||
"state_count": s.state_count
|
||||
|
@ -452,13 +461,13 @@ def on_state_after(celeste, before_out):
|
|||
"memory": memory,
|
||||
"episode_number": episode_number,
|
||||
"steps_done": steps_done
|
||||
}, run_data_path / "checkpoint.torch")
|
||||
}, model_save_path)
|
||||
|
||||
|
||||
# Clean up screenshots
|
||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||
|
||||
target = run_data_path / Path(f"screenshots/{episode_number}")
|
||||
target = screenshot_dir / Path(f"{episode_number}")
|
||||
target.mkdir(parents = True)
|
||||
|
||||
for s in shots:
|
||||
|
@ -466,8 +475,6 @@ def on_state_after(celeste, before_out):
|
|||
|
||||
# Save a prediction graph
|
||||
if episode_number % image_interval == 0:
|
||||
p = run_data_path / Path("model_images")
|
||||
p.mkdir(parents = True, exist_ok = True)
|
||||
torch.save({
|
||||
"policy_state_dict": policy_net.state_dict(),
|
||||
"target_state_dict": target_net.state_dict(),
|
||||
|
@ -475,10 +482,10 @@ def on_state_after(celeste, before_out):
|
|||
"memory": memory,
|
||||
"episode_number": episode_number,
|
||||
"steps_done": steps_done
|
||||
}, p / f"{episode_number}.torch")
|
||||
}, model_archive_dir / f"{episode_number}.torch")
|
||||
|
||||
|
||||
print("State over, resetting")
|
||||
print("Game over. Resetting.")
|
||||
episode_number += 1
|
||||
celeste.reset()
|
||||
|
||||
|
|
Reference in New Issue