Added configurable checkpoints and better stage complete handling
parent
0b61702677
commit
03135e2ef9
|
@ -77,14 +77,17 @@ class Celeste:
|
|||
|
||||
# Targets the agent tries to reach.
|
||||
# The last target MUST be outside the frame.
|
||||
# Format is X, Y, range, force_y
|
||||
# force_y is optional. If true, y_value MUST match perfectly.
|
||||
target_checkpoints = [
|
||||
[ # Stage 1
|
||||
#(28, 88), # Start pillar
|
||||
(60, 80), # Middle pillar
|
||||
(105, 64), # Right ledge
|
||||
(25, 40), # Left ledge
|
||||
(110, 16), # End ledge
|
||||
(110, -2), # Next stage
|
||||
#(28, 88, 8), # Start pillar
|
||||
(60, 80, 8), # Middle pillar
|
||||
(105, 64, 8), # Right ledge
|
||||
(25, 40, 8), # Left ledge
|
||||
(97, 24, 5, True), # Small end ledge
|
||||
(110, 16, 8), # End ledge
|
||||
(110, -20, 8), # Next stage
|
||||
]
|
||||
]
|
||||
|
||||
|
@ -208,9 +211,9 @@ class Celeste:
|
|||
[int(self._internal_state["rx"])]
|
||||
)
|
||||
|
||||
if len(Celeste.target_checkpoints) < stage:
|
||||
next_point_x = None
|
||||
next_point_y = None
|
||||
if len(Celeste.target_checkpoints) <= stage:
|
||||
next_point_x = 0
|
||||
next_point_y = 0
|
||||
else:
|
||||
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||
|
@ -329,46 +332,65 @@ class Celeste:
|
|||
|
||||
|
||||
|
||||
if self.state.stage <= 0:
|
||||
# Calculate distance to each point
|
||||
x = self.state.xpos
|
||||
y = self.state.ypos
|
||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||
if i < self._next_checkpoint_idx:
|
||||
dist[i] = 1000
|
||||
continue
|
||||
|
||||
# Calculate distance to each point
|
||||
x = self.state.xpos
|
||||
y = self.state.ypos
|
||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||
if i < self._next_checkpoint_idx:
|
||||
dist[i] = 1000
|
||||
continue
|
||||
# Update checkpoints
|
||||
tx, ty = c[:2]
|
||||
dist[i] = (math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
# Possible modification:
|
||||
# make x-distance twice as valuable as y-distance
|
||||
))
|
||||
min_idx = int(dist.argmin())
|
||||
dist = int(dist[min_idx])
|
||||
|
||||
# Update checkpoints
|
||||
tx, ty = c
|
||||
dist[i] = (math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
# Possible modification:
|
||||
# make x-distance twice as valuable as y-distance
|
||||
))
|
||||
min_idx = int(dist.argmin())
|
||||
dist = int(dist[min_idx])
|
||||
|
||||
t = Celeste.target_checkpoints[self.state.stage][min_idx]
|
||||
range = t[2]
|
||||
if len(t) == 3:
|
||||
force_y = False
|
||||
else:
|
||||
force_y = t[3]
|
||||
|
||||
if force_y:
|
||||
got_point = (
|
||||
dist <= range and
|
||||
y == t[1]
|
||||
)
|
||||
else:
|
||||
got_point = dist <= range
|
||||
|
||||
if dist <= 8:
|
||||
print(f"Got point {min_idx}")
|
||||
self._next_checkpoint_idx = min_idx + 1
|
||||
self._last_checkpoint_state = self._state_counter
|
||||
if got_point:
|
||||
self._next_checkpoint_idx = min_idx + 1
|
||||
self._last_checkpoint_state = self._state_counter
|
||||
|
||||
# Recalculate distance to new point
|
||||
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
dist = math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
)
|
||||
# Recalculate distance to new point
|
||||
tx, ty = (
|
||||
Celeste.target_checkpoints
|
||||
[self.state.stage]
|
||||
[self._next_checkpoint_idx]
|
||||
[:2]
|
||||
)
|
||||
dist = math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
)
|
||||
|
||||
# Timeout if we spend too long between points
|
||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||
# Timeout if we spend too long between points
|
||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||
|
||||
|
||||
self._dist = dist
|
||||
self._dist = dist
|
||||
|
||||
# Call step callbacks
|
||||
# These should call celeste.act() to set next input
|
||||
|
|
Reference in New Issue