Compare commits
2 Commits
ee232329b7
...
8420e719d8
Author | SHA1 | Date |
---|---|---|
Mark | 8420e719d8 | |
Mark | 6b7abc49a6 |
|
@ -0,0 +1,119 @@
|
|||
from pathlib import Path
|
||||
import torch
|
||||
import json
|
||||
|
||||
from celeste_ai import Celeste
|
||||
from celeste_ai import DQN
|
||||
|
||||
|
||||
|
||||
model_data_root = Path("model_data/solved_1")
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
|
||||
k = (model_data_root / "model_archive").iterdir()
|
||||
i = 0
|
||||
|
||||
state_history = []
|
||||
current_path = None
|
||||
|
||||
def next_image():
|
||||
global policy_net
|
||||
global current_path
|
||||
global i
|
||||
i += 1
|
||||
|
||||
try:
|
||||
current_path = k.__next__()
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
print(f"Pathing {current_path} ({i})")
|
||||
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(
|
||||
current_path,
|
||||
map_location = compute_device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
next_image()
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
[getattr(state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||
str_action = Celeste.action_space[action]
|
||||
|
||||
celeste.act(str_action)
|
||||
|
||||
return state, action
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
global state_history
|
||||
|
||||
state, action = before_out
|
||||
next_state = celeste.state
|
||||
finished_stage = next_state.stage >= 1
|
||||
|
||||
state_history.append({
|
||||
"xpos": state.xpos,
|
||||
"ypos": state.ypos,
|
||||
"action": Celeste.action_space[action]
|
||||
})
|
||||
|
||||
# Move on to the next episode once we reach
|
||||
# a terminal state.
|
||||
if (next_state.deaths != 0 or finished_stage):
|
||||
|
||||
with (model_data_root / "paths.json").open("a") as f:
|
||||
f.write(json.dumps(
|
||||
{
|
||||
"hist": state_history,
|
||||
"current_image": str(current_path)
|
||||
}
|
||||
) + "\n")
|
||||
|
||||
state_history = []
|
||||
k = next_image()
|
||||
|
||||
if k is False:
|
||||
raise Exception("Done.")
|
||||
|
||||
print("Game over. Resetting.")
|
||||
celeste.reset()
|
||||
|
||||
|
||||
|
||||
c = Celeste(
|
||||
"resources/pico-8/linux/pico8"
|
||||
)
|
||||
|
||||
c.update_loop(
|
||||
on_state_before,
|
||||
on_state_after
|
||||
)
|
|
@ -0,0 +1,100 @@
|
|||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from celeste_ai import Celeste
|
||||
from celeste_ai import DQN
|
||||
from celeste_ai.util.screenshots import ScreenshotManager
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
sm = ScreenshotManager(
|
||||
# Where PICO-8 saves screenshots.
|
||||
# Probably your desktop.
|
||||
source = Path("/home/mark/Desktop"),
|
||||
pattern = "hackcel_*.png",
|
||||
target = model_data_root / "screenshots_test"
|
||||
).clean() # Remove old screenshots
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
episode_number = 0
|
||||
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
|
||||
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(
|
||||
model_save_path,
|
||||
map_location = compute_device
|
||||
)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
[getattr(state, x) for x in Celeste.state_number_map],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
||||
str_action = Celeste.action_space[action]
|
||||
|
||||
print(str_action)
|
||||
celeste.act(str_action)
|
||||
|
||||
return state, action
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
|
||||
state, action = before_out
|
||||
next_state = celeste.state
|
||||
finished_stage = next_state.stage >= 1
|
||||
|
||||
|
||||
# Move on to the next episode once we reach
|
||||
# a terminal state.
|
||||
if (next_state.deaths != 0 or finished_stage):
|
||||
s = celeste.state
|
||||
|
||||
sm.move()
|
||||
|
||||
|
||||
print("Game over. Resetting.")
|
||||
celeste.reset()
|
||||
episode_number += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Celeste(
|
||||
"resources/pico-8/linux/pico8"
|
||||
)
|
||||
|
||||
c.update_loop(
|
||||
on_state_before,
|
||||
on_state_after
|
||||
)
|
Reference in New Issue