Compare commits
	
		
			2 Commits
		
	
	
		
			ee232329b7
			...
			8420e719d8
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8420e719d8 | |||
| 6b7abc49a6 | 
							
								
								
									
										119
									
								
								celeste/celeste_ai/paths.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								celeste/celeste_ai/paths.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,119 @@ | ||||
| from pathlib import Path | ||||
| import torch | ||||
| import json | ||||
|  | ||||
| from celeste_ai import Celeste | ||||
| from celeste_ai import DQN | ||||
|  | ||||
|  | ||||
|  | ||||
| model_data_root = Path("model_data/solved_1") | ||||
|  | ||||
| compute_device = torch.device( | ||||
| 	"cuda" if torch.cuda.is_available() else "cpu" | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Celeste env properties | ||||
| n_observations = len(Celeste.state_number_map) | ||||
| n_actions = len(Celeste.action_space) | ||||
|  | ||||
| policy_net = DQN( | ||||
| 	n_observations, | ||||
| 	n_actions | ||||
| ).to(compute_device) | ||||
|  | ||||
| k = (model_data_root / "model_archive").iterdir() | ||||
| i = 0 | ||||
|  | ||||
| state_history = [] | ||||
| current_path = None | ||||
|  | ||||
| def next_image(): | ||||
| 	global policy_net | ||||
| 	global current_path | ||||
| 	global i | ||||
| 	i += 1 | ||||
|  | ||||
| 	try: | ||||
| 		current_path = k.__next__() | ||||
| 	except StopIteration: | ||||
| 		return False | ||||
|  | ||||
| 	print(f"Pathing {current_path} ({i})") | ||||
|  | ||||
| 	# Load model if one exists | ||||
| 	checkpoint = torch.load( | ||||
| 		current_path, | ||||
| 		map_location = compute_device | ||||
| 	) | ||||
| 	policy_net.load_state_dict(checkpoint["policy_state_dict"]) | ||||
|  | ||||
|  | ||||
| next_image() | ||||
|  | ||||
| def on_state_before(celeste): | ||||
| 	global steps_done | ||||
|  | ||||
| 	state = celeste.state | ||||
|  | ||||
| 	pt_state = torch.tensor( | ||||
| 		[getattr(state, x) for x in Celeste.state_number_map], | ||||
| 		dtype = torch.float32, | ||||
| 		device = compute_device | ||||
| 	).unsqueeze(0) | ||||
|  | ||||
|  | ||||
| 	action = policy_net(pt_state).max(1)[1].view(1, 1).item() | ||||
| 	str_action = Celeste.action_space[action] | ||||
|  | ||||
| 	celeste.act(str_action) | ||||
|  | ||||
| 	return state, action | ||||
|  | ||||
|  | ||||
| def on_state_after(celeste, before_out): | ||||
| 	global episode_number | ||||
| 	global state_history | ||||
|  | ||||
| 	state, action = before_out | ||||
| 	next_state = celeste.state | ||||
| 	finished_stage = next_state.stage >= 1 | ||||
|  | ||||
| 	state_history.append({ | ||||
| 		"xpos": state.xpos, | ||||
| 		"ypos": state.ypos, | ||||
| 		"action": Celeste.action_space[action] | ||||
| 	}) | ||||
|  | ||||
| 	# Move on to the next episode once we reach | ||||
| 	# a terminal state. | ||||
| 	if (next_state.deaths != 0 or finished_stage): | ||||
|  | ||||
| 		with (model_data_root / "paths.json").open("a") as f: | ||||
| 			f.write(json.dumps( | ||||
| 				{ | ||||
| 					"hist": state_history, | ||||
| 					"current_image": str(current_path) | ||||
| 				} | ||||
| 			) + "\n") | ||||
|  | ||||
| 		state_history = [] | ||||
| 		k = next_image() | ||||
|  | ||||
| 		if k is False: | ||||
| 			raise Exception("Done.") | ||||
|  | ||||
| 		print("Game over. Resetting.") | ||||
| 		celeste.reset() | ||||
|  | ||||
|  | ||||
|  | ||||
| c = Celeste( | ||||
| 	"resources/pico-8/linux/pico8" | ||||
| ) | ||||
|  | ||||
| c.update_loop( | ||||
| 	on_state_before, | ||||
| 	on_state_after | ||||
| ) | ||||
							
								
								
									
										100
									
								
								celeste/celeste_ai/test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								celeste/celeste_ai/test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,100 @@ | ||||
| from pathlib import Path | ||||
| import torch | ||||
|  | ||||
| from celeste_ai import Celeste | ||||
| from celeste_ai import DQN | ||||
| from celeste_ai.util.screenshots import ScreenshotManager | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| 	# Where to read/write model data. | ||||
| 	model_data_root = Path("model_data/current") | ||||
|  | ||||
| 	model_save_path		= model_data_root / "model.torch" | ||||
| 	model_data_root.mkdir(parents = True, exist_ok = True) | ||||
|  | ||||
|  | ||||
| 	sm = ScreenshotManager( | ||||
| 		# Where PICO-8 saves screenshots. | ||||
| 		# Probably your desktop. | ||||
| 		source = Path("/home/mark/Desktop"), | ||||
| 		pattern = "hackcel_*.png", | ||||
| 		target = model_data_root / "screenshots_test" | ||||
| 	).clean() # Remove old screenshots | ||||
|  | ||||
|  | ||||
| 	compute_device = torch.device( | ||||
| 		"cuda" if torch.cuda.is_available() else "cpu" | ||||
| 	) | ||||
|  | ||||
| 	episode_number = 0 | ||||
|  | ||||
| 	# Celeste env properties | ||||
| 	n_observations = len(Celeste.state_number_map) | ||||
| 	n_actions = len(Celeste.action_space) | ||||
|  | ||||
| 	policy_net = DQN( | ||||
| 		n_observations, | ||||
| 		n_actions | ||||
| 	).to(compute_device) | ||||
|  | ||||
|  | ||||
| 	# Load model if one exists | ||||
| 	checkpoint = torch.load( | ||||
| 		model_save_path, | ||||
| 		map_location = compute_device | ||||
| 	) | ||||
| 	policy_net.load_state_dict(checkpoint["policy_state_dict"]) | ||||
|  | ||||
|  | ||||
| def on_state_before(celeste): | ||||
| 	global steps_done | ||||
|  | ||||
| 	state = celeste.state | ||||
|  | ||||
| 	pt_state = torch.tensor( | ||||
| 		[getattr(state, x) for x in Celeste.state_number_map], | ||||
| 		dtype = torch.float32, | ||||
| 		device = compute_device | ||||
| 	).unsqueeze(0) | ||||
|  | ||||
|  | ||||
| 	action = policy_net(pt_state).max(1)[1].view(1, 1).item() | ||||
| 	str_action = Celeste.action_space[action] | ||||
|  | ||||
| 	print(str_action) | ||||
| 	celeste.act(str_action) | ||||
|  | ||||
| 	return state, action | ||||
|  | ||||
|  | ||||
| def on_state_after(celeste, before_out): | ||||
| 	global episode_number | ||||
|  | ||||
| 	state, action = before_out | ||||
| 	next_state = celeste.state | ||||
| 	finished_stage = next_state.stage >= 1 | ||||
|  | ||||
|  | ||||
| 	# Move on to the next episode once we reach | ||||
| 	# a terminal state. | ||||
| 	if (next_state.deaths != 0 or finished_stage): | ||||
| 		s = celeste.state | ||||
|  | ||||
| 		sm.move() | ||||
|  | ||||
|  | ||||
| 		print("Game over. Resetting.") | ||||
| 		celeste.reset() | ||||
| 		episode_number += 1 | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| 	c = Celeste( | ||||
| 		"resources/pico-8/linux/pico8" | ||||
| 	) | ||||
|  | ||||
| 	c.update_loop( | ||||
| 		on_state_before, | ||||
| 		on_state_after | ||||
| 	) | ||||
		Reference in New Issue
	
	Block a user