Compare commits
2 Commits
ee232329b7
...
8420e719d8
Author | SHA1 | Date |
---|---|---|
Mark | 8420e719d8 | |
Mark | 6b7abc49a6 |
|
@ -0,0 +1,119 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
|
||||||
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model_data_root = Path("model_data/solved_1")
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
k = (model_data_root / "model_archive").iterdir()
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
state_history = []
|
||||||
|
current_path = None
|
||||||
|
|
||||||
|
def next_image():
|
||||||
|
global policy_net
|
||||||
|
global current_path
|
||||||
|
global i
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_path = k.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"Pathing {current_path} ({i})")
|
||||||
|
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load(
|
||||||
|
current_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
next_image()
|
||||||
|
|
||||||
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
celeste.act(str_action)
|
||||||
|
|
||||||
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
global state_history
|
||||||
|
|
||||||
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
finished_stage = next_state.stage >= 1
|
||||||
|
|
||||||
|
state_history.append({
|
||||||
|
"xpos": state.xpos,
|
||||||
|
"ypos": state.ypos,
|
||||||
|
"action": Celeste.action_space[action]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
|
|
||||||
|
with (model_data_root / "paths.json").open("a") as f:
|
||||||
|
f.write(json.dumps(
|
||||||
|
{
|
||||||
|
"hist": state_history,
|
||||||
|
"current_image": str(current_path)
|
||||||
|
}
|
||||||
|
) + "\n")
|
||||||
|
|
||||||
|
state_history = []
|
||||||
|
k = next_image()
|
||||||
|
|
||||||
|
if k is False:
|
||||||
|
raise Exception("Done.")
|
||||||
|
|
||||||
|
print("Game over. Resetting.")
|
||||||
|
celeste.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop(
|
||||||
|
on_state_before,
|
||||||
|
on_state_after
|
||||||
|
)
|
|
@ -0,0 +1,100 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from celeste_ai import Celeste
|
||||||
|
from celeste_ai import DQN
|
||||||
|
from celeste_ai.util.screenshots import ScreenshotManager
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Where to read/write model data.
|
||||||
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
|
model_save_path = model_data_root / "model.torch"
|
||||||
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
sm = ScreenshotManager(
|
||||||
|
# Where PICO-8 saves screenshots.
|
||||||
|
# Probably your desktop.
|
||||||
|
source = Path("/home/mark/Desktop"),
|
||||||
|
pattern = "hackcel_*.png",
|
||||||
|
target = model_data_root / "screenshots_test"
|
||||||
|
).clean() # Remove old screenshots
|
||||||
|
|
||||||
|
|
||||||
|
compute_device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_number = 0
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = len(Celeste.state_number_map)
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load(
|
||||||
|
model_save_path,
|
||||||
|
map_location = compute_device
|
||||||
|
)
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_before(celeste):
|
||||||
|
global steps_done
|
||||||
|
|
||||||
|
state = celeste.state
|
||||||
|
|
||||||
|
pt_state = torch.tensor(
|
||||||
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
print(str_action)
|
||||||
|
celeste.act(str_action)
|
||||||
|
|
||||||
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
|
||||||
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
finished_stage = next_state.stage >= 1
|
||||||
|
|
||||||
|
|
||||||
|
# Move on to the next episode once we reach
|
||||||
|
# a terminal state.
|
||||||
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
|
s = celeste.state
|
||||||
|
|
||||||
|
sm.move()
|
||||||
|
|
||||||
|
|
||||||
|
print("Game over. Resetting.")
|
||||||
|
celeste.reset()
|
||||||
|
episode_number += 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
c = Celeste(
|
||||||
|
"resources/pico-8/linux/pico8"
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop(
|
||||||
|
on_state_before,
|
||||||
|
on_state_after
|
||||||
|
)
|
Reference in New Issue