Mark
/
celeste-ai
Archived
1
0
Fork 0

Added test script

master
Mark 2023-03-04 13:35:34 -08:00
parent 6b7abc49a6
commit 8420e719d8
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 100 additions and 0 deletions

100
celeste/celeste_ai/test.py Normal file
View File

@ -0,0 +1,100 @@
from pathlib import Path
import torch
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_data_root.mkdir(parents = True, exist_ok = True)
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots_test"
).clean() # Remove old screenshots
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
episode_number = 0
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
print(str_action)
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
s = celeste.state
sm.move()
print("Game over. Resetting.")
celeste.reset()
episode_number += 1
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)