Added stage completion handling
parent
4ff32b91ea
commit
589f41c205
|
@ -50,7 +50,6 @@ class Celeste:
|
||||||
action_space = [
|
action_space = [
|
||||||
"left", # move left 0
|
"left", # move left 0
|
||||||
"right", # move right 1
|
"right", # move right 1
|
||||||
#"jump", # jump
|
|
||||||
"jump-l", # jump left 2
|
"jump-l", # jump left 2
|
||||||
"jump-r", # jump right 3
|
"jump-r", # jump right 3
|
||||||
|
|
||||||
|
@ -86,6 +85,13 @@ class Celeste:
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Maps room_x, room_y coordinates to stage number.
|
||||||
|
stage_map = [
|
||||||
|
[0, 1, 2, 3, 4]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pico_path,
|
pico_path,
|
||||||
|
@ -194,9 +200,7 @@ class Celeste:
|
||||||
def state(self):
|
def state(self):
|
||||||
try:
|
try:
|
||||||
stage = (
|
stage = (
|
||||||
[
|
Celeste.stage_map
|
||||||
[0, 1, 2, 3, 4]
|
|
||||||
]
|
|
||||||
[int(self._internal_state["ry"])]
|
[int(self._internal_state["ry"])]
|
||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
|
@ -341,10 +341,23 @@ def on_state_after(celeste, before_out):
|
||||||
dtype = torch.long
|
dtype = torch.long
|
||||||
)
|
)
|
||||||
|
|
||||||
|
finished_stage = False
|
||||||
|
# No reward if dead
|
||||||
if next_state.deaths != 0:
|
if next_state.deaths != 0:
|
||||||
pt_next_state = None
|
pt_next_state = None
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
|
# Reward for finishing stage
|
||||||
|
elif next_state.stage >= 1:
|
||||||
|
finished_stage = True
|
||||||
|
reward = next_state.next_point - state.next_point
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
# Add to point counter
|
||||||
|
for i in range(state.next_point, state.next_point + reward):
|
||||||
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
# Regular reward
|
||||||
else:
|
else:
|
||||||
pt_next_state = torch.tensor(
|
pt_next_state = torch.tensor(
|
||||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||||
|
@ -401,7 +414,7 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
# Move on to the next episode once we reach
|
||||||
# a terminal state.
|
# a terminal state.
|
||||||
if (next_state.deaths != 0):
|
if (next_state.deaths != 0 or finished_stage):
|
||||||
s = celeste.state
|
s = celeste.state
|
||||||
with model_train_log.open("a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
|
|
Reference in New Issue