Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

..

No commits in common. "f40b58508e910cbf99ebf8163c91220ee9cba1aa" and "c372ef8cc7d25a8c532a98ac8b6ffb414808be62" have entirely different histories.

3 changed files with 5 additions and 34 deletions

View File

@ -28,9 +28,6 @@ class CelesteState(NamedTuple):
# Number of deaths since game start # Number of deaths since game start
deaths: int deaths: int
# If an index is true, we got a strawberry on that stage.
berries: list[bool]
# Distance to next point # Distance to next point
dist: float dist: float
@ -226,7 +223,6 @@ class Celeste:
xvel = float(self._internal_state["vx"]), xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]), yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]), deaths = int(self._internal_state["dc"]),
berries = [x == "t" for x in self._internal_state["fr"][1:]],
dist = self._dist, dist = self._dist,
next_point = self._next_checkpoint_idx, next_point = self._next_checkpoint_idx,

View File

@ -15,10 +15,6 @@ if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data/current") model_data_root = Path("model_data/current")
# Where PICO-8 saves screenshots.
# Probably your desktop.
screenshot_source = Path("/home/mark/Desktop")
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log" model_train_log = model_data_root / "train_log"
@ -29,7 +25,7 @@ if __name__ == "__main__":
# Remove old screenshots # Remove old screenshots
shots = screenshot_source.glob("hackcel_*.png") shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
for s in shots: for s in shots:
s.unlink() s.unlink()
@ -284,6 +280,8 @@ def optimize_model():
def on_state_before(celeste): def on_state_before(celeste):
global steps_done global steps_done
# Conversion to pytorch
state = celeste.state state = celeste.state
pt_state = torch.tensor( pt_state = torch.tensor(
@ -349,7 +347,7 @@ def on_state_after(celeste, before_out):
pt_next_state = None pt_next_state = None
reward = 0 reward = 0
# Reward for finishing a stage # Reward for finishing stage
elif next_state.stage >= 1: elif next_state.stage >= 1:
finished_stage = True finished_stage = True
reward = next_state.next_point - state.next_point reward = next_state.next_point - state.next_point
@ -372,7 +370,6 @@ def on_state_after(celeste, before_out):
if state.next_point == next_state.next_point: if state.next_point == next_state.next_point:
reward = 0 reward = 0
else: else:
print(f"Got point {state.next_point}")
# Reward for reaching a point # Reward for reaching a point
reward = next_state.next_point - state.next_point reward = next_state.next_point - state.next_point
@ -380,14 +377,6 @@ def on_state_after(celeste, before_out):
for i in range(state.next_point, state.next_point + reward): for i in range(state.next_point, state.next_point + reward):
point_counter[i] += 1 point_counter[i] += 1
# Strawberry reward
if next_state.berries[state.stage] and not state.berries[state.stage]:
print(f"Got stage {state.stage} bonus")
reward += 1
reward = reward * 10 reward = reward * 10
pt_reward = torch.tensor([reward], device = compute_device) pt_reward = torch.tensor([reward], device = compute_device)
@ -457,7 +446,7 @@ def on_state_after(celeste, before_out):
# Clean up screenshots # Clean up screenshots
shots = screenshot_source.glob("hackcel_*.png") shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}") target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True) target.mkdir(parents = True)

View File

@ -1275,26 +1275,12 @@ function _update()
hack_has_sent_first_message = true hack_has_sent_first_message = true
out_string = "dc:" .. tostr(deaths) .. ";" out_string = "dc:" .. tostr(deaths) .. ";"
-- Dash status
if hack_can_dash then if hack_can_dash then
out_string = out_string .. "ds:t;" out_string = out_string .. "ds:t;"
else else
out_string = out_string .. "ds:f;" out_string = out_string .. "ds:f;"
end end
-- Fruit status
out_string = out_string .. "fr:"
for i = 0,29 do
if got_fruit[i] then
out_string = out_string .. "t"
else
out_string = out_string .. "f"
end
end
out_string = out_string .. ";"
for k, v in pairs(hack_player_state) do for k, v in pairs(hack_player_state) do
out_string = out_string .. k ..":" .. v .. ";" out_string = out_string .. k ..":" .. v .. ";"
end end