diff --git a/celeste/main.py b/celeste/main.py index 5b27be5..4f6e5f8 100644 --- a/celeste/main.py +++ b/celeste/main.py @@ -9,8 +9,17 @@ import torch from celeste import Celeste -run_data_path = Path("out") -run_data_path.mkdir(parents = True, exist_ok = True) +# Where to read/write model data. +model_data_root = Path("model_data") + +model_save_path = model_data_root / "model.torch" +model_archive_dir = model_data_root / "model_archive" +model_train_log = model_data_root / "train_log" +screenshot_dir = model_data_root / "screenshots" +model_data_root.mkdir(parents = True, exist_ok = True) +model_archive_dir.mkdir(parents = True, exist_ok = True) +screenshot_dir.mkdir(parents = True, exist_ok = True) + compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" @@ -307,9 +316,9 @@ def optimize_model(): episode_number = 0 -if (run_data_path/"checkpoint.torch").is_file(): +if model_save_path.is_file(): # Load model if one exists - checkpoint = torch.load((run_data_path/"checkpoint.torch")) + checkpoint = torch.load(model_save_path) policy_net.load_state_dict(checkpoint["policy_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) @@ -437,7 +446,7 @@ def on_state_after(celeste, before_out): # a terminal state. if (next_state.deaths != 0): s = celeste.state - with open(run_data_path / "train.log", "a") as f: + with model_train_log.open("a") as f: f.write(json.dumps({ "checkpoints": s.next_point, "state_count": s.state_count @@ -452,13 +461,13 @@ def on_state_after(celeste, before_out): "memory": memory, "episode_number": episode_number, "steps_done": steps_done - }, run_data_path / "checkpoint.torch") + }, model_save_path) # Clean up screenshots shots = Path("/home/mark/Desktop").glob("hackcel_*.png") - target = run_data_path / Path(f"screenshots/{episode_number}") + target = screenshot_dir / Path(f"{episode_number}") target.mkdir(parents = True) for s in shots: @@ -466,8 +475,6 @@ def on_state_after(celeste, before_out): # Save a prediction graph if episode_number % image_interval == 0: - p = run_data_path / Path("model_images") - p.mkdir(parents = True, exist_ok = True) torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), @@ -475,10 +482,10 @@ def on_state_after(celeste, before_out): "memory": memory, "episode_number": episode_number, "steps_done": steps_done - }, p / f"{episode_number}.torch") + }, model_archive_dir / f"{episode_number}.torch") - print("State over, resetting") + print("Game over. Resetting.") episode_number += 1 celeste.reset()