Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/celeste_ai/paths.py

120 lines
2.0 KiB
Python

from pathlib import Path
import torch
import json
from celeste_ai import Celeste
from celeste_ai import DQN
model_data_root = Path("model_data/solved_1")
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
k = (model_data_root / "model_archive").iterdir()
i = 0
state_history = []
current_path = None
def next_image():
global policy_net
global current_path
global i
i += 1
try:
current_path = k.__next__()
except StopIteration:
return False
print(f"Pathing {current_path} ({i})")
# Load model if one exists
checkpoint = torch.load(
current_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
next_image()
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
global state_history
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
state_history.append({
"xpos": state.xpos,
"ypos": state.ypos,
"action": Celeste.action_space[action]
})
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
with (model_data_root / "paths.json").open("a") as f:
f.write(json.dumps(
{
"hist": state_history,
"current_image": str(current_path)
}
) + "\n")
state_history = []
k = next_image()
if k is False:
raise Exception("Done.")
print("Game over. Resetting.")
celeste.reset()
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)