Compare commits
	
		
			2 Commits
		
	
	
		
			4ff32b91ea
			...
			c372ef8cc7
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						
						
							
						
						c372ef8cc7
	
				 | 
					
					
						|||
| 
						
						
							
						
						589f41c205
	
				 | 
					
					
						
@@ -50,7 +50,6 @@ class Celeste:
 | 
				
			|||||||
	action_space = [
 | 
						action_space = [
 | 
				
			||||||
		"left",		# move left		0
 | 
							"left",		# move left		0
 | 
				
			||||||
		"right",	# move right	1
 | 
							"right",	# move right	1
 | 
				
			||||||
		#"jump",	# jump
 | 
					 | 
				
			||||||
		"jump-l",	# jump left		2
 | 
							"jump-l",	# jump left		2
 | 
				
			||||||
		"jump-r",	# jump right	3
 | 
							"jump-r",	# jump right	3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -86,6 +85,13 @@ class Celeste:
 | 
				
			|||||||
		]
 | 
							]
 | 
				
			||||||
	]
 | 
						]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Maps room_x, room_y coordinates to stage number.
 | 
				
			||||||
 | 
						stage_map = [
 | 
				
			||||||
 | 
							[0, 1, 2, 3, 4]
 | 
				
			||||||
 | 
						]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def __init__(
 | 
						def __init__(
 | 
				
			||||||
			self,
 | 
								self,
 | 
				
			||||||
			pico_path,
 | 
								pico_path,
 | 
				
			||||||
@@ -194,9 +200,7 @@ class Celeste:
 | 
				
			|||||||
	def state(self):
 | 
						def state(self):
 | 
				
			||||||
		try:
 | 
							try:
 | 
				
			||||||
			stage = (
 | 
								stage = (
 | 
				
			||||||
				[
 | 
									Celeste.stage_map
 | 
				
			||||||
					[0, 1, 2, 3, 4]
 | 
					 | 
				
			||||||
				]
 | 
					 | 
				
			||||||
				[int(self._internal_state["ry"])]
 | 
									[int(self._internal_state["ry"])]
 | 
				
			||||||
				[int(self._internal_state["rx"])]
 | 
									[int(self._internal_state["rx"])]
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -341,10 +341,23 @@ def on_state_after(celeste, before_out):
 | 
				
			|||||||
		dtype = torch.long
 | 
							dtype = torch.long
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						finished_stage = False
 | 
				
			||||||
 | 
						# No reward if dead
 | 
				
			||||||
	if next_state.deaths != 0:
 | 
						if next_state.deaths != 0:
 | 
				
			||||||
		pt_next_state = None
 | 
							pt_next_state = None
 | 
				
			||||||
		reward = 0
 | 
							reward = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Reward for finishing stage
 | 
				
			||||||
 | 
						elif next_state.stage >= 1:
 | 
				
			||||||
 | 
							finished_stage = True
 | 
				
			||||||
 | 
							reward = next_state.next_point - state.next_point
 | 
				
			||||||
 | 
							reward += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Add to point counter
 | 
				
			||||||
 | 
							for i in range(state.next_point, state.next_point + reward):
 | 
				
			||||||
 | 
								point_counter[i] += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Regular reward
 | 
				
			||||||
	else:
 | 
						else:
 | 
				
			||||||
		pt_next_state = torch.tensor(
 | 
							pt_next_state = torch.tensor(
 | 
				
			||||||
			[getattr(next_state, x) for x in Celeste.state_number_map],
 | 
								[getattr(next_state, x) for x in Celeste.state_number_map],
 | 
				
			||||||
@@ -401,7 +414,7 @@ def on_state_after(celeste, before_out):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	# Move on to the next episode once we reach
 | 
						# Move on to the next episode once we reach
 | 
				
			||||||
	# a terminal state.
 | 
						# a terminal state.
 | 
				
			||||||
	if (next_state.deaths != 0):
 | 
						if (next_state.deaths != 0 or finished_stage):
 | 
				
			||||||
		s = celeste.state
 | 
							s = celeste.state
 | 
				
			||||||
		with model_train_log.open("a") as f:
 | 
							with model_train_log.open("a") as f:
 | 
				
			||||||
			f.write(json.dumps({
 | 
								f.write(json.dumps({
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -22,8 +22,6 @@ render_dir () {
 | 
				
			|||||||
		$OUTPUT_DIR/${1##*/}.mp4
 | 
							$OUTPUT_DIR/${1##*/}.mp4
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Todo: error out if exists
 | 
					# Todo: error out if exists
 | 
				
			||||||
mkdir -p $OUTPUT_DIR
 | 
					mkdir -p $OUTPUT_DIR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -50,17 +48,18 @@ ffmpeg \
 | 
				
			|||||||
	-safe 0 \
 | 
						-safe 0 \
 | 
				
			||||||
	-i video_merge_list \
 | 
						-i video_merge_list \
 | 
				
			||||||
	-vf "scale=1024x1024:flags=neighbor" \
 | 
						-vf "scale=1024x1024:flags=neighbor" \
 | 
				
			||||||
	$OUTPUT_DIR/00-all.mp4
 | 
						$SC_ROOT/1x.mp4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
rm video_merge_list
 | 
					rm video_merge_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Make accelerated video
 | 
					# Make accelerated video
 | 
				
			||||||
ffmpeg \
 | 
					ffmpeg \
 | 
				
			||||||
	-loglevel error -stats -y \
 | 
						-loglevel error -stats -y \
 | 
				
			||||||
	-i $OUTPUT_DIR/00-all.mp4 \
 | 
						-i $SC_ROOT/1x.mp4 \
 | 
				
			||||||
	-framerate 60 \
 | 
						-framerate 60 \
 | 
				
			||||||
	-filter:v "setpts=0.125*PTS" \
 | 
						-filter:v "setpts=0.125*PTS" \
 | 
				
			||||||
	$SC_ROOT/8x.mp4
 | 
						$SC_ROOT/8x.mp4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
echo "Cleaning up..."
 | 
					echo "Cleaning up..."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
rm -dr $OUTPUT_DIR
 | 
					rm -dr $OUTPUT_DIR
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -200,13 +200,12 @@ for ep in range(num_episodes):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							state = next_state
 | 
				
			||||||
 | 
							
 | 
				
			||||||
		
 | 
							
 | 
				
			||||||
		# Only train the network if we have enough
 | 
							# Only train the network if we have enough
 | 
				
			||||||
		# transitions in memory to do so.
 | 
							# transitions in memory to do so.
 | 
				
			||||||
		if len(memory) >= BATCH_SIZE:
 | 
							if len(memory) >= BATCH_SIZE:
 | 
				
			||||||
 | 
					 | 
				
			||||||
			state = next_state
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			# Run optimizer
 | 
								# Run optimizer
 | 
				
			||||||
			optimize.optimize_model(
 | 
								optimize.optimize_model(
 | 
				
			||||||
				memory,
 | 
									memory,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user