Added screenshot_source
parent
c372ef8cc7
commit
dc8f0ace68
|
@ -15,6 +15,10 @@ if __name__ == "__main__":
|
|||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
# Where PICO-8 saves screenshots.
|
||||
# Probably your desktop.
|
||||
screenshot_source = Path("/home/mark/Desktop")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
model_train_log = model_data_root / "train_log"
|
||||
|
@ -25,7 +29,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
# Remove old screenshots
|
||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||
shots = screenshot_source.glob("hackcel_*.png")
|
||||
for s in shots:
|
||||
s.unlink()
|
||||
|
||||
|
@ -280,8 +284,6 @@ def optimize_model():
|
|||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
||||
# Conversion to pytorch
|
||||
|
||||
state = celeste.state
|
||||
|
||||
pt_state = torch.tensor(
|
||||
|
@ -347,7 +349,7 @@ def on_state_after(celeste, before_out):
|
|||
pt_next_state = None
|
||||
reward = 0
|
||||
|
||||
# Reward for finishing stage
|
||||
# Reward for finishing a stage
|
||||
elif next_state.stage >= 1:
|
||||
finished_stage = True
|
||||
reward = next_state.next_point - state.next_point
|
||||
|
@ -370,6 +372,7 @@ def on_state_after(celeste, before_out):
|
|||
if state.next_point == next_state.next_point:
|
||||
reward = 0
|
||||
else:
|
||||
print(f"Got point {state.next_point}")
|
||||
# Reward for reaching a point
|
||||
reward = next_state.next_point - state.next_point
|
||||
|
||||
|
@ -377,6 +380,14 @@ def on_state_after(celeste, before_out):
|
|||
for i in range(state.next_point, state.next_point + reward):
|
||||
point_counter[i] += 1
|
||||
|
||||
# Strawberry reward
|
||||
if next_state.berries[state.stage] and not state.berries[state.stage]:
|
||||
print(f"Got stage {state.stage} bonus")
|
||||
reward += 1
|
||||
|
||||
|
||||
|
||||
|
||||
reward = reward * 10
|
||||
pt_reward = torch.tensor([reward], device = compute_device)
|
||||
|
||||
|
@ -446,7 +457,7 @@ def on_state_after(celeste, before_out):
|
|||
|
||||
|
||||
# Clean up screenshots
|
||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||
shots = screenshot_source.glob("hackcel_*.png")
|
||||
|
||||
target = screenshot_dir / Path(f"{episode_number}")
|
||||
target.mkdir(parents = True)
|
||||
|
|
Reference in New Issue