From 6b7abc49a6ee601944890d1e232fc690bcb7dbf2 Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 4 Mar 2023 13:33:57 -0800 Subject: [PATCH] Added path tracer --- celeste/celeste_ai/paths.py | 119 ++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 celeste/celeste_ai/paths.py diff --git a/celeste/celeste_ai/paths.py b/celeste/celeste_ai/paths.py new file mode 100644 index 0000000..210e5a3 --- /dev/null +++ b/celeste/celeste_ai/paths.py @@ -0,0 +1,119 @@ +from pathlib import Path +import torch +import json + +from celeste_ai import Celeste +from celeste_ai import DQN + + + +model_data_root = Path("model_data/solved_1") + +compute_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" +) + + +# Celeste env properties +n_observations = len(Celeste.state_number_map) +n_actions = len(Celeste.action_space) + +policy_net = DQN( + n_observations, + n_actions +).to(compute_device) + +k = (model_data_root / "model_archive").iterdir() +i = 0 + +state_history = [] +current_path = None + +def next_image(): + global policy_net + global current_path + global i + i += 1 + + try: + current_path = k.__next__() + except StopIteration: + return False + + print(f"Pathing {current_path} ({i})") + + # Load model if one exists + checkpoint = torch.load( + current_path, + map_location = compute_device + ) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + + +next_image() + +def on_state_before(celeste): + global steps_done + + state = celeste.state + + pt_state = torch.tensor( + [getattr(state, x) for x in Celeste.state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) + + + action = policy_net(pt_state).max(1)[1].view(1, 1).item() + str_action = Celeste.action_space[action] + + celeste.act(str_action) + + return state, action + + +def on_state_after(celeste, before_out): + global episode_number + global state_history + + state, action = before_out + next_state = celeste.state + finished_stage = next_state.stage >= 1 + + state_history.append({ + "xpos": state.xpos, + "ypos": state.ypos, + "action": Celeste.action_space[action] + }) + + # Move on to the next episode once we reach + # a terminal state. + if (next_state.deaths != 0 or finished_stage): + + with (model_data_root / "paths.json").open("a") as f: + f.write(json.dumps( + { + "hist": state_history, + "current_image": str(current_path) + } + ) + "\n") + + state_history = [] + k = next_image() + + if k is False: + raise Exception("Done.") + + print("Game over. Resetting.") + celeste.reset() + + + +c = Celeste( + "resources/pico-8/linux/pico8" +) + +c.update_loop( + on_state_before, + on_state_after +)