Added configurable checkpoints and better stage complete handling
This commit is contained in:
		@ -77,14 +77,17 @@ class Celeste:
 | 
			
		||||
 | 
			
		||||
	# Targets the agent tries to reach.
 | 
			
		||||
	# The last target MUST be outside the frame.
 | 
			
		||||
	# Format is X, Y, range, force_y
 | 
			
		||||
	# force_y is optional. If true, y_value MUST match perfectly.
 | 
			
		||||
	target_checkpoints = [
 | 
			
		||||
		[	# Stage 1
 | 
			
		||||
			#(28, 88),		# Start pillar
 | 
			
		||||
			(60, 80),		# Middle pillar
 | 
			
		||||
			(105, 64),		# Right ledge
 | 
			
		||||
			(25, 40),		# Left ledge
 | 
			
		||||
			(110, 16),		# End ledge
 | 
			
		||||
			(110, -2),		# Next stage
 | 
			
		||||
			#(28, 88,	8),			# Start pillar
 | 
			
		||||
			(60, 80,	8),		# Middle pillar
 | 
			
		||||
			(105, 64,	8),		# Right ledge
 | 
			
		||||
			(25, 40,	8),		# Left ledge
 | 
			
		||||
			(97, 24,	5, True),	# Small end ledge
 | 
			
		||||
			(110, 16,	8),		# End ledge
 | 
			
		||||
			(110, -20,	8),		# Next stage
 | 
			
		||||
		]
 | 
			
		||||
	]
 | 
			
		||||
 | 
			
		||||
@ -208,9 +211,9 @@ class Celeste:
 | 
			
		||||
				[int(self._internal_state["rx"])]
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			if len(Celeste.target_checkpoints) < stage:
 | 
			
		||||
				next_point_x = None
 | 
			
		||||
				next_point_y = None
 | 
			
		||||
			if len(Celeste.target_checkpoints) <= stage:
 | 
			
		||||
				next_point_x = 0
 | 
			
		||||
				next_point_y = 0
 | 
			
		||||
			else:
 | 
			
		||||
				next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
 | 
			
		||||
				next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
 | 
			
		||||
@ -329,46 +332,65 @@ class Celeste:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
			if self.state.stage <= 0:
 | 
			
		||||
				# Calculate distance to each point
 | 
			
		||||
				x = self.state.xpos
 | 
			
		||||
				y = self.state.ypos
 | 
			
		||||
				dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
 | 
			
		||||
				for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
 | 
			
		||||
					if i < self._next_checkpoint_idx:
 | 
			
		||||
						dist[i] = 1000
 | 
			
		||||
						continue
 | 
			
		||||
 | 
			
		||||
			# Calculate distance to each point
 | 
			
		||||
			x = self.state.xpos
 | 
			
		||||
			y = self.state.ypos
 | 
			
		||||
			dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
 | 
			
		||||
			for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
 | 
			
		||||
				if i < self._next_checkpoint_idx:
 | 
			
		||||
					dist[i] = 1000
 | 
			
		||||
					continue
 | 
			
		||||
					# Update checkpoints
 | 
			
		||||
					tx, ty = c[:2]
 | 
			
		||||
					dist[i] = (math.sqrt(
 | 
			
		||||
						(x-tx)*(x-tx) +
 | 
			
		||||
						((y-ty)*(y-ty))/2
 | 
			
		||||
						# Possible modification:
 | 
			
		||||
						# make x-distance twice as valuable as y-distance
 | 
			
		||||
					))
 | 
			
		||||
				min_idx = int(dist.argmin())
 | 
			
		||||
				dist = int(dist[min_idx])
 | 
			
		||||
 | 
			
		||||
				# Update checkpoints
 | 
			
		||||
				tx, ty = c
 | 
			
		||||
				dist[i] = (math.sqrt(
 | 
			
		||||
					(x-tx)*(x-tx) +
 | 
			
		||||
					((y-ty)*(y-ty))/2
 | 
			
		||||
					# Possible modification:
 | 
			
		||||
					# make x-distance twice as valuable as y-distance
 | 
			
		||||
				))
 | 
			
		||||
			min_idx = int(dist.argmin())
 | 
			
		||||
			dist = int(dist[min_idx])
 | 
			
		||||
				
 | 
			
		||||
				t = Celeste.target_checkpoints[self.state.stage][min_idx]
 | 
			
		||||
				range = t[2]
 | 
			
		||||
				if len(t) == 3:
 | 
			
		||||
					force_y = False
 | 
			
		||||
				else:
 | 
			
		||||
					force_y = t[3]
 | 
			
		||||
 | 
			
		||||
				if force_y:
 | 
			
		||||
					got_point = (
 | 
			
		||||
						dist <= range and
 | 
			
		||||
						y == t[1]
 | 
			
		||||
					)
 | 
			
		||||
				else:
 | 
			
		||||
					got_point = dist <= range
 | 
			
		||||
 | 
			
		||||
			if dist <= 8:
 | 
			
		||||
				print(f"Got point {min_idx}")
 | 
			
		||||
				self._next_checkpoint_idx = min_idx + 1
 | 
			
		||||
				self._last_checkpoint_state = self._state_counter
 | 
			
		||||
				if got_point:
 | 
			
		||||
					self._next_checkpoint_idx = min_idx + 1
 | 
			
		||||
					self._last_checkpoint_state = self._state_counter
 | 
			
		||||
 | 
			
		||||
				# Recalculate distance to new point
 | 
			
		||||
				tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
 | 
			
		||||
				dist = math.sqrt(
 | 
			
		||||
					(x-tx)*(x-tx) +
 | 
			
		||||
					((y-ty)*(y-ty))/2
 | 
			
		||||
				)
 | 
			
		||||
					# Recalculate distance to new point
 | 
			
		||||
					tx, ty = (
 | 
			
		||||
						Celeste.target_checkpoints
 | 
			
		||||
						[self.state.stage]
 | 
			
		||||
						[self._next_checkpoint_idx]
 | 
			
		||||
						[:2]
 | 
			
		||||
					)
 | 
			
		||||
					dist = math.sqrt(
 | 
			
		||||
						(x-tx)*(x-tx) +
 | 
			
		||||
						((y-ty)*(y-ty))/2
 | 
			
		||||
					)
 | 
			
		||||
			
 | 
			
		||||
			# Timeout if we spend too long between points
 | 
			
		||||
			elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
 | 
			
		||||
				self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
 | 
			
		||||
				# Timeout if we spend too long between points
 | 
			
		||||
				elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
 | 
			
		||||
					self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
			self._dist = dist
 | 
			
		||||
				self._dist = dist
 | 
			
		||||
 | 
			
		||||
			# Call step callbacks
 | 
			
		||||
			# These should call celeste.act() to set next input
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user