From dc8f0ace68fe649c44a5cb80394dde8c383434a8 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 24 Feb 2023 21:56:37 -0800 Subject: [PATCH] Added screenshot_source --- celeste/celeste_ai/train.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/celeste/celeste_ai/train.py b/celeste/celeste_ai/train.py index 42712da..5d9fa0e 100644 --- a/celeste/celeste_ai/train.py +++ b/celeste/celeste_ai/train.py @@ -15,6 +15,10 @@ if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") + # Where PICO-8 saves screenshots. + # Probably your desktop. + screenshot_source = Path("/home/mark/Desktop") + model_save_path = model_data_root / "model.torch" model_archive_dir = model_data_root / "model_archive" model_train_log = model_data_root / "train_log" @@ -25,7 +29,7 @@ if __name__ == "__main__": # Remove old screenshots - shots = Path("/home/mark/Desktop").glob("hackcel_*.png") + shots = screenshot_source.glob("hackcel_*.png") for s in shots: s.unlink() @@ -280,8 +284,6 @@ def optimize_model(): def on_state_before(celeste): global steps_done - # Conversion to pytorch - state = celeste.state pt_state = torch.tensor( @@ -347,7 +349,7 @@ def on_state_after(celeste, before_out): pt_next_state = None reward = 0 - # Reward for finishing stage + # Reward for finishing a stage elif next_state.stage >= 1: finished_stage = True reward = next_state.next_point - state.next_point @@ -370,6 +372,7 @@ def on_state_after(celeste, before_out): if state.next_point == next_state.next_point: reward = 0 else: + print(f"Got point {state.next_point}") # Reward for reaching a point reward = next_state.next_point - state.next_point @@ -377,6 +380,14 @@ def on_state_after(celeste, before_out): for i in range(state.next_point, state.next_point + reward): point_counter[i] += 1 + # Strawberry reward + if next_state.berries[state.stage] and not state.berries[state.stage]: + print(f"Got stage {state.stage} bonus") + reward += 1 + + + + reward = reward * 10 pt_reward = torch.tensor([reward], device = compute_device) @@ -446,7 +457,7 @@ def on_state_after(celeste, before_out): # Clean up screenshots - shots = Path("/home/mark/Desktop").glob("hackcel_*.png") + shots = screenshot_source.glob("hackcel_*.png") target = screenshot_dir / Path(f"{episode_number}") target.mkdir(parents = True)