Added stage completion handling
parent
4ff32b91ea
commit
589f41c205
|
@ -50,7 +50,6 @@ class Celeste:
|
|||
action_space = [
|
||||
"left", # move left 0
|
||||
"right", # move right 1
|
||||
#"jump", # jump
|
||||
"jump-l", # jump left 2
|
||||
"jump-r", # jump right 3
|
||||
|
||||
|
@ -86,6 +85,13 @@ class Celeste:
|
|||
]
|
||||
]
|
||||
|
||||
|
||||
# Maps room_x, room_y coordinates to stage number.
|
||||
stage_map = [
|
||||
[0, 1, 2, 3, 4]
|
||||
]
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pico_path,
|
||||
|
@ -194,9 +200,7 @@ class Celeste:
|
|||
def state(self):
|
||||
try:
|
||||
stage = (
|
||||
[
|
||||
[0, 1, 2, 3, 4]
|
||||
]
|
||||
Celeste.stage_map
|
||||
[int(self._internal_state["ry"])]
|
||||
[int(self._internal_state["rx"])]
|
||||
)
|
||||
|
|
|
@ -341,10 +341,23 @@ def on_state_after(celeste, before_out):
|
|||
dtype = torch.long
|
||||
)
|
||||
|
||||
finished_stage = False
|
||||
# No reward if dead
|
||||
if next_state.deaths != 0:
|
||||
pt_next_state = None
|
||||
reward = 0
|
||||
|
||||
# Reward for finishing stage
|
||||
elif next_state.stage >= 1:
|
||||
finished_stage = True
|
||||
reward = next_state.next_point - state.next_point
|
||||
reward += 1
|
||||
|
||||
# Add to point counter
|
||||
for i in range(state.next_point, state.next_point + reward):
|
||||
point_counter[i] += 1
|
||||
|
||||
# Regular reward
|
||||
else:
|
||||
pt_next_state = torch.tensor(
|
||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||
|
@ -401,7 +414,7 @@ def on_state_after(celeste, before_out):
|
|||
|
||||
# Move on to the next episode once we reach
|
||||
# a terminal state.
|
||||
if (next_state.deaths != 0):
|
||||
if (next_state.deaths != 0 or finished_stage):
|
||||
s = celeste.state
|
||||
with model_train_log.open("a") as f:
|
||||
f.write(json.dumps({
|
||||
|
|
Reference in New Issue