Mark
/
celeste-ai
Archived
1
0
Fork 0

Added path configuration

master
Mark 2023-02-18 19:35:46 -08:00
parent 610e5eef92
commit 4876664178
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 18 additions and 11 deletions

View File

@ -9,8 +9,17 @@ import torch
from celeste import Celeste from celeste import Celeste
run_data_path = Path("out") # Where to read/write model data.
run_data_path.mkdir(parents = True, exist_ok = True) model_data_root = Path("model_data")
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device( compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu" "cuda" if torch.cuda.is_available() else "cpu"
@ -307,9 +316,9 @@ def optimize_model():
episode_number = 0 episode_number = 0
if (run_data_path/"checkpoint.torch").is_file(): if model_save_path.is_file():
# Load model if one exists # Load model if one exists
checkpoint = torch.load((run_data_path/"checkpoint.torch")) checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@ -437,7 +446,7 @@ def on_state_after(celeste, before_out):
# a terminal state. # a terminal state.
if (next_state.deaths != 0): if (next_state.deaths != 0):
s = celeste.state s = celeste.state
with open(run_data_path / "train.log", "a") as f: with model_train_log.open("a") as f:
f.write(json.dumps({ f.write(json.dumps({
"checkpoints": s.next_point, "checkpoints": s.next_point,
"state_count": s.state_count "state_count": s.state_count
@ -452,13 +461,13 @@ def on_state_after(celeste, before_out):
"memory": memory, "memory": memory,
"episode_number": episode_number, "episode_number": episode_number,
"steps_done": steps_done "steps_done": steps_done
}, run_data_path / "checkpoint.torch") }, model_save_path)
# Clean up screenshots # Clean up screenshots
shots = Path("/home/mark/Desktop").glob("hackcel_*.png") shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
target = run_data_path / Path(f"screenshots/{episode_number}") target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True) target.mkdir(parents = True)
for s in shots: for s in shots:
@ -466,8 +475,6 @@ def on_state_after(celeste, before_out):
# Save a prediction graph # Save a prediction graph
if episode_number % image_interval == 0: if episode_number % image_interval == 0:
p = run_data_path / Path("model_images")
p.mkdir(parents = True, exist_ok = True)
torch.save({ torch.save({
"policy_state_dict": policy_net.state_dict(), "policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(), "target_state_dict": target_net.state_dict(),
@ -475,10 +482,10 @@ def on_state_after(celeste, before_out):
"memory": memory, "memory": memory,
"episode_number": episode_number, "episode_number": episode_number,
"steps_done": steps_done "steps_done": steps_done
}, p / f"{episode_number}.torch") }, model_archive_dir / f"{episode_number}.torch")
print("State over, resetting") print("Game over. Resetting.")
episode_number += 1 episode_number += 1
celeste.reset() celeste.reset()