Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

..

No commits in common. "8420e719d8d93b8a65325902a544729b26591fda" and "ee232329b7d3df6c69f2f9ec7990a773ec33d3b9" have entirely different histories.

2 changed files with 0 additions and 219 deletions

View File

@ -1,119 +0,0 @@
from pathlib import Path
import torch
import json
from celeste_ai import Celeste
from celeste_ai import DQN
model_data_root = Path("model_data/solved_1")
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
k = (model_data_root / "model_archive").iterdir()
i = 0
state_history = []
current_path = None
def next_image():
global policy_net
global current_path
global i
i += 1
try:
current_path = k.__next__()
except StopIteration:
return False
print(f"Pathing {current_path} ({i})")
# Load model if one exists
checkpoint = torch.load(
current_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
next_image()
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
global state_history
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
state_history.append({
"xpos": state.xpos,
"ypos": state.ypos,
"action": Celeste.action_space[action]
})
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
with (model_data_root / "paths.json").open("a") as f:
f.write(json.dumps(
{
"hist": state_history,
"current_image": str(current_path)
}
) + "\n")
state_history = []
k = next_image()
if k is False:
raise Exception("Done.")
print("Game over. Resetting.")
celeste.reset()
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)

View File

@ -1,100 +0,0 @@
from pathlib import Path
import torch
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_data_root.mkdir(parents = True, exist_ok = True)
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots_test"
).clean() # Remove old screenshots
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
episode_number = 0
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
# Load model if one exists
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
def on_state_before(celeste):
global steps_done
state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = policy_net(pt_state).max(1)[1].view(1, 1).item()
str_action = Celeste.action_space[action]
print(str_action)
celeste.act(str_action)
return state, action
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
finished_stage = next_state.stage >= 1
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage):
s = celeste.state
sm.move()
print("Game over. Resetting.")
celeste.reset()
episode_number += 1
if __name__ == "__main__":
c = Celeste(
"resources/pico-8/linux/pico8"
)
c.update_loop(
on_state_before,
on_state_after
)