Mark
/
celeste-ai
Archived
1
0
Fork 0

Added stage completion handling

master
Mark 2023-02-24 17:46:07 -08:00
parent 4ff32b91ea
commit 589f41c205
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 22 additions and 5 deletions

View File

@ -50,7 +50,6 @@ class Celeste:
action_space = [
"left", # move left 0
"right", # move right 1
#"jump", # jump
"jump-l", # jump left 2
"jump-r", # jump right 3
@ -86,6 +85,13 @@ class Celeste:
]
]
# Maps room_x, room_y coordinates to stage number.
stage_map = [
[0, 1, 2, 3, 4]
]
def __init__(
self,
pico_path,
@ -194,9 +200,7 @@ class Celeste:
def state(self):
try:
stage = (
[
[0, 1, 2, 3, 4]
]
Celeste.stage_map
[int(self._internal_state["ry"])]
[int(self._internal_state["rx"])]
)

View File

@ -341,10 +341,23 @@ def on_state_after(celeste, before_out):
dtype = torch.long
)
finished_stage = False
# No reward if dead
if next_state.deaths != 0:
pt_next_state = None
reward = 0
# Reward for finishing stage
elif next_state.stage >= 1:
finished_stage = True
reward = next_state.next_point - state.next_point
reward += 1
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
point_counter[i] += 1
# Regular reward
else:
pt_next_state = torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
@ -401,7 +414,7 @@ def on_state_after(celeste, before_out):
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0):
if (next_state.deaths != 0 or finished_stage):
s = celeste.state
with model_train_log.open("a") as f:
f.write(json.dumps({