Compare commits
2 Commits
c372ef8cc7
...
f40b58508e
Author | SHA1 | Date |
---|---|---|
Mark | f40b58508e | |
Mark | dc8f0ace68 |
|
@ -28,6 +28,9 @@ class CelesteState(NamedTuple):
|
||||||
# Number of deaths since game start
|
# Number of deaths since game start
|
||||||
deaths: int
|
deaths: int
|
||||||
|
|
||||||
|
# If an index is true, we got a strawberry on that stage.
|
||||||
|
berries: list[bool]
|
||||||
|
|
||||||
# Distance to next point
|
# Distance to next point
|
||||||
dist: float
|
dist: float
|
||||||
|
|
||||||
|
@ -223,6 +226,7 @@ class Celeste:
|
||||||
xvel = float(self._internal_state["vx"]),
|
xvel = float(self._internal_state["vx"]),
|
||||||
yvel = float(self._internal_state["vy"]),
|
yvel = float(self._internal_state["vy"]),
|
||||||
deaths = int(self._internal_state["dc"]),
|
deaths = int(self._internal_state["dc"]),
|
||||||
|
berries = [x == "t" for x in self._internal_state["fr"][1:]],
|
||||||
|
|
||||||
dist = self._dist,
|
dist = self._dist,
|
||||||
next_point = self._next_checkpoint_idx,
|
next_point = self._next_checkpoint_idx,
|
||||||
|
|
|
@ -15,6 +15,10 @@ if __name__ == "__main__":
|
||||||
# Where to read/write model data.
|
# Where to read/write model data.
|
||||||
model_data_root = Path("model_data/current")
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
|
# Where PICO-8 saves screenshots.
|
||||||
|
# Probably your desktop.
|
||||||
|
screenshot_source = Path("/home/mark/Desktop")
|
||||||
|
|
||||||
model_save_path = model_data_root / "model.torch"
|
model_save_path = model_data_root / "model.torch"
|
||||||
model_archive_dir = model_data_root / "model_archive"
|
model_archive_dir = model_data_root / "model_archive"
|
||||||
model_train_log = model_data_root / "train_log"
|
model_train_log = model_data_root / "train_log"
|
||||||
|
@ -25,7 +29,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# Remove old screenshots
|
# Remove old screenshots
|
||||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
shots = screenshot_source.glob("hackcel_*.png")
|
||||||
for s in shots:
|
for s in shots:
|
||||||
s.unlink()
|
s.unlink()
|
||||||
|
|
||||||
|
@ -280,8 +284,6 @@ def optimize_model():
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
global steps_done
|
global steps_done
|
||||||
|
|
||||||
# Conversion to pytorch
|
|
||||||
|
|
||||||
state = celeste.state
|
state = celeste.state
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
pt_state = torch.tensor(
|
||||||
|
@ -347,7 +349,7 @@ def on_state_after(celeste, before_out):
|
||||||
pt_next_state = None
|
pt_next_state = None
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
# Reward for finishing stage
|
# Reward for finishing a stage
|
||||||
elif next_state.stage >= 1:
|
elif next_state.stage >= 1:
|
||||||
finished_stage = True
|
finished_stage = True
|
||||||
reward = next_state.next_point - state.next_point
|
reward = next_state.next_point - state.next_point
|
||||||
|
@ -370,6 +372,7 @@ def on_state_after(celeste, before_out):
|
||||||
if state.next_point == next_state.next_point:
|
if state.next_point == next_state.next_point:
|
||||||
reward = 0
|
reward = 0
|
||||||
else:
|
else:
|
||||||
|
print(f"Got point {state.next_point}")
|
||||||
# Reward for reaching a point
|
# Reward for reaching a point
|
||||||
reward = next_state.next_point - state.next_point
|
reward = next_state.next_point - state.next_point
|
||||||
|
|
||||||
|
@ -377,6 +380,14 @@ def on_state_after(celeste, before_out):
|
||||||
for i in range(state.next_point, state.next_point + reward):
|
for i in range(state.next_point, state.next_point + reward):
|
||||||
point_counter[i] += 1
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
# Strawberry reward
|
||||||
|
if next_state.berries[state.stage] and not state.berries[state.stage]:
|
||||||
|
print(f"Got stage {state.stage} bonus")
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
reward = reward * 10
|
reward = reward * 10
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
|
||||||
|
@ -446,7 +457,7 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
|
|
||||||
# Clean up screenshots
|
# Clean up screenshots
|
||||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
shots = screenshot_source.glob("hackcel_*.png")
|
||||||
|
|
||||||
target = screenshot_dir / Path(f"{episode_number}")
|
target = screenshot_dir / Path(f"{episode_number}")
|
||||||
target.mkdir(parents = True)
|
target.mkdir(parents = True)
|
||||||
|
|
|
@ -1275,12 +1275,26 @@ function _update()
|
||||||
hack_has_sent_first_message = true
|
hack_has_sent_first_message = true
|
||||||
out_string = "dc:" .. tostr(deaths) .. ";"
|
out_string = "dc:" .. tostr(deaths) .. ";"
|
||||||
|
|
||||||
|
-- Dash status
|
||||||
if hack_can_dash then
|
if hack_can_dash then
|
||||||
out_string = out_string .. "ds:t;"
|
out_string = out_string .. "ds:t;"
|
||||||
else
|
else
|
||||||
out_string = out_string .. "ds:f;"
|
out_string = out_string .. "ds:f;"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
|
-- Fruit status
|
||||||
|
out_string = out_string .. "fr:"
|
||||||
|
for i = 0,29 do
|
||||||
|
if got_fruit[i] then
|
||||||
|
out_string = out_string .. "t"
|
||||||
|
else
|
||||||
|
out_string = out_string .. "f"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
out_string = out_string .. ";"
|
||||||
|
|
||||||
|
|
||||||
for k, v in pairs(hack_player_state) do
|
for k, v in pairs(hack_player_state) do
|
||||||
out_string = out_string .. k ..":" .. v .. ";"
|
out_string = out_string .. k ..":" .. v .. ";"
|
||||||
end
|
end
|
||||||
|
|
Reference in New Issue