Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

...

2 Commits

Author SHA1 Message Date
Mark 8420e719d8
Added test script 2023-03-04 13:35:34 -08:00
Mark 6b7abc49a6
Added path tracer 2023-03-04 13:33:57 -08:00
2 changed files with 219 additions and 0 deletions

119
celeste/celeste_ai/paths.py Normal file
View File

@ -0,0 +1,119 @@
from pathlib import Path
import torch
import json
from celeste_ai import Celeste
from celeste_ai import DQN
model_data_root = Path("model_data/solved_1")
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
k = (model_data_root / "model_archive").iterdir()
i = 0
state_history = []
current_path = None
def next_image():
global policy_net
global current_path
global i
i += 1
try:
current_path = k.__next__()
except StopIteration:
return False
print(f"Pathing {current_path} ({i})")
# Load model if one exists
checkpoint = torch.load(
current_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
next_image()
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
global state_history
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
state_history.append({
"xpos": state.xpos,
"ypos": state.ypos,
"action": Celeste.action_space[action]
})
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
with (model_data_root / "paths.json").open("a") as f:
f.write(json.dumps(
{
"hist": state_history,
"current_image": str(current_path)
}
) + "\n")
state_history = []
k = next_image()
if k is False:
raise Exception("Done.")
print("Game over. Resetting.")
celeste.reset()
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)

100
celeste/celeste_ai/test.py Normal file
View File

@ -0,0 +1,100 @@
from pathlib import Path
import torch
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_data_root.mkdir(parents = True, exist_ok = True)
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots_test"
).clean() # Remove old screenshots
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
episode_number = 0
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
print(str_action)
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
s = celeste.state
sm.move()
print("Game over. Resetting.")
celeste.reset()
episode_number += 1
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)