Mark
/
celeste-ai
Archived
1
0
Fork 0

Added path configuration

master
Mark 2023-02-18 19:35:46 -08:00
parent 610e5eef92
commit 4876664178
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 18 additions and 11 deletions

View File

@ -9,8 +9,17 @@ import torch
from celeste import Celeste
run_data_path = Path("out")
run_data_path.mkdir(parents = True, exist_ok = True)
# Where to read/write model data.
model_data_root = Path("model_data")
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
@ -307,9 +316,9 @@ def optimize_model():
episode_number = 0
if (run_data_path/"checkpoint.torch").is_file():
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load((run_data_path/"checkpoint.torch"))
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@ -437,7 +446,7 @@ def on_state_after(celeste, before_out):
# a terminal state.
if (next_state.deaths != 0):
s = celeste.state
with open(run_data_path / "train.log", "a") as f:
with model_train_log.open("a") as f:
f.write(json.dumps({
"checkpoints": s.next_point,
"state_count": s.state_count
@ -452,13 +461,13 @@ def on_state_after(celeste, before_out):
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, run_data_path / "checkpoint.torch")
}, model_save_path)
# Clean up screenshots
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
target = run_data_path / Path(f"screenshots/{episode_number}")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
@ -466,8 +475,6 @@ def on_state_after(celeste, before_out):
# Save a prediction graph
if episode_number % image_interval == 0:
p = run_data_path / Path("model_images")
p.mkdir(parents = True, exist_ok = True)
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
@ -475,10 +482,10 @@ def on_state_after(celeste, before_out):
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, p / f"{episode_number}.torch")
}, model_archive_dir / f"{episode_number}.torch")
print("State over, resetting")
print("Game over. Resetting.")
episode_number += 1
celeste.reset()