Added configurable checkpoints and better stage complete handling
parent
0b61702677
commit
03135e2ef9
|
@ -77,14 +77,17 @@ class Celeste:
|
||||||
|
|
||||||
# Targets the agent tries to reach.
|
# Targets the agent tries to reach.
|
||||||
# The last target MUST be outside the frame.
|
# The last target MUST be outside the frame.
|
||||||
|
# Format is X, Y, range, force_y
|
||||||
|
# force_y is optional. If true, y_value MUST match perfectly.
|
||||||
target_checkpoints = [
|
target_checkpoints = [
|
||||||
[ # Stage 1
|
[ # Stage 1
|
||||||
#(28, 88), # Start pillar
|
#(28, 88, 8), # Start pillar
|
||||||
(60, 80), # Middle pillar
|
(60, 80, 8), # Middle pillar
|
||||||
(105, 64), # Right ledge
|
(105, 64, 8), # Right ledge
|
||||||
(25, 40), # Left ledge
|
(25, 40, 8), # Left ledge
|
||||||
(110, 16), # End ledge
|
(97, 24, 5, True), # Small end ledge
|
||||||
(110, -2), # Next stage
|
(110, 16, 8), # End ledge
|
||||||
|
(110, -20, 8), # Next stage
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -208,9 +211,9 @@ class Celeste:
|
||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(Celeste.target_checkpoints) < stage:
|
if len(Celeste.target_checkpoints) <= stage:
|
||||||
next_point_x = None
|
next_point_x = 0
|
||||||
next_point_y = None
|
next_point_y = 0
|
||||||
else:
|
else:
|
||||||
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||||
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||||
|
@ -329,46 +332,65 @@ class Celeste:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if self.state.stage <= 0:
|
||||||
|
# Calculate distance to each point
|
||||||
|
x = self.state.xpos
|
||||||
|
y = self.state.ypos
|
||||||
|
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||||
|
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||||
|
if i < self._next_checkpoint_idx:
|
||||||
|
dist[i] = 1000
|
||||||
|
continue
|
||||||
|
|
||||||
# Calculate distance to each point
|
# Update checkpoints
|
||||||
x = self.state.xpos
|
tx, ty = c[:2]
|
||||||
y = self.state.ypos
|
dist[i] = (math.sqrt(
|
||||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
(x-tx)*(x-tx) +
|
||||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
((y-ty)*(y-ty))/2
|
||||||
if i < self._next_checkpoint_idx:
|
# Possible modification:
|
||||||
dist[i] = 1000
|
# make x-distance twice as valuable as y-distance
|
||||||
continue
|
))
|
||||||
|
min_idx = int(dist.argmin())
|
||||||
# Update checkpoints
|
dist = int(dist[min_idx])
|
||||||
tx, ty = c
|
|
||||||
dist[i] = (math.sqrt(
|
|
||||||
(x-tx)*(x-tx) +
|
|
||||||
((y-ty)*(y-ty))/2
|
|
||||||
# Possible modification:
|
|
||||||
# make x-distance twice as valuable as y-distance
|
|
||||||
))
|
|
||||||
min_idx = int(dist.argmin())
|
|
||||||
dist = int(dist[min_idx])
|
|
||||||
|
|
||||||
|
|
||||||
if dist <= 8:
|
t = Celeste.target_checkpoints[self.state.stage][min_idx]
|
||||||
print(f"Got point {min_idx}")
|
range = t[2]
|
||||||
self._next_checkpoint_idx = min_idx + 1
|
if len(t) == 3:
|
||||||
self._last_checkpoint_state = self._state_counter
|
force_y = False
|
||||||
|
else:
|
||||||
|
force_y = t[3]
|
||||||
|
|
||||||
# Recalculate distance to new point
|
if force_y:
|
||||||
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
got_point = (
|
||||||
dist = math.sqrt(
|
dist <= range and
|
||||||
(x-tx)*(x-tx) +
|
y == t[1]
|
||||||
((y-ty)*(y-ty))/2
|
)
|
||||||
)
|
else:
|
||||||
|
got_point = dist <= range
|
||||||
|
|
||||||
# Timeout if we spend too long between points
|
if got_point:
|
||||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
self._next_checkpoint_idx = min_idx + 1
|
||||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
self._last_checkpoint_state = self._state_counter
|
||||||
|
|
||||||
|
# Recalculate distance to new point
|
||||||
|
tx, ty = (
|
||||||
|
Celeste.target_checkpoints
|
||||||
|
[self.state.stage]
|
||||||
|
[self._next_checkpoint_idx]
|
||||||
|
[:2]
|
||||||
|
)
|
||||||
|
dist = math.sqrt(
|
||||||
|
(x-tx)*(x-tx) +
|
||||||
|
((y-ty)*(y-ty))/2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timeout if we spend too long between points
|
||||||
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||||
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||||
|
|
||||||
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
|
|
||||||
# Call step callbacks
|
# Call step callbacks
|
||||||
# These should call celeste.act() to set next input
|
# These should call celeste.act() to set next input
|
||||||
|
|
Reference in New Issue