diff --git a/celeste/celeste_ai/celeste.py b/celeste/celeste_ai/celeste.py index c7db9d8..164e56f 100755 --- a/celeste/celeste_ai/celeste.py +++ b/celeste/celeste_ai/celeste.py @@ -50,7 +50,6 @@ class Celeste: action_space = [ "left", # move left 0 "right", # move right 1 - #"jump", # jump "jump-l", # jump left 2 "jump-r", # jump right 3 @@ -86,6 +85,13 @@ class Celeste: ] ] + + # Maps room_x, room_y coordinates to stage number. + stage_map = [ + [0, 1, 2, 3, 4] + ] + + def __init__( self, pico_path, @@ -194,9 +200,7 @@ class Celeste: def state(self): try: stage = ( - [ - [0, 1, 2, 3, 4] - ] + Celeste.stage_map [int(self._internal_state["ry"])] [int(self._internal_state["rx"])] ) diff --git a/celeste/celeste_ai/train.py b/celeste/celeste_ai/train.py index 32ea61b..42712da 100644 --- a/celeste/celeste_ai/train.py +++ b/celeste/celeste_ai/train.py @@ -341,10 +341,23 @@ def on_state_after(celeste, before_out): dtype = torch.long ) + finished_stage = False + # No reward if dead if next_state.deaths != 0: pt_next_state = None reward = 0 + # Reward for finishing stage + elif next_state.stage >= 1: + finished_stage = True + reward = next_state.next_point - state.next_point + reward += 1 + + # Add to point counter + for i in range(state.next_point, state.next_point + reward): + point_counter[i] += 1 + + # Regular reward else: pt_next_state = torch.tensor( [getattr(next_state, x) for x in Celeste.state_number_map], @@ -401,7 +414,7 @@ def on_state_after(celeste, before_out): # Move on to the next episode once we reach # a terminal state. - if (next_state.deaths != 0): + if (next_state.deaths != 0 or finished_stage): s = celeste.state with model_train_log.open("a") as f: f.write(json.dumps({