Cleaned up celeste wrapper
This commit is contained in:
		@@ -1,12 +1,44 @@
 | 
				
			|||||||
 | 
					from typing import NamedTuple
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import threading
 | 
					 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
from tqdm import tqdm
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CelesteError(Exception):
 | 
					class CelesteError(Exception):
 | 
				
			||||||
	pass
 | 
						pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CelesteState(NamedTuple):
 | 
				
			||||||
 | 
						# Stage number
 | 
				
			||||||
 | 
						stage: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Player position
 | 
				
			||||||
 | 
						xpos: int
 | 
				
			||||||
 | 
						ypos: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Player velocity
 | 
				
			||||||
 | 
						xvel: float
 | 
				
			||||||
 | 
						yvel: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Number of deaths since game start
 | 
				
			||||||
 | 
						deaths: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Distance to next point
 | 
				
			||||||
 | 
						dist: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Index of next point
 | 
				
			||||||
 | 
						next_point: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Coordinates of next point
 | 
				
			||||||
 | 
						next_point_x: int
 | 
				
			||||||
 | 
						next_point_y: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# Number of states recieved since restart
 | 
				
			||||||
 | 
						state_count: int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# True if Madeline can dash
 | 
				
			||||||
 | 
						can_dash: bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Celeste:
 | 
					class Celeste:
 | 
				
			||||||
	action_space = [
 | 
						action_space = [
 | 
				
			||||||
		"left",		# move left
 | 
							"left",		# move left
 | 
				
			||||||
@@ -20,10 +52,25 @@ class Celeste:
 | 
				
			|||||||
		"dash-lu"	# dash left-up
 | 
							"dash-lu"	# dash left-up
 | 
				
			||||||
	]
 | 
						]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def __init__(self):
 | 
						# Map integers to state values.
 | 
				
			||||||
 | 
						# This also determines what data is fed to the model.
 | 
				
			||||||
 | 
						state_number_map = [
 | 
				
			||||||
 | 
							"xpos",
 | 
				
			||||||
 | 
							"ypos",
 | 
				
			||||||
 | 
							"next_point_x",
 | 
				
			||||||
 | 
							"next_point_y"
 | 
				
			||||||
 | 
						]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def __init__(
 | 
				
			||||||
 | 
								self,
 | 
				
			||||||
 | 
								*,
 | 
				
			||||||
 | 
								state_timeout = 30,
 | 
				
			||||||
 | 
								cart_name = "hackcel.p8"
 | 
				
			||||||
 | 
							):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Start pico-8
 | 
							# Start pico-8
 | 
				
			||||||
		self.process = subprocess.Popen(
 | 
							self._process = subprocess.Popen(
 | 
				
			||||||
			"bin/pico-8/linux/pico8",
 | 
								"resources/pico-8/linux/pico8",
 | 
				
			||||||
			shell=True,
 | 
								shell=True,
 | 
				
			||||||
			stdout=subprocess.PIPE,
 | 
								stdout=subprocess.PIPE,
 | 
				
			||||||
			stderr=subprocess.STDOUT
 | 
								stderr=subprocess.STDOUT
 | 
				
			||||||
@@ -39,26 +86,34 @@ class Celeste:
 | 
				
			|||||||
		]).decode("utf-8").strip().split("\n")
 | 
							]).decode("utf-8").strip().split("\n")
 | 
				
			||||||
		if len(winid) != 1:
 | 
							if len(winid) != 1:
 | 
				
			||||||
			raise Exception("Could not find unique PICO-8 window id")
 | 
								raise Exception("Could not find unique PICO-8 window id")
 | 
				
			||||||
		self.winid = winid[0]
 | 
							self._winid = winid[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Load cartridge
 | 
							# Load cartridge
 | 
				
			||||||
		self.keystring("load hackcel.p8")
 | 
							self._keystring(f"load {cart_name}")
 | 
				
			||||||
		self.keypress("Enter")
 | 
							self._keypress("Enter")
 | 
				
			||||||
		self.keystring("run")
 | 
							self._keystring("run")
 | 
				
			||||||
		self.keypress("Enter", post = 1000)
 | 
							self._keypress("Enter", post = 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Initialize variables
 | 
					 | 
				
			||||||
		self.internal_status = {}
 | 
					 | 
				
			||||||
		self.before_out = None
 | 
					 | 
				
			||||||
		self.last_point_frame = 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Score system
 | 
							# Parameters
 | 
				
			||||||
		self.frame_counter = 0
 | 
							self.state_timeout = state_timeout	# If we run this many states without getting a checkpoint, reset.
 | 
				
			||||||
		self.next_point = 0
 | 
							self.cart_name = cart_name			# Name of cart to load. Not used anywhere, but saved for convenience.
 | 
				
			||||||
		self.dist = 0 # distance to next point
 | 
					
 | 
				
			||||||
		self.target_points = [
 | 
							# Internal variables
 | 
				
			||||||
 | 
							self._internal_state = {}			# Raw data read from stdout
 | 
				
			||||||
 | 
							self._before_out = None				# Output of "before" callback in update loop
 | 
				
			||||||
 | 
							self._last_checkpoint_state = 0		# Index of frame at which we reached the last checkpoint
 | 
				
			||||||
 | 
							self._state_counter = 0				# Number of frames we've run since last reset
 | 
				
			||||||
 | 
							self._next_checkpoint_idx = 0		# Index of next point
 | 
				
			||||||
 | 
							self._dist = 0						# Distance to next point
 | 
				
			||||||
 | 
							self._resetting = False				# True between a call to .reset() and the first state message from pico.
 | 
				
			||||||
 | 
							self._keys = {}						# Dictionary of "key": bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Targets the agent tries to reach.
 | 
				
			||||||
 | 
							# The last target MUST be outside the frame.
 | 
				
			||||||
 | 
							self.target_checkpoints = [
 | 
				
			||||||
			[	# Stage 1
 | 
								[	# Stage 1
 | 
				
			||||||
				(28, 88),		# Start pillar
 | 
									#(28, 88),		# Start pillar
 | 
				
			||||||
				(60, 80),		# Middle pillar
 | 
									(60, 80),		# Middle pillar
 | 
				
			||||||
				(105, 64),		# Right ledge
 | 
									(105, 64),		# Right ledge
 | 
				
			||||||
				(25, 40),		# Left ledge
 | 
									(25, 40),		# Left ledge
 | 
				
			||||||
@@ -67,119 +122,150 @@ class Celeste:
 | 
				
			|||||||
			]
 | 
								]
 | 
				
			||||||
		]
 | 
							]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def act(self, action):
 | 
						def act(self, action: str):
 | 
				
			||||||
		self.keyup("x")
 | 
							"""
 | 
				
			||||||
		self.keyup("c")
 | 
							Specify what keys should be down. This does NOT send key events.
 | 
				
			||||||
		self.keyup("Left")
 | 
							Celeste._apply_keys() does that at the right time.
 | 
				
			||||||
		self.keyup("Right")
 | 
					 | 
				
			||||||
		self.keyup("Down")
 | 
					 | 
				
			||||||
		self.keyup("Up")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							Args:
 | 
				
			||||||
 | 
								action (str): key name, as in Celeste.action_space
 | 
				
			||||||
 | 
							"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							self._keys = {}
 | 
				
			||||||
		if action is None:
 | 
							if action is None:
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		elif action == "left":
 | 
							elif action == "left":
 | 
				
			||||||
			self.keydown("Left")
 | 
								self._keys["Left"] = True
 | 
				
			||||||
		elif action == "right":
 | 
							elif action == "right":
 | 
				
			||||||
			self.keydown("Right")
 | 
								self._keys["Right"] = True
 | 
				
			||||||
		elif action == "jump":
 | 
							elif action == "jump":
 | 
				
			||||||
			self.keydown("c")
 | 
								self._keys["c"] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		elif action == "dash-u":
 | 
							elif action == "dash-u":
 | 
				
			||||||
			self.keydown("Up")
 | 
								self._keys["Up"] = True
 | 
				
			||||||
			self.keydown("x")
 | 
								self._keys["x"] = True
 | 
				
			||||||
		elif action == "dash-r":
 | 
							elif action == "dash-r":
 | 
				
			||||||
			self.keydown("Right")
 | 
								self._keys["Right"] = True
 | 
				
			||||||
			self.keydown("x")
 | 
								self._keys["x"] = True
 | 
				
			||||||
		elif action == "dash-l":
 | 
							elif action == "dash-l":
 | 
				
			||||||
			self.keydown("Left")
 | 
								self._keys["Left"] = True
 | 
				
			||||||
			self.keydown("x")
 | 
								self._keys["x"] = True
 | 
				
			||||||
		elif action == "dash-ru":
 | 
							elif action == "dash-ru":
 | 
				
			||||||
			self.keydown("Up")
 | 
								self._keys["Up"] = True
 | 
				
			||||||
			self.keydown("Right")
 | 
								self._keys["Right"] = True
 | 
				
			||||||
			self.keydown("x")
 | 
								self._keys["x"] = True
 | 
				
			||||||
		elif action == "dash-lu":
 | 
							elif action == "dash-lu":
 | 
				
			||||||
			self.keydown("Up")
 | 
								self._keys["Up"] = True
 | 
				
			||||||
			self.keydown("Left")
 | 
								self._keys["Left"] = True
 | 
				
			||||||
			self.keydown("x")
 | 
								self._keys["x"] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _apply_keys(self):
 | 
				
			||||||
 | 
							for i in [
 | 
				
			||||||
 | 
								"x", "c",
 | 
				
			||||||
 | 
								"Left", "Right",
 | 
				
			||||||
 | 
								"Down", "Up"
 | 
				
			||||||
 | 
							]:
 | 
				
			||||||
 | 
								if self._keys.get(i):
 | 
				
			||||||
 | 
									self._keydown(i)
 | 
				
			||||||
 | 
								else:
 | 
				
			||||||
 | 
									self._keyup(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	@property
 | 
						@property
 | 
				
			||||||
	def status(self):
 | 
						def state(self):
 | 
				
			||||||
		try:
 | 
							try:
 | 
				
			||||||
			return {
 | 
								stage = (
 | 
				
			||||||
				"stage": (
 | 
					 | 
				
			||||||
				[
 | 
									[
 | 
				
			||||||
					[0, 1, 2, 3, 4]
 | 
										[0, 1, 2, 3, 4]
 | 
				
			||||||
				]
 | 
									]
 | 
				
			||||||
					[int(self.internal_status["ry"])]
 | 
									[int(self._internal_state["ry"])]
 | 
				
			||||||
					[int(self.internal_status["rx"])]
 | 
									[int(self._internal_state["rx"])]
 | 
				
			||||||
				),
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				"xpos": int(self.internal_status["px"]),
 | 
								if len(self.target_checkpoints) < stage:
 | 
				
			||||||
				"ypos": int(self.internal_status["py"]),
 | 
									next_point_x = None
 | 
				
			||||||
				"xvel": float(self.internal_status["vx"]),
 | 
									next_point_y = None
 | 
				
			||||||
				"yvel": float(self.internal_status["vy"]),
 | 
								else:
 | 
				
			||||||
				"deaths": int(self.internal_status["dc"]),
 | 
									next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
 | 
				
			||||||
 | 
									next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return CelesteState(
 | 
				
			||||||
 | 
									stage			= stage,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									xpos			= int(self._internal_state["px"]),
 | 
				
			||||||
 | 
									ypos			= int(self._internal_state["py"]),
 | 
				
			||||||
 | 
									xvel			= float(self._internal_state["vx"]),
 | 
				
			||||||
 | 
									yvel			= float(self._internal_state["vy"]),
 | 
				
			||||||
 | 
									deaths			= int(self._internal_state["dc"]),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									dist			= self._dist,
 | 
				
			||||||
 | 
									next_point		= self._next_checkpoint_idx,
 | 
				
			||||||
 | 
									next_point_x	= next_point_x,
 | 
				
			||||||
 | 
									next_point_y	= next_point_y,
 | 
				
			||||||
 | 
									state_count		= self._state_counter,
 | 
				
			||||||
 | 
									can_dash		= self._internal_state["ds"] == "t"
 | 
				
			||||||
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				"dist": self.dist,
 | 
					 | 
				
			||||||
				"next_point": self.next_point,
 | 
					 | 
				
			||||||
				"frame_count": self.frame_counter
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		except KeyError:
 | 
							except KeyError:
 | 
				
			||||||
			raise CelesteError("Not enough data to get status.")
 | 
								raise CelesteError("Not enough data to get state.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						def _keypress(self, key: str, *, post = 200):
 | 
				
			||||||
	def keypress(self, key: str, *, post = 200):
 | 
					 | 
				
			||||||
		subprocess.run([
 | 
							subprocess.run([
 | 
				
			||||||
			"xdotool",
 | 
								"xdotool",
 | 
				
			||||||
			"key",
 | 
								"key",
 | 
				
			||||||
			"--window", self.winid,
 | 
								"--window", self._winid,
 | 
				
			||||||
			key
 | 
								key
 | 
				
			||||||
		])
 | 
							])
 | 
				
			||||||
		time.sleep(post / 1000)
 | 
							time.sleep(post / 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def keydown(self, key: str):
 | 
						def _keydown(self, key: str):
 | 
				
			||||||
		subprocess.run([
 | 
							subprocess.run([
 | 
				
			||||||
			"xdotool",
 | 
								"xdotool",
 | 
				
			||||||
			"keydown",
 | 
								"keydown",
 | 
				
			||||||
			"--window", self.winid,
 | 
								"--window", self._winid,
 | 
				
			||||||
			key
 | 
								key
 | 
				
			||||||
		])
 | 
							])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def keyup(self, key: str):
 | 
						def _keyup(self, key: str):
 | 
				
			||||||
		subprocess.run([
 | 
							subprocess.run([
 | 
				
			||||||
			"xdotool",
 | 
								"xdotool",
 | 
				
			||||||
			"keyup",
 | 
								"keyup",
 | 
				
			||||||
			"--window", self.winid,
 | 
								"--window", self._winid,
 | 
				
			||||||
			key
 | 
								key
 | 
				
			||||||
		])
 | 
							])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def keystring(self, string, *, delay = 100, post = 200):
 | 
						def _keystring(self, string, *, delay = 100, post = 200):
 | 
				
			||||||
		subprocess.run([
 | 
							subprocess.run([
 | 
				
			||||||
			"xdotool",
 | 
								"xdotool",
 | 
				
			||||||
			"type",
 | 
								"type",
 | 
				
			||||||
			"--window", self.winid,
 | 
								"--window", self._winid,
 | 
				
			||||||
			"--delay", str(delay),
 | 
								"--delay", str(delay),
 | 
				
			||||||
			string
 | 
								string
 | 
				
			||||||
		])
 | 
							])
 | 
				
			||||||
		time.sleep(post / 1000)
 | 
							time.sleep(post / 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def reset(self):
 | 
						def reset(self):
 | 
				
			||||||
		self.internal_status = {}
 | 
							# Make sure all keys are released
 | 
				
			||||||
		self.next_point = 0
 | 
							self.act(None)
 | 
				
			||||||
		self.frame_counter = 0
 | 
							self._apply_keys()
 | 
				
			||||||
		self.before_out = None
 | 
					 | 
				
			||||||
		self.resetting = True
 | 
					 | 
				
			||||||
		self.last_point_frame = 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self.keypress("Escape")
 | 
							self._internal_state = {}
 | 
				
			||||||
		self.keystring("run")
 | 
							self._next_checkpoint_idx = 0
 | 
				
			||||||
		self.keypress("Enter", post = 1000)
 | 
							self._state_counter = 0
 | 
				
			||||||
 | 
							self._before_out = None
 | 
				
			||||||
 | 
							self._resetting = True
 | 
				
			||||||
 | 
							self._last_checkpoint_state = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		self.flush_reader()
 | 
							self._keypress("Escape")
 | 
				
			||||||
 | 
							self._keystring("run")
 | 
				
			||||||
 | 
							self._keypress("Enter", post = 1000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	def flush_reader(self):
 | 
					
 | 
				
			||||||
		for k in iter(self.process.stdout.readline, ""):
 | 
					
 | 
				
			||||||
 | 
							# Clear all old stdout messages and
 | 
				
			||||||
 | 
							# wait for the game to restart.
 | 
				
			||||||
 | 
							for k in iter(self._process.stdout.readline, ""):
 | 
				
			||||||
			k = k.decode("utf-8")[:-1]
 | 
								k = k.decode("utf-8")[:-1]
 | 
				
			||||||
			if k == "!RESTART":
 | 
								if k == "!RESTART":
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
@@ -187,61 +273,68 @@ class Celeste:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	def update_loop(self, before, after):
 | 
						def update_loop(self, before, after):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		# Get state, call callback, wait for state
 | 
							# Waits for stdout from pico-8 process
 | 
				
			||||||
		# One line => one frame.
 | 
							for line in iter(self._process.stdout.readline, ""):
 | 
				
			||||||
 | 
					 | 
				
			||||||
		it = iter(self.process.stdout.readline, "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for line in it:
 | 
					 | 
				
			||||||
			l = line.decode("utf-8")[:-1].strip()
 | 
								l = line.decode("utf-8")[:-1].strip()
 | 
				
			||||||
			self.resetting = False
 | 
					
 | 
				
			||||||
 | 
								# Release all keys
 | 
				
			||||||
 | 
								self.act(None)
 | 
				
			||||||
 | 
								self._apply_keys()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								# Clear reset state
 | 
				
			||||||
 | 
								self._resetting = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			# This should only occur at game start
 | 
								# This should only occur at game start
 | 
				
			||||||
			if l in ["!RESTART"]:
 | 
								if l in ["!RESTART"]:
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			self.frame_counter += 1
 | 
								self._state_counter += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			# Parse status string
 | 
								# Parse state string
 | 
				
			||||||
			for entry in l.split(";"):
 | 
								for entry in l.split(";"):
 | 
				
			||||||
				if entry == "":
 | 
									if entry == "":
 | 
				
			||||||
					continue
 | 
										continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				key, val = entry.split(":")
 | 
									key, val = entry.split(":")
 | 
				
			||||||
				self.internal_status[key] = val
 | 
									self._internal_state[key] = val
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			# Update checkpoints
 | 
								# Update checkpoints
 | 
				
			||||||
 | 
								tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
 | 
				
			||||||
			tx, ty = self.target_points[self.status["stage"]][self.next_point]
 | 
								x = self.state.xpos
 | 
				
			||||||
			x = self.status["xpos"]
 | 
								y = self.state.ypos
 | 
				
			||||||
			y = self.status["ypos"]
 | 
					 | 
				
			||||||
			dist = math.sqrt(
 | 
								dist = math.sqrt(
 | 
				
			||||||
				(x-tx)*(x-tx) +
 | 
									(x-tx)*(x-tx) +
 | 
				
			||||||
				(y-ty)*(y-ty)
 | 
									((y-ty)*(y-ty))/2
 | 
				
			||||||
 | 
									# Possible modification:
 | 
				
			||||||
 | 
									# make x-distance twice as valuable as y-distance
 | 
				
			||||||
			)
 | 
								)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if dist <= 4 and y == ty:
 | 
								if dist <= 5:
 | 
				
			||||||
				print(f"Got point {self.next_point}")
 | 
									print(f"Got point {self._next_checkpoint_idx}")
 | 
				
			||||||
				self.next_point += 1
 | 
									self._next_checkpoint_idx += 1
 | 
				
			||||||
				self.last_point_frame = self.frame_counter
 | 
									self._last_checkpoint_state = self._state_counter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				# Recalculate distance to new point
 | 
									# Recalculate distance to new point
 | 
				
			||||||
				tx, ty = self.target_points[self.status["stage"]][self.next_point]
 | 
									tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
 | 
				
			||||||
				dist = math.sqrt(
 | 
									dist = math.sqrt(
 | 
				
			||||||
					(x-tx)*(x-tx) +
 | 
										(x-tx)*(x-tx) +
 | 
				
			||||||
					(y-ty)*(y-ty)
 | 
										((y-ty)*(y-ty))/2
 | 
				
			||||||
				)
 | 
									)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			# Timeout if we spend too long between points
 | 
								# Timeout if we spend too long between points
 | 
				
			||||||
			elif self.frame_counter - self.last_point_frame > 40:
 | 
								elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
 | 
				
			||||||
				self.internal_status["dc"] = str(int(self.internal_status["dc"]) + 1)
 | 
									self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			self.dist = dist
 | 
								self._dist = dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			# Call step callbacks
 | 
								# Call step callbacks
 | 
				
			||||||
			if self.before_out is not None:
 | 
								# These should call celeste.act() to set next input
 | 
				
			||||||
				after(self, self.before_out)
 | 
								if self._before_out is not None:
 | 
				
			||||||
			if not self.resetting:
 | 
									after(self, self._before_out)
 | 
				
			||||||
				self.before_out = before(self)
 | 
					
 | 
				
			||||||
 | 
								# Do not run before callback if after() triggered a reset.
 | 
				
			||||||
 | 
								if not self._resetting:
 | 
				
			||||||
 | 
									self._before_out = before(self)
 | 
				
			||||||
 | 
								self._apply_keys()
 | 
				
			||||||
 | 
								
 | 
				
			||||||
							
								
								
									
										150
									
								
								celeste/main.py
									
									
									
									
									
								
							
							
						
						
									
										150
									
								
								celeste/main.py
									
									
									
									
									
								
							@@ -1,30 +1,24 @@
 | 
				
			|||||||
from collections import namedtuple
 | 
					from collections import namedtuple
 | 
				
			||||||
from collections import deque
 | 
					from collections import deque
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Glue layer
 | 
					 | 
				
			||||||
from celeste import Celeste
 | 
					from celeste import Celeste
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					run_data_path = Path("out")
 | 
				
			||||||
 | 
					run_data_path.mkdir(parents = True, exist_ok = True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
compute_device = torch.device(
 | 
					compute_device = torch.device(
 | 
				
			||||||
	"cuda" if torch.cuda.is_available() else "cpu"
 | 
						"cuda" if torch.cuda.is_available() else "cpu"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
state_number_map = [
 | 
					 | 
				
			||||||
	"xpos",
 | 
					 | 
				
			||||||
	"ypos",
 | 
					 | 
				
			||||||
	"xvel",
 | 
					 | 
				
			||||||
	"yvel",
 | 
					 | 
				
			||||||
	"next_point"
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Celeste env properties
 | 
					# Celeste env properties
 | 
				
			||||||
n_observations = len(state_number_map)
 | 
					n_observations = len(Celeste.state_number_map)
 | 
				
			||||||
n_actions = len(Celeste.action_space)
 | 
					n_actions = len(Celeste.action_space)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -39,7 +33,7 @@ EPS_END = 0.05
 | 
				
			|||||||
EPS_DECAY = 1000
 | 
					EPS_DECAY = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BATCH_SIZE = 128
 | 
					BATCH_SIZE = 1_000
 | 
				
			||||||
# Learning rate of target_net.
 | 
					# Learning rate of target_net.
 | 
				
			||||||
# Controls how soft our soft update is.
 | 
					# Controls how soft our soft update is.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
@@ -64,9 +58,19 @@ GAMMA = 0.99
 | 
				
			|||||||
class DQN(torch.nn.Module):
 | 
					class DQN(torch.nn.Module):
 | 
				
			||||||
	def __init__(self, n_observations: int, n_actions: int):
 | 
						def __init__(self, n_observations: int, n_actions: int):
 | 
				
			||||||
		super(DQN, self).__init__()
 | 
							super(DQN, self).__init__()
 | 
				
			||||||
		self.layer1 = torch.nn.Linear(n_observations, 128)
 | 
							
 | 
				
			||||||
		self.layer2 = torch.nn.Linear(128, 128)
 | 
							self.layers = torch.nn.Sequential(
 | 
				
			||||||
		self.layer3 = torch.nn.Linear(128, n_actions)
 | 
								torch.nn.Linear(n_observations, 128),
 | 
				
			||||||
 | 
								torch.nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								torch.nn.Linear(128, 128),
 | 
				
			||||||
 | 
								torch.nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								torch.nn.Linear(128, 128),
 | 
				
			||||||
 | 
								torch.nn.ReLU(),
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								torch.torch.nn.Linear(128, n_actions)
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	# Can be called with one input, or with a batch.
 | 
						# Can be called with one input, or with a batch.
 | 
				
			||||||
	#
 | 
						#
 | 
				
			||||||
@@ -77,9 +81,7 @@ class DQN(torch.nn.Module):
 | 
				
			|||||||
	# Recall that Q(s, a) is the (expected) return of taking
 | 
						# Recall that Q(s, a) is the (expected) return of taking
 | 
				
			||||||
	# action `a` at state `s`
 | 
						# action `a` at state `s`
 | 
				
			||||||
	def forward(self, x):
 | 
						def forward(self, x):
 | 
				
			||||||
		x = torch.nn.functional.relu(self.layer1(x))
 | 
							return self.layers(x)
 | 
				
			||||||
		x = torch.nn.functional.relu(self.layer2(x))
 | 
					 | 
				
			||||||
		return self.layer3(x)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -94,7 +96,7 @@ num_episodes = 100
 | 
				
			|||||||
# Memory: a deque that holds recent states as Transitions
 | 
					# Memory: a deque that holds recent states as Transitions
 | 
				
			||||||
#	Has a fixed length, drops oldest
 | 
					#	Has a fixed length, drops oldest
 | 
				
			||||||
#	element if maxlen is exceeded.
 | 
					#	element if maxlen is exceeded.
 | 
				
			||||||
memory = deque([], maxlen=10_000)
 | 
					memory = deque([], maxlen=100_000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
policy_net = DQN(
 | 
					policy_net = DQN(
 | 
				
			||||||
@@ -112,11 +114,10 @@ target_net.load_state_dict(policy_net.state_dict())
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
optimizer = torch.optim.AdamW(
 | 
					optimizer = torch.optim.AdamW(
 | 
				
			||||||
	policy_net.parameters(),
 | 
						policy_net.parameters(),
 | 
				
			||||||
	lr = 1e-4, # Hyperparameter: learning rate
 | 
						lr = 0.01, # Hyperparameter: learning rate
 | 
				
			||||||
	amsgrad = True
 | 
						amsgrad = True
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def select_action(state, steps_done):
 | 
					def select_action(state, steps_done):
 | 
				
			||||||
	"""
 | 
						"""
 | 
				
			||||||
	Select an action using an epsilon-greedy policy.
 | 
						Select an action using an epsilon-greedy policy.
 | 
				
			||||||
@@ -303,39 +304,68 @@ def optimize_model():
 | 
				
			|||||||
	optimizer.step()
 | 
						optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					episode_number = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if (run_data_path/"checkpoint.torch").is_file():
 | 
				
			||||||
 | 
						# Load model if one exists
 | 
				
			||||||
 | 
						checkpoint = torch.load((run_data_path/"checkpoint.torch"))
 | 
				
			||||||
 | 
						policy_net.load_state_dict(checkpoint["policy_state_dict"])
 | 
				
			||||||
 | 
						target_net.load_state_dict(checkpoint["target_state_dict"])
 | 
				
			||||||
 | 
						optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
 | 
				
			||||||
 | 
						memory = checkpoint["memory"]
 | 
				
			||||||
 | 
						episode_number = checkpoint["episode_number"] + 1
 | 
				
			||||||
 | 
						steps_done = checkpoint["steps_done"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def on_state_before(celeste):
 | 
					def on_state_before(celeste):
 | 
				
			||||||
	global steps_done
 | 
						global steps_done
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	# Conversion to pytorch
 | 
						# Conversion to pytorch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	state = celeste.status
 | 
						state = celeste.state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pt_state = torch.tensor(
 | 
						pt_state = torch.tensor(
 | 
				
			||||||
		[state[x] for x in state_number_map],
 | 
							[getattr(state, x) for x in Celeste.state_number_map],
 | 
				
			||||||
		dtype = torch.float32,
 | 
							dtype = torch.float32,
 | 
				
			||||||
		device = compute_device
 | 
							device = compute_device
 | 
				
			||||||
	).unsqueeze(0)
 | 
						).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						action = None
 | 
				
			||||||
 | 
						while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
 | 
				
			||||||
		action = select_action(
 | 
							action = select_action(
 | 
				
			||||||
			pt_state,
 | 
								pt_state,
 | 
				
			||||||
			steps_done
 | 
								steps_done
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
 | 
							str_action = Celeste.action_space[action]
 | 
				
			||||||
	steps_done += 1
 | 
						steps_done += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	# Turn number into action string
 | 
					 | 
				
			||||||
	str_action = Celeste.action_space[action]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						# For manual testing
 | 
				
			||||||
 | 
						#str_action = ""
 | 
				
			||||||
 | 
						#while str_action not in Celeste.action_space:
 | 
				
			||||||
 | 
						#	str_action = input("action> ")
 | 
				
			||||||
 | 
						#action = Celeste.action_space.index(str_action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						print(str_action)
 | 
				
			||||||
	celeste.act(str_action)
 | 
						celeste.act(str_action)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return state, action
 | 
						return state, action
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					image_interval = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def on_state_after(celeste, before_out):
 | 
					def on_state_after(celeste, before_out):
 | 
				
			||||||
 | 
						global episode_number
 | 
				
			||||||
 | 
						global image_count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	state, action = before_out
 | 
						state, action = before_out
 | 
				
			||||||
 | 
						next_state = celeste.state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pt_state = torch.tensor(
 | 
						pt_state = torch.tensor(
 | 
				
			||||||
		[state[x] for x in state_number_map],
 | 
							[getattr(state, x) for x in Celeste.state_number_map],
 | 
				
			||||||
		dtype = torch.float32,
 | 
							dtype = torch.float32,
 | 
				
			||||||
		device = compute_device
 | 
							device = compute_device
 | 
				
			||||||
	).unsqueeze(0)
 | 
						).unsqueeze(0)
 | 
				
			||||||
@@ -346,33 +376,30 @@ def on_state_after(celeste, before_out):
 | 
				
			|||||||
		dtype = torch.long
 | 
							dtype = torch.long
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	next_state = celeste.status
 | 
						if next_state.deaths != 0:
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if next_state["deaths"] != 0:
 | 
					 | 
				
			||||||
		pt_next_state = None
 | 
							pt_next_state = None
 | 
				
			||||||
		reward = 0
 | 
							reward = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	else:
 | 
						else:
 | 
				
			||||||
		pt_next_state = torch.tensor(
 | 
							pt_next_state = torch.tensor(
 | 
				
			||||||
			[next_state[x] for x in state_number_map],
 | 
								[getattr(next_state, x) for x in Celeste.state_number_map],
 | 
				
			||||||
			dtype = torch.float32,
 | 
								dtype = torch.float32,
 | 
				
			||||||
			device = compute_device
 | 
								device = compute_device
 | 
				
			||||||
		).unsqueeze(0)
 | 
							).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if state["next_point"] == next_state["next_point"]:
 | 
							if state.next_point == next_state.next_point:
 | 
				
			||||||
			reward = state["dist"] - next_state["dist"]
 | 
								reward = state.dist - next_state.dist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if reward > 0:
 | 
								# Clip rewards that are too large
 | 
				
			||||||
 | 
								if reward > 1:
 | 
				
			||||||
				reward = 1
 | 
									reward = 1
 | 
				
			||||||
			elif reward < 0:
 | 
					 | 
				
			||||||
				reward = -1
 | 
					 | 
				
			||||||
			else:
 | 
								else:
 | 
				
			||||||
				reward = 0
 | 
									reward = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		else:
 | 
							else:
 | 
				
			||||||
			# Score for reaching a point
 | 
								# Score for reaching a point
 | 
				
			||||||
			reward = 10
 | 
								reward = 1
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	pt_reward = torch.tensor([reward], device = compute_device)
 | 
						pt_reward = torch.tensor([reward], device = compute_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -387,6 +414,8 @@ def on_state_after(celeste, before_out):
 | 
				
			|||||||
		)
 | 
							)
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						print("==> ", int(reward))
 | 
				
			||||||
 | 
						print("\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	# Only train the network if we have enough
 | 
						# Only train the network if we have enough
 | 
				
			||||||
@@ -406,8 +435,51 @@ def on_state_after(celeste, before_out):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	# Move on to the next episode once we reach
 | 
						# Move on to the next episode once we reach
 | 
				
			||||||
	# a terminal state.
 | 
						# a terminal state.
 | 
				
			||||||
	if (next_state["deaths"] != 0):
 | 
						if (next_state.deaths != 0):
 | 
				
			||||||
 | 
							s = celeste.state
 | 
				
			||||||
 | 
							with open(run_data_path / "train.log", "a") as f:
 | 
				
			||||||
 | 
								f.write(json.dumps({
 | 
				
			||||||
 | 
									"checkpoints": s.next_point,
 | 
				
			||||||
 | 
									"state_count": s.state_count
 | 
				
			||||||
 | 
								}) + "\n")
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Save model
 | 
				
			||||||
 | 
							torch.save({
 | 
				
			||||||
 | 
								"policy_state_dict": policy_net.state_dict(),
 | 
				
			||||||
 | 
								"target_state_dict": target_net.state_dict(),
 | 
				
			||||||
 | 
								"optimizer_state_dict": optimizer.state_dict(),
 | 
				
			||||||
 | 
								"memory": memory,
 | 
				
			||||||
 | 
								"episode_number": episode_number,
 | 
				
			||||||
 | 
								"steps_done": steps_done
 | 
				
			||||||
 | 
							}, run_data_path / "checkpoint.torch")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Clean up screenshots
 | 
				
			||||||
 | 
							shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							target = run_data_path / Path(f"screenshots/{episode_number}")
 | 
				
			||||||
 | 
							target.mkdir(parents = True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for s in shots:
 | 
				
			||||||
 | 
								s.rename(target / s.name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							# Save a prediction graph
 | 
				
			||||||
 | 
							if episode_number % image_interval == 0:
 | 
				
			||||||
 | 
								p =	run_data_path / Path("model_images")
 | 
				
			||||||
 | 
								p.mkdir(parents = True, exist_ok = True)
 | 
				
			||||||
 | 
								torch.save({
 | 
				
			||||||
 | 
									"policy_state_dict": policy_net.state_dict(),
 | 
				
			||||||
 | 
									"target_state_dict": target_net.state_dict(),
 | 
				
			||||||
 | 
									"optimizer_state_dict": optimizer.state_dict(),
 | 
				
			||||||
 | 
									"memory": memory,
 | 
				
			||||||
 | 
									"episode_number": episode_number,
 | 
				
			||||||
 | 
									"steps_done": steps_done
 | 
				
			||||||
 | 
								}, p / f"{episode_number}.torch")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		print("State over, resetting")
 | 
							print("State over, resetting")
 | 
				
			||||||
 | 
							episode_number += 1
 | 
				
			||||||
		celeste.reset()
 | 
							celeste.reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user