from pathlib import Path import torch import json from celeste_ai import Celeste from celeste_ai import DQN model_data_root = Path("model_data/solved_1") compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) policy_net = DQN( n_observations, n_actions ).to(compute_device) k = (model_data_root / "model_archive").iterdir() i = 0 state_history = [] current_path = None def next_image(): global policy_net global current_path global i i += 1 try: current_path = k.__next__() except StopIteration: return False print(f"Pathing {current_path} ({i})") # Load model if one exists checkpoint = torch.load( current_path, map_location = compute_device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) next_image() def on_state_before(celeste): global steps_done state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = policy_net(pt_state).max(1)[1].view(1, 1).item() str_action = Celeste.action_space[action] celeste.act(str_action) return state, action def on_state_after(celeste, before_out): global episode_number global state_history state, action = before_out next_state = celeste.state finished_stage = next_state.stage >= 1 state_history.append({ "xpos": state.xpos, "ypos": state.ypos, "action": Celeste.action_space[action] }) # Move on to the next episode once we reach # a terminal state. if (next_state.deaths != 0 or finished_stage): with (model_data_root / "paths.json").open("a") as f: f.write(json.dumps( { "hist": state_history, "current_image": str(current_path) } ) + "\n") state_history = [] k = next_image() if k is False: raise Exception("Done.") print("Game over. Resetting.") celeste.reset() c = Celeste( "resources/pico-8/linux/pico8" ) c.update_loop( on_state_before, on_state_after )