Compare commits
No commits in common. "8420e719d8d93b8a65325902a544729b26591fda" and "ee232329b7d3df6c69f2f9ec7990a773ec33d3b9" have entirely different histories.
8420e719d8
...
ee232329b7
|
@ -1,119 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
import json
|
|
||||||
|
|
||||||
from celeste_ai import Celeste
|
|
||||||
from celeste_ai import DQN
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model_data_root = Path("model_data/solved_1")
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Celeste env properties
|
|
||||||
n_observations = len(Celeste.state_number_map)
|
|
||||||
n_actions = len(Celeste.action_space)
|
|
||||||
|
|
||||||
policy_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
k = (model_data_root / "model_archive").iterdir()
|
|
||||||
i = 0
|
|
||||||
|
|
||||||
state_history = []
|
|
||||||
current_path = None
|
|
||||||
|
|
||||||
def next_image():
|
|
||||||
global policy_net
|
|
||||||
global current_path
|
|
||||||
global i
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_path = k.__next__()
|
|
||||||
except StopIteration:
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"Pathing {current_path} ({i})")
|
|
||||||
|
|
||||||
# Load model if one exists
|
|
||||||
checkpoint = torch.load(
|
|
||||||
current_path,
|
|
||||||
map_location = compute_device
|
|
||||||
)
|
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
|
||||||
|
|
||||||
|
|
||||||
next_image()
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
|
||||||
global steps_done
|
|
||||||
|
|
||||||
state = celeste.state
|
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
|
||||||
[getattr(state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
|
|
||||||
celeste.act(str_action)
|
|
||||||
|
|
||||||
return state, action
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
|
||||||
global episode_number
|
|
||||||
global state_history
|
|
||||||
|
|
||||||
state, action = before_out
|
|
||||||
next_state = celeste.state
|
|
||||||
finished_stage = next_state.stage >= 1
|
|
||||||
|
|
||||||
state_history.append({
|
|
||||||
"xpos": state.xpos,
|
|
||||||
"ypos": state.ypos,
|
|
||||||
"action": Celeste.action_space[action]
|
|
||||||
})
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (next_state.deaths != 0 or finished_stage):
|
|
||||||
|
|
||||||
with (model_data_root / "paths.json").open("a") as f:
|
|
||||||
f.write(json.dumps(
|
|
||||||
{
|
|
||||||
"hist": state_history,
|
|
||||||
"current_image": str(current_path)
|
|
||||||
}
|
|
||||||
) + "\n")
|
|
||||||
|
|
||||||
state_history = []
|
|
||||||
k = next_image()
|
|
||||||
|
|
||||||
if k is False:
|
|
||||||
raise Exception("Done.")
|
|
||||||
|
|
||||||
print("Game over. Resetting.")
|
|
||||||
celeste.reset()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
c = Celeste(
|
|
||||||
"resources/pico-8/linux/pico8"
|
|
||||||
)
|
|
||||||
|
|
||||||
c.update_loop(
|
|
||||||
on_state_before,
|
|
||||||
on_state_after
|
|
||||||
)
|
|
|
@ -1,100 +0,0 @@
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from celeste_ai import Celeste
|
|
||||||
from celeste_ai import DQN
|
|
||||||
from celeste_ai.util.screenshots import ScreenshotManager
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Where to read/write model data.
|
|
||||||
model_data_root = Path("model_data/current")
|
|
||||||
|
|
||||||
model_save_path = model_data_root / "model.torch"
|
|
||||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
|
||||||
|
|
||||||
|
|
||||||
sm = ScreenshotManager(
|
|
||||||
# Where PICO-8 saves screenshots.
|
|
||||||
# Probably your desktop.
|
|
||||||
source = Path("/home/mark/Desktop"),
|
|
||||||
pattern = "hackcel_*.png",
|
|
||||||
target = model_data_root / "screenshots_test"
|
|
||||||
).clean() # Remove old screenshots
|
|
||||||
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
episode_number = 0
|
|
||||||
|
|
||||||
# Celeste env properties
|
|
||||||
n_observations = len(Celeste.state_number_map)
|
|
||||||
n_actions = len(Celeste.action_space)
|
|
||||||
|
|
||||||
policy_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
|
|
||||||
# Load model if one exists
|
|
||||||
checkpoint = torch.load(
|
|
||||||
model_save_path,
|
|
||||||
map_location = compute_device
|
|
||||||
)
|
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
|
||||||
global steps_done
|
|
||||||
|
|
||||||
state = celeste.state
|
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
|
||||||
[getattr(state, x) for x in Celeste.state_number_map],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
|
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
|
|
||||||
print(str_action)
|
|
||||||
celeste.act(str_action)
|
|
||||||
|
|
||||||
return state, action
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
|
||||||
global episode_number
|
|
||||||
|
|
||||||
state, action = before_out
|
|
||||||
next_state = celeste.state
|
|
||||||
finished_stage = next_state.stage >= 1
|
|
||||||
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (next_state.deaths != 0 or finished_stage):
|
|
||||||
s = celeste.state
|
|
||||||
|
|
||||||
sm.move()
|
|
||||||
|
|
||||||
|
|
||||||
print("Game over. Resetting.")
|
|
||||||
celeste.reset()
|
|
||||||
episode_number += 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
c = Celeste(
|
|
||||||
"resources/pico-8/linux/pico8"
|
|
||||||
)
|
|
||||||
|
|
||||||
c.update_loop(
|
|
||||||
on_state_before,
|
|
||||||
on_state_after
|
|
||||||
)
|
|
Reference in New Issue