diff --git a/celeste/celeste_ai/test.py b/celeste/celeste_ai/test.py new file mode 100644 index 0000000..f963eb7 --- /dev/null +++ b/celeste/celeste_ai/test.py @@ -0,0 +1,100 @@ +from pathlib import Path +import torch + +from celeste_ai import Celeste +from celeste_ai import DQN +from celeste_ai.util.screenshots import ScreenshotManager + + +if __name__ == "__main__": + # Where to read/write model data. + model_data_root = Path("model_data/current") + + model_save_path = model_data_root / "model.torch" + model_data_root.mkdir(parents = True, exist_ok = True) + + + sm = ScreenshotManager( + # Where PICO-8 saves screenshots. + # Probably your desktop. + source = Path("/home/mark/Desktop"), + pattern = "hackcel_*.png", + target = model_data_root / "screenshots_test" + ).clean() # Remove old screenshots + + + compute_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + + episode_number = 0 + + # Celeste env properties + n_observations = len(Celeste.state_number_map) + n_actions = len(Celeste.action_space) + + policy_net = DQN( + n_observations, + n_actions + ).to(compute_device) + + + # Load model if one exists + checkpoint = torch.load( + model_save_path, + map_location = compute_device + ) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + + +def on_state_before(celeste): + global steps_done + + state = celeste.state + + pt_state = torch.tensor( + [getattr(state, x) for x in Celeste.state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + + action = policy_net(pt_state).max(1)[1].view(1, 1).item() + str_action = Celeste.action_space[action] + + print(str_action) + celeste.act(str_action) + + return state, action + + +def on_state_after(celeste, before_out): + global episode_number + + state, action = before_out + next_state = celeste.state + finished_stage = next_state.stage >= 1 + + + # Move on to the next episode once we reach + # a terminal state. + if (next_state.deaths != 0 or finished_stage): + s = celeste.state + + sm.move() + + + print("Game over. Resetting.") + celeste.reset() + episode_number += 1 + + +if __name__ == "__main__": + c = Celeste( + "resources/pico-8/linux/pico8" + ) + + c.update_loop( + on_state_before, + on_state_after + )