Added test script
parent
6b7abc49a6
commit
8420e719d8
|
@ -0,0 +1,100 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
from celeste_ai.util.screenshots import ScreenshotManager
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Where to read/write model data.
|
||||||
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
|
model_save_path = model_data_root / "model.torch"
|
||||||
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
sm = ScreenshotManager(
|
||||||
|
# Where PICO-8 saves screenshots.
|
||||||
|
# Probably your desktop.
|
||||||
|
source = Path("/home/mark/Desktop"),
|
||||||
|
pattern = "hackcel_*.png",
|
||||||
|
target = model_data_root / "screenshots_test"
|
||||||
|
).clean() # Remove old screenshots
|
||||||
|
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_number = 0
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load(
|
||||||
|
model_save_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
print(str_action)
|
||||||
|
celeste.act(str_action)
|
||||||
|
|
||||||
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
|
||||||
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
finished_stage = next_state.stage >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
|
s = celeste.state
|
||||||
|
|
||||||
|
sm.move()
|
||||||
|
|
||||||
|
|
||||||
|
print("Game over. Resetting.")
|
||||||
|
celeste.reset()
|
||||||
|
episode_number += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop(
|
||||||
|
on_state_before,
|
||||||
|
on_state_after
|
||||||
|
)
|
Reference in New Issue