Added test script
parent
6b7abc49a6
commit
8420e719d8
|
@ -0,0 +1,100 @@
|
|||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from celeste_ai import Celeste
|
||||
from celeste_ai import DQN
|
||||
from celeste_ai.util.screenshots import ScreenshotManager
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
sm = ScreenshotManager(
|
||||
# Where PICO-8 saves screenshots.
|
||||
# Probably your desktop.
|
||||
source = Path("/home/mark/Desktop"),
|
||||
pattern = "hackcel_*.png",
|
||||
target = model_data_root / "screenshots_test"
|
||||
).clean() # Remove old screenshots
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
episode_number = 0
|
||||
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
|
||||
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(
|
||||
model_save_path,
|
||||
map_location = compute_device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
[getattr(state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||
str_action = Celeste.action_space[action]
|
||||
|
||||
print(str_action)
|
||||
celeste.act(str_action)
|
||||
|
||||
return state, action
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
|
||||
state, action = before_out
|
||||
next_state = celeste.state
|
||||
finished_stage = next_state.stage >= 1
|
||||
|
||||
|
||||
# Move on to the next episode once we reach
|
||||
# a terminal state.
|
||||
if (next_state.deaths != 0 or finished_stage):
|
||||
s = celeste.state
|
||||
|
||||
sm.move()
|
||||
|
||||
|
||||
print("Game over. Resetting.")
|
||||
celeste.reset()
|
||||
episode_number += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Celeste(
|
||||
"resources/pico-8/linux/pico8"
|
||||
)
|
||||
|
||||
c.update_loop(
|
||||
on_state_before,
|
||||
on_state_after
|
||||
)
|
Reference in New Issue