Renamed files, added random motion
parent
d6452f5ed8
commit
fd02c65b41
|
@ -0,0 +1,230 @@
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import math
|
||||||
|
|
||||||
|
class CelesteError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Celeste:
|
||||||
|
action_space = [
|
||||||
|
"left", # move left
|
||||||
|
"right", # move right
|
||||||
|
"jump", # jump
|
||||||
|
|
||||||
|
"dash-u", # dash up
|
||||||
|
"dash-r", # dash right
|
||||||
|
"dash-l", # dash left
|
||||||
|
"dash-ru", # dash right-up
|
||||||
|
"dash-lu" # dash left-up
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, on_get_state):
|
||||||
|
|
||||||
|
self.on_get_state = on_get_state
|
||||||
|
|
||||||
|
# Start pico-8
|
||||||
|
self.process = subprocess.Popen(
|
||||||
|
"bin/pico-8/linux/pico8",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for window to open and get window id
|
||||||
|
time.sleep(2)
|
||||||
|
winid = subprocess.check_output([
|
||||||
|
"xdotool",
|
||||||
|
"search",
|
||||||
|
"--class",
|
||||||
|
"pico8"
|
||||||
|
]).decode("utf-8").strip().split("\n")
|
||||||
|
if len(winid) != 1:
|
||||||
|
raise Exception("Could not find unique PICO-8 window id")
|
||||||
|
self.winid = winid[0]
|
||||||
|
|
||||||
|
# Load cartridge
|
||||||
|
self.keystring("load hackcel.p8")
|
||||||
|
self.keypress("Enter")
|
||||||
|
self.keystring("run")
|
||||||
|
self.keypress("Enter", post = 1000)
|
||||||
|
|
||||||
|
# Initialize variables
|
||||||
|
self.internal_status = {}
|
||||||
|
self.dead = False
|
||||||
|
|
||||||
|
# Score system
|
||||||
|
self.frame_counter = 0
|
||||||
|
self.next_point = 0
|
||||||
|
self.dist = 0 # distance to next point
|
||||||
|
self.target_points = [
|
||||||
|
[ # Stage 1
|
||||||
|
(28, 88), # Start pillar
|
||||||
|
(60, 80), # Middle pillar
|
||||||
|
(105, 64), # Right ledge
|
||||||
|
(25, 40), # Left ledge
|
||||||
|
(110, 16), # End ledge
|
||||||
|
(110, -2), # Next stage
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def act(self, action):
|
||||||
|
self.keyup("x")
|
||||||
|
self.keyup("c")
|
||||||
|
self.keyup("Left")
|
||||||
|
self.keyup("Right")
|
||||||
|
self.keyup("Down")
|
||||||
|
self.keyup("Up")
|
||||||
|
|
||||||
|
if action is None:
|
||||||
|
return
|
||||||
|
elif action == "left":
|
||||||
|
self.keydown("Left")
|
||||||
|
elif action == "right":
|
||||||
|
self.keydown("Right")
|
||||||
|
elif action == "jump":
|
||||||
|
self.keydown("c")
|
||||||
|
|
||||||
|
elif action == "dash-u":
|
||||||
|
self.keydown("Up")
|
||||||
|
self.keydown("x")
|
||||||
|
elif action == "dash-r":
|
||||||
|
self.keydown("Right")
|
||||||
|
self.keydown("x")
|
||||||
|
elif action == "dash-l":
|
||||||
|
self.keydown("Left")
|
||||||
|
self.keydown("x")
|
||||||
|
elif action == "dash-ru":
|
||||||
|
self.keydown("Up")
|
||||||
|
self.keydown("Right")
|
||||||
|
self.keydown("x")
|
||||||
|
elif action == "dash-lu":
|
||||||
|
self.keydown("Up")
|
||||||
|
self.keydown("Left")
|
||||||
|
self.keydown("x")
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self):
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"stage": (
|
||||||
|
[
|
||||||
|
[0, 1, 2, 3, 4]
|
||||||
|
]
|
||||||
|
[int(self.internal_status["ry"])]
|
||||||
|
[int(self.internal_status["rx"])]
|
||||||
|
),
|
||||||
|
|
||||||
|
"xpos": int(self.internal_status["px"]),
|
||||||
|
"ypos": int(self.internal_status["py"]),
|
||||||
|
"xvel": float(self.internal_status["vx"]),
|
||||||
|
"yvel": float(self.internal_status["vy"]),
|
||||||
|
"deaths": int(self.internal_status["dc"]),
|
||||||
|
|
||||||
|
"dist": self.dist,
|
||||||
|
"next_point": self.next_point,
|
||||||
|
"frame_count": self.frame_counter
|
||||||
|
}
|
||||||
|
except KeyError:
|
||||||
|
raise CelesteError("Not enough data to get status.")
|
||||||
|
|
||||||
|
|
||||||
|
def keypress(self, key: str, *, post = 200):
|
||||||
|
subprocess.run([
|
||||||
|
"xdotool",
|
||||||
|
"key",
|
||||||
|
"--window", self.winid,
|
||||||
|
key
|
||||||
|
])
|
||||||
|
time.sleep(post / 1000)
|
||||||
|
|
||||||
|
def keydown(self, key: str):
|
||||||
|
subprocess.run([
|
||||||
|
"xdotool",
|
||||||
|
"keydown",
|
||||||
|
"--window", self.winid,
|
||||||
|
key
|
||||||
|
])
|
||||||
|
|
||||||
|
def keyup(self, key: str):
|
||||||
|
subprocess.run([
|
||||||
|
"xdotool",
|
||||||
|
"keyup",
|
||||||
|
"--window", self.winid,
|
||||||
|
key
|
||||||
|
])
|
||||||
|
|
||||||
|
def keystring(self, string, *, delay = 100, post = 200):
|
||||||
|
subprocess.run([
|
||||||
|
"xdotool",
|
||||||
|
"type",
|
||||||
|
"--window", self.winid,
|
||||||
|
"--delay", str(delay),
|
||||||
|
string
|
||||||
|
])
|
||||||
|
time.sleep(post / 1000)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.internal_status = {}
|
||||||
|
self.next_point = 0
|
||||||
|
self.frame_counter = 0
|
||||||
|
|
||||||
|
self.keypress("Escape")
|
||||||
|
self.keystring("run")
|
||||||
|
self.keypress("Enter", post = 1000)
|
||||||
|
self.dead = False
|
||||||
|
|
||||||
|
def flush_reader(self):
|
||||||
|
for k in iter(self.process.stdout.readline, ""):
|
||||||
|
k = k.decode("utf-8")[:-1]
|
||||||
|
if k == "!RESTART":
|
||||||
|
break
|
||||||
|
|
||||||
|
def update_loop(self):
|
||||||
|
|
||||||
|
# Get state, call callback, wait for state
|
||||||
|
# One line => one frame.
|
||||||
|
|
||||||
|
for line in iter(self.process.stdout.readline, ""):
|
||||||
|
l = line.decode("utf-8")[:-1].strip()
|
||||||
|
|
||||||
|
# This should only occur at game start
|
||||||
|
if l in ["!RESTART"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.frame_counter += 1
|
||||||
|
|
||||||
|
# Parse status string
|
||||||
|
for entry in l.split(";"):
|
||||||
|
if entry == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
key, val = entry.split(":")
|
||||||
|
self.internal_status[key] = val
|
||||||
|
|
||||||
|
|
||||||
|
# Update checkpoints
|
||||||
|
|
||||||
|
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
||||||
|
x = self.status["xpos"]
|
||||||
|
y = self.status["ypos"]
|
||||||
|
dist = math.sqrt(
|
||||||
|
(x-tx)*(x-tx) +
|
||||||
|
(y-ty)*(y-ty)
|
||||||
|
)
|
||||||
|
|
||||||
|
if dist <= 4 and y == ty:
|
||||||
|
self.next_point += 1
|
||||||
|
|
||||||
|
# Recalculate distance to new point
|
||||||
|
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
||||||
|
dist = math.sqrt(
|
||||||
|
(x-tx)*(x-tx) +
|
||||||
|
(y-ty)*(y-ty)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dist = dist
|
||||||
|
|
||||||
|
# Call step callback
|
||||||
|
self.on_get_state(self)
|
|
@ -1,213 +1,179 @@
|
||||||
import subprocess
|
from collections import namedtuple
|
||||||
import time
|
from collections import deque
|
||||||
import threading
|
import random
|
||||||
import math
|
import math
|
||||||
|
|
||||||
class Celeste:
|
import torch
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
|
|
||||||
# Start process
|
# Glue layer
|
||||||
self.process = subprocess.Popen(
|
from celeste import Celeste
|
||||||
"bin/pico-8/linux/pico8",
|
|
||||||
shell=True,
|
|
||||||
stdout=subprocess.PIPE,
|
compute_device = torch.device(
|
||||||
stderr=subprocess.STDOUT
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Epsilon-greedy parameters
|
||||||
|
#
|
||||||
|
# Original docs:
|
||||||
|
# EPS_START is the starting value of epsilon
|
||||||
|
# EPS_END is the final value of epsilon
|
||||||
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||||
|
EPS_START = 0.9
|
||||||
|
EPS_END = 0.05
|
||||||
|
EPS_DECAY = 1000
|
||||||
|
|
||||||
|
|
||||||
|
# Outline our network
|
||||||
|
class DQN(torch.nn.Module):
|
||||||
|
def __init__(self, n_observations: int, n_actions: int):
|
||||||
|
super(DQN, self).__init__()
|
||||||
|
self.layer1 = torch.nn.Linear(n_observations, 128)
|
||||||
|
self.layer2 = torch.nn.Linear(128, 128)
|
||||||
|
self.layer3 = torch.nn.Linear(128, n_actions)
|
||||||
|
|
||||||
|
# Can be called with one input, or with a batch.
|
||||||
|
#
|
||||||
|
# Returns tensor(
|
||||||
|
# [ Q(s, left), Q(s, right) ], ...
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# Recall that Q(s, a) is the (expected) return of taking
|
||||||
|
# action `a` at state `s`
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.nn.functional.relu(self.layer1(x))
|
||||||
|
x = torch.nn.functional.relu(self.layer2(x))
|
||||||
|
return self.layer3(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Celeste env properties
|
||||||
|
n_observations = 4
|
||||||
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
policy_net = DQN(
|
||||||
|
n_observations,
|
||||||
|
n_actions
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
|
|
||||||
|
def select_action(state, steps_done):
|
||||||
|
"""
|
||||||
|
Select an action using an epsilon-greedy policy.
|
||||||
|
|
||||||
|
Sometimes use our model, sometimes sample one uniformly.
|
||||||
|
|
||||||
|
P(random action) starts at EPS_START and decays to EPS_END.
|
||||||
|
Decay rate is controlled by EPS_DECAY.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Random number 0 <= x < 1
|
||||||
|
sample = random.random()
|
||||||
|
|
||||||
|
# Calculate random step threshhold
|
||||||
|
eps_threshold = (
|
||||||
|
EPS_END + (EPS_START - EPS_END) *
|
||||||
|
math.exp(
|
||||||
|
-1.0 * steps_done /
|
||||||
|
EPS_DECAY
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for window to open and get window id
|
|
||||||
time.sleep(2)
|
|
||||||
winid = subprocess.check_output([
|
|
||||||
"xdotool",
|
|
||||||
"search",
|
|
||||||
"--class",
|
|
||||||
"pico8"
|
|
||||||
]).decode("utf-8").strip().split("\n")
|
|
||||||
if len(winid) != 1:
|
|
||||||
raise Exception("Could not find unique PICO-8 window id")
|
|
||||||
self.winid = winid[0]
|
|
||||||
|
|
||||||
# Load cartridge
|
|
||||||
self.keystring("load hackcel.p8")
|
|
||||||
self.keypress("Enter")
|
|
||||||
self.keystring("run")
|
|
||||||
self.keypress("Enter", post = 1000)
|
|
||||||
|
|
||||||
# Initialize variables
|
|
||||||
self.internal_status = {}
|
|
||||||
self.dead = False
|
|
||||||
|
|
||||||
# -1: left
|
|
||||||
# 0: not moving
|
|
||||||
# 1: moving right
|
|
||||||
self.moving = 0
|
|
||||||
|
|
||||||
# Start state update thread
|
|
||||||
self.update_thread = threading.Thread(target = self._update_loop)
|
|
||||||
self.update_thread.start()
|
|
||||||
|
|
||||||
def act(self, action):
|
|
||||||
self.keyup("x")
|
|
||||||
self.keyup("c")
|
|
||||||
self.keyup("Down")
|
|
||||||
self.keyup("Up")
|
|
||||||
if self.moving != -1:
|
|
||||||
self.keyup("Left")
|
|
||||||
if self.moving != 1:
|
|
||||||
self.keyup("Right")
|
|
||||||
|
|
||||||
if action is None:
|
|
||||||
self.moving = 0
|
|
||||||
self.keyup("Left")
|
|
||||||
self.keyup("Right")
|
|
||||||
elif action == "left":
|
|
||||||
if self.moving != -1:
|
|
||||||
self.keydown("Left")
|
|
||||||
self.moving = -1
|
|
||||||
elif action == "right":
|
|
||||||
if self.moving != 1:
|
|
||||||
self.keydown("Right")
|
|
||||||
self.moving = 1
|
|
||||||
|
|
||||||
|
|
||||||
@property
|
|
||||||
def status(self):
|
|
||||||
return {
|
|
||||||
"stage": (
|
|
||||||
[
|
|
||||||
[0, 1, 2, 3, 4]
|
|
||||||
]
|
|
||||||
[int(self.internal_status["ry"])]
|
|
||||||
[int(self.internal_status["rx"])]
|
|
||||||
),
|
|
||||||
|
|
||||||
"xpos": int(self.internal_status["px"]),
|
|
||||||
"ypos": int(self.internal_status["py"]),
|
|
||||||
"xvel": float(self.internal_status["vx"]),
|
|
||||||
"yvel": float(self.internal_status["vy"])
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Possible actions
|
|
||||||
@property
|
|
||||||
def action_space(self):
|
|
||||||
return [
|
|
||||||
"left", # move left
|
|
||||||
"rght", # move right
|
|
||||||
"jump", # jump
|
|
||||||
|
|
||||||
"dshn", # dash north
|
|
||||||
"dshe", # dash east
|
|
||||||
"dshw", # dash west
|
|
||||||
"dsne", # dash north-east
|
|
||||||
"dsnw" # dash north-west
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def keypress(self, key: str, *, post = 200):
|
|
||||||
subprocess.run([
|
|
||||||
"xdotool",
|
|
||||||
"key",
|
|
||||||
"--window", self.winid,
|
|
||||||
key
|
|
||||||
])
|
|
||||||
time.sleep(post / 1000)
|
|
||||||
|
|
||||||
def keydown(self, key: str):
|
|
||||||
subprocess.run([
|
|
||||||
"xdotool",
|
|
||||||
"keydown",
|
|
||||||
"--window", self.winid,
|
|
||||||
key
|
|
||||||
])
|
|
||||||
|
|
||||||
def keyup(self, key: str):
|
|
||||||
subprocess.run([
|
|
||||||
"xdotool",
|
|
||||||
"keyup",
|
|
||||||
"--window", self.winid,
|
|
||||||
key
|
|
||||||
])
|
|
||||||
|
|
||||||
def keystring(self, string, *, delay = 100, post = 200):
|
|
||||||
subprocess.run([
|
|
||||||
"xdotool",
|
|
||||||
"type",
|
|
||||||
"--window", self.winid,
|
|
||||||
"--delay", str(delay),
|
|
||||||
string
|
|
||||||
])
|
|
||||||
time.sleep(post / 1000)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.internal_status = {}
|
|
||||||
if not self.dead:
|
|
||||||
self.keypress("Escape")
|
|
||||||
self.keystring("run")
|
|
||||||
self.keypress("Enter", post = 1000)
|
|
||||||
self.dead = False
|
|
||||||
|
|
||||||
def _update_loop(self):
|
|
||||||
# Poll process for new output until finished
|
|
||||||
for line in iter(self.process.stdout.readline, ""):
|
|
||||||
l = line.decode("utf-8")[:-1]
|
|
||||||
|
|
||||||
if l in ["!RESTART"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for entry in l.split(";"):
|
|
||||||
key, val = entry.split(":")
|
|
||||||
self.internal_status[key] = val
|
|
||||||
|
|
||||||
# Exit game on death
|
|
||||||
if "dc" in self.internal_status and self.internal_status["dc"] != "0":
|
|
||||||
self.keypress("Escape")
|
|
||||||
self.dead = True
|
|
||||||
|
|
||||||
# Flush stream reader
|
|
||||||
for k in iter(self.process.stdout.readline, ""):
|
|
||||||
k = k.decode("utf-8")[:-1]
|
|
||||||
if k == "!RESTART":
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
# Stage 1:
|
|
||||||
|
|
||||||
|
|
||||||
next_point = 0
|
|
||||||
target_points = [
|
|
||||||
(28, 88), # Start pillar
|
|
||||||
(60, 80), # Middle pillar
|
|
||||||
(105, 64), # Right ledge
|
|
||||||
(25, 40), # Left ledge
|
|
||||||
(110, 16), # End ledge
|
|
||||||
(110, -2), # Next stage
|
|
||||||
]
|
|
||||||
|
|
||||||
# += 5
|
|
||||||
|
|
||||||
c = Celeste()
|
|
||||||
while True:
|
|
||||||
if c.dead:
|
|
||||||
print("\n\nDead, resetting...")
|
|
||||||
c.reset()
|
|
||||||
|
|
||||||
|
|
||||||
tx, ty = target_points[next_point]
|
|
||||||
x = c.status["xpos"]
|
|
||||||
y = c.status["ypos"]
|
|
||||||
|
|
||||||
dist = math.sqrt(
|
|
||||||
(x-tx)*(x-tx) +
|
|
||||||
(y-ty)*(y-ty)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if dist <= 4 and y == ty:
|
if sample > eps_threshold:
|
||||||
next_point += 1
|
with torch.no_grad():
|
||||||
|
# t.max(1) will return the largest column value of each row.
|
||||||
|
# second column on max result is index of where max element was
|
||||||
|
# found, so we pick action with the larger expected reward.
|
||||||
|
return policy_net(state).max(1)[1].view(1, 1).item()
|
||||||
|
|
||||||
print(f"Target point: {next_point:02}, Dist: {dist:0.3}")
|
else:
|
||||||
|
return random.randint( 0, n_actions-1 )
|
||||||
|
|
||||||
#print()
|
|
||||||
#print(c.status)
|
|
||||||
|
|
||||||
|
last_state = None
|
||||||
|
|
||||||
|
|
||||||
|
Transition = namedtuple(
|
||||||
|
"Transition",
|
||||||
|
(
|
||||||
|
"state",
|
||||||
|
"action",
|
||||||
|
"next_state",
|
||||||
|
"reward"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def on_state(celeste):
|
||||||
|
global last_state
|
||||||
|
|
||||||
|
s = celeste.status
|
||||||
|
|
||||||
|
if last_state is None:
|
||||||
|
last_state = s
|
||||||
|
return
|
||||||
|
|
||||||
|
s_next = s["next_point"]
|
||||||
|
s_dist = s["dist"]
|
||||||
|
l_next = last_state["next_point"]
|
||||||
|
l_dist = last_state["dist"]
|
||||||
|
|
||||||
|
|
||||||
|
if l_next == s_next:
|
||||||
|
reward = l_dist - s_dist
|
||||||
|
else:
|
||||||
|
reward = 10
|
||||||
|
|
||||||
|
dead = s["deaths"] != 0
|
||||||
|
frame_count = s["frame_count"]
|
||||||
|
|
||||||
|
# Values at this point
|
||||||
|
# reward: reward for last action
|
||||||
|
# dead: true if game over
|
||||||
|
|
||||||
|
state_number_map = [
|
||||||
|
"xpos",
|
||||||
|
"ypos",
|
||||||
|
"xvel",
|
||||||
|
"yvel"
|
||||||
|
]
|
||||||
|
|
||||||
|
tf_state = torch.tensor(
|
||||||
|
[s[x] for x in state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
tf_last = torch.tensor(
|
||||||
|
[last_state[x] for x in state_number_map],
|
||||||
|
dtype = torch.float32,
|
||||||
|
device = compute_device
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
action = select_action(
|
||||||
|
tf_state,
|
||||||
|
frame_count
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turn number into action string
|
||||||
|
action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
celeste.act(action)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Update previous state
|
||||||
|
last_state = s
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
c = Celeste(
|
||||||
|
on_state
|
||||||
|
)
|
||||||
|
|
||||||
|
c.update_loop()
|
||||||
|
|
Reference in New Issue