from pathlib import Path import torch from celeste_ai import Celeste from celeste_ai import DQN from celeste_ai.util.screenshots import ScreenshotManager if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") model_save_path = model_data_root / "model.torch" model_data_root.mkdir(parents = True, exist_ok = True) sm = ScreenshotManager( # Where PICO-8 saves screenshots. # Probably your desktop. source = Path("/home/mark/Desktop"), pattern = "hackcel_*.png", target = model_data_root / "screenshots_test" ).clean() # Remove old screenshots compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) episode_number = 0 # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) policy_net = DQN( n_observations, n_actions ).to(compute_device) # Load model if one exists checkpoint = torch.load( model_save_path, map_location = compute_device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) def on_state_before(celeste): global steps_done state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = policy_net(pt_state).max(1)[1].view(1, 1).item() str_action = Celeste.action_space[action] print(str_action) celeste.act(str_action) return state, action def on_state_after(celeste, before_out): global episode_number state, action = before_out next_state = celeste.state finished_stage = next_state.stage >= 1 # Move on to the next episode once we reach # a terminal state. if (next_state.deaths != 0 or finished_stage): s = celeste.state sm.move() print("Game over. Resetting.") celeste.reset() episode_number += 1 if __name__ == "__main__": c = Celeste( "resources/pico-8/linux/pico8" ) c.update_loop( on_state_before, on_state_after )