Cleaned up celeste wrapper
parent
85d8c7a300
commit
610e5eef92
|
@ -1,12 +1,44 @@
|
||||||
|
from typing import NamedTuple
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import threading
|
|
||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
class CelesteError(Exception):
|
class CelesteError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CelesteState(NamedTuple):
|
||||||
|
# Stage number
|
||||||
|
stage: int
|
||||||
|
|
||||||
|
# Player position
|
||||||
|
xpos: int
|
||||||
|
ypos: int
|
||||||
|
|
||||||
|
# Player velocity
|
||||||
|
xvel: float
|
||||||
|
yvel: float
|
||||||
|
|
||||||
|
# Number of deaths since game start
|
||||||
|
deaths: int
|
||||||
|
|
||||||
|
# Distance to next point
|
||||||
|
dist: float
|
||||||
|
|
||||||
|
# Index of next point
|
||||||
|
next_point: int
|
||||||
|
|
||||||
|
# Coordinates of next point
|
||||||
|
next_point_x: int
|
||||||
|
next_point_y: int
|
||||||
|
|
||||||
|
# Number of states recieved since restart
|
||||||
|
state_count: int
|
||||||
|
|
||||||
|
# True if Madeline can dash
|
||||||
|
can_dash: bool
|
||||||
|
|
||||||
|
|
||||||
class Celeste:
|
class Celeste:
|
||||||
action_space = [
|
action_space = [
|
||||||
"left", # move left
|
"left", # move left
|
||||||
|
@ -20,10 +52,25 @@ class Celeste:
|
||||||
"dash-lu" # dash left-up
|
"dash-lu" # dash left-up
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
# Map integers to state values.
|
||||||
|
# This also determines what data is fed to the model.
|
||||||
|
state_number_map = [
|
||||||
|
"xpos",
|
||||||
|
"ypos",
|
||||||
|
"next_point_x",
|
||||||
|
"next_point_y"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
state_timeout = 30,
|
||||||
|
cart_name = "hackcel.p8"
|
||||||
|
):
|
||||||
|
|
||||||
# Start pico-8
|
# Start pico-8
|
||||||
self.process = subprocess.Popen(
|
self._process = subprocess.Popen(
|
||||||
"bin/pico-8/linux/pico8",
|
"resources/pico-8/linux/pico8",
|
||||||
shell=True,
|
shell=True,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.STDOUT
|
stderr=subprocess.STDOUT
|
||||||
|
@ -39,26 +86,34 @@ class Celeste:
|
||||||
]).decode("utf-8").strip().split("\n")
|
]).decode("utf-8").strip().split("\n")
|
||||||
if len(winid) != 1:
|
if len(winid) != 1:
|
||||||
raise Exception("Could not find unique PICO-8 window id")
|
raise Exception("Could not find unique PICO-8 window id")
|
||||||
self.winid = winid[0]
|
self._winid = winid[0]
|
||||||
|
|
||||||
# Load cartridge
|
# Load cartridge
|
||||||
self.keystring("load hackcel.p8")
|
self._keystring(f"load {cart_name}")
|
||||||
self.keypress("Enter")
|
self._keypress("Enter")
|
||||||
self.keystring("run")
|
self._keystring("run")
|
||||||
self.keypress("Enter", post = 1000)
|
self._keypress("Enter", post = 1000)
|
||||||
|
|
||||||
# Initialize variables
|
|
||||||
self.internal_status = {}
|
|
||||||
self.before_out = None
|
|
||||||
self.last_point_frame = 0
|
|
||||||
|
|
||||||
# Score system
|
# Parameters
|
||||||
self.frame_counter = 0
|
self.state_timeout = state_timeout # If we run this many states without getting a checkpoint, reset.
|
||||||
self.next_point = 0
|
self.cart_name = cart_name # Name of cart to load. Not used anywhere, but saved for convenience.
|
||||||
self.dist = 0 # distance to next point
|
|
||||||
self.target_points = [
|
# Internal variables
|
||||||
|
self._internal_state = {} # Raw data read from stdout
|
||||||
|
self._before_out = None # Output of "before" callback in update loop
|
||||||
|
self._last_checkpoint_state = 0 # Index of frame at which we reached the last checkpoint
|
||||||
|
self._state_counter = 0 # Number of frames we've run since last reset
|
||||||
|
self._next_checkpoint_idx = 0 # Index of next point
|
||||||
|
self._dist = 0 # Distance to next point
|
||||||
|
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||||
|
self._keys = {} # Dictionary of "key": bool
|
||||||
|
|
||||||
|
# Targets the agent tries to reach.
|
||||||
|
# The last target MUST be outside the frame.
|
||||||
|
self.target_checkpoints = [
|
||||||
[ # Stage 1
|
[ # Stage 1
|
||||||
(28, 88), # Start pillar
|
#(28, 88), # Start pillar
|
||||||
(60, 80), # Middle pillar
|
(60, 80), # Middle pillar
|
||||||
(105, 64), # Right ledge
|
(105, 64), # Right ledge
|
||||||
(25, 40), # Left ledge
|
(25, 40), # Left ledge
|
||||||
|
@ -67,119 +122,150 @@ class Celeste:
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
def act(self, action):
|
def act(self, action: str):
|
||||||
self.keyup("x")
|
"""
|
||||||
self.keyup("c")
|
Specify what keys should be down. This does NOT send key events.
|
||||||
self.keyup("Left")
|
Celeste._apply_keys() does that at the right time.
|
||||||
self.keyup("Right")
|
|
||||||
self.keyup("Down")
|
|
||||||
self.keyup("Up")
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (str): key name, as in Celeste.action_space
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._keys = {}
|
||||||
if action is None:
|
if action is None:
|
||||||
return
|
return
|
||||||
elif action == "left":
|
elif action == "left":
|
||||||
self.keydown("Left")
|
self._keys["Left"] = True
|
||||||
elif action == "right":
|
elif action == "right":
|
||||||
self.keydown("Right")
|
self._keys["Right"] = True
|
||||||
elif action == "jump":
|
elif action == "jump":
|
||||||
self.keydown("c")
|
self._keys["c"] = True
|
||||||
|
|
||||||
elif action == "dash-u":
|
elif action == "dash-u":
|
||||||
self.keydown("Up")
|
self._keys["Up"] = True
|
||||||
self.keydown("x")
|
self._keys["x"] = True
|
||||||
elif action == "dash-r":
|
elif action == "dash-r":
|
||||||
self.keydown("Right")
|
self._keys["Right"] = True
|
||||||
self.keydown("x")
|
self._keys["x"] = True
|
||||||
elif action == "dash-l":
|
elif action == "dash-l":
|
||||||
self.keydown("Left")
|
self._keys["Left"] = True
|
||||||
self.keydown("x")
|
self._keys["x"] = True
|
||||||
elif action == "dash-ru":
|
elif action == "dash-ru":
|
||||||
self.keydown("Up")
|
self._keys["Up"] = True
|
||||||
self.keydown("Right")
|
self._keys["Right"] = True
|
||||||
self.keydown("x")
|
self._keys["x"] = True
|
||||||
elif action == "dash-lu":
|
elif action == "dash-lu":
|
||||||
self.keydown("Up")
|
self._keys["Up"] = True
|
||||||
self.keydown("Left")
|
self._keys["Left"] = True
|
||||||
self.keydown("x")
|
self._keys["x"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_keys(self):
|
||||||
|
for i in [
|
||||||
|
"x", "c",
|
||||||
|
"Left", "Right",
|
||||||
|
"Down", "Up"
|
||||||
|
]:
|
||||||
|
if self._keys.get(i):
|
||||||
|
self._keydown(i)
|
||||||
|
else:
|
||||||
|
self._keyup(i)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def status(self):
|
def state(self):
|
||||||
try:
|
try:
|
||||||
return {
|
stage = (
|
||||||
"stage": (
|
|
||||||
[
|
[
|
||||||
[0, 1, 2, 3, 4]
|
[0, 1, 2, 3, 4]
|
||||||
]
|
]
|
||||||
[int(self.internal_status["ry"])]
|
[int(self._internal_state["ry"])]
|
||||||
[int(self.internal_status["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
),
|
)
|
||||||
|
|
||||||
"xpos": int(self.internal_status["px"]),
|
if len(self.target_checkpoints) < stage:
|
||||||
"ypos": int(self.internal_status["py"]),
|
next_point_x = None
|
||||||
"xvel": float(self.internal_status["vx"]),
|
next_point_y = None
|
||||||
"yvel": float(self.internal_status["vy"]),
|
else:
|
||||||
"deaths": int(self.internal_status["dc"]),
|
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||||
|
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||||
|
|
||||||
|
|
||||||
|
return CelesteState(
|
||||||
|
stage = stage,
|
||||||
|
|
||||||
|
xpos = int(self._internal_state["px"]),
|
||||||
|
ypos = int(self._internal_state["py"]),
|
||||||
|
xvel = float(self._internal_state["vx"]),
|
||||||
|
yvel = float(self._internal_state["vy"]),
|
||||||
|
deaths = int(self._internal_state["dc"]),
|
||||||
|
|
||||||
|
dist = self._dist,
|
||||||
|
next_point = self._next_checkpoint_idx,
|
||||||
|
next_point_x = next_point_x,
|
||||||
|
next_point_y = next_point_y,
|
||||||
|
state_count = self._state_counter,
|
||||||
|
can_dash = self._internal_state["ds"] == "t"
|
||||||
|
)
|
||||||
|
|
||||||
"dist": self.dist,
|
|
||||||
"next_point": self.next_point,
|
|
||||||
"frame_count": self.frame_counter
|
|
||||||
}
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise CelesteError("Not enough data to get status.")
|
raise CelesteError("Not enough data to get state.")
|
||||||
|
|
||||||
|
def _keypress(self, key: str, *, post = 200):
|
||||||
def keypress(self, key: str, *, post = 200):
|
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
"xdotool",
|
"xdotool",
|
||||||
"key",
|
"key",
|
||||||
"--window", self.winid,
|
"--window", self._winid,
|
||||||
key
|
key
|
||||||
])
|
])
|
||||||
time.sleep(post / 1000)
|
time.sleep(post / 1000)
|
||||||
|
|
||||||
def keydown(self, key: str):
|
def _keydown(self, key: str):
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
"xdotool",
|
"xdotool",
|
||||||
"keydown",
|
"keydown",
|
||||||
"--window", self.winid,
|
"--window", self._winid,
|
||||||
key
|
key
|
||||||
])
|
])
|
||||||
|
|
||||||
def keyup(self, key: str):
|
def _keyup(self, key: str):
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
"xdotool",
|
"xdotool",
|
||||||
"keyup",
|
"keyup",
|
||||||
"--window", self.winid,
|
"--window", self._winid,
|
||||||
key
|
key
|
||||||
])
|
])
|
||||||
|
|
||||||
def keystring(self, string, *, delay = 100, post = 200):
|
def _keystring(self, string, *, delay = 100, post = 200):
|
||||||
subprocess.run([
|
subprocess.run([
|
||||||
"xdotool",
|
"xdotool",
|
||||||
"type",
|
"type",
|
||||||
"--window", self.winid,
|
"--window", self._winid,
|
||||||
"--delay", str(delay),
|
"--delay", str(delay),
|
||||||
string
|
string
|
||||||
])
|
])
|
||||||
time.sleep(post / 1000)
|
time.sleep(post / 1000)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.internal_status = {}
|
# Make sure all keys are released
|
||||||
self.next_point = 0
|
self.act(None)
|
||||||
self.frame_counter = 0
|
self._apply_keys()
|
||||||
self.before_out = None
|
|
||||||
self.resetting = True
|
|
||||||
self.last_point_frame = 0
|
|
||||||
|
|
||||||
self.keypress("Escape")
|
self._internal_state = {}
|
||||||
self.keystring("run")
|
self._next_checkpoint_idx = 0
|
||||||
self.keypress("Enter", post = 1000)
|
self._state_counter = 0
|
||||||
|
self._before_out = None
|
||||||
|
self._resetting = True
|
||||||
|
self._last_checkpoint_state = 0
|
||||||
|
|
||||||
self.flush_reader()
|
self._keypress("Escape")
|
||||||
|
self._keystring("run")
|
||||||
|
self._keypress("Enter", post = 1000)
|
||||||
|
|
||||||
def flush_reader(self):
|
|
||||||
for k in iter(self.process.stdout.readline, ""):
|
|
||||||
|
# Clear all old stdout messages and
|
||||||
|
# wait for the game to restart.
|
||||||
|
for k in iter(self._process.stdout.readline, ""):
|
||||||
k = k.decode("utf-8")[:-1]
|
k = k.decode("utf-8")[:-1]
|
||||||
if k == "!RESTART":
|
if k == "!RESTART":
|
||||||
break
|
break
|
||||||
|
@ -187,61 +273,68 @@ class Celeste:
|
||||||
|
|
||||||
def update_loop(self, before, after):
|
def update_loop(self, before, after):
|
||||||
|
|
||||||
# Get state, call callback, wait for state
|
# Waits for stdout from pico-8 process
|
||||||
# One line => one frame.
|
for line in iter(self._process.stdout.readline, ""):
|
||||||
|
|
||||||
it = iter(self.process.stdout.readline, "")
|
|
||||||
|
|
||||||
|
|
||||||
for line in it:
|
|
||||||
l = line.decode("utf-8")[:-1].strip()
|
l = line.decode("utf-8")[:-1].strip()
|
||||||
self.resetting = False
|
|
||||||
|
# Release all keys
|
||||||
|
self.act(None)
|
||||||
|
self._apply_keys()
|
||||||
|
|
||||||
|
# Clear reset state
|
||||||
|
self._resetting = False
|
||||||
|
|
||||||
# This should only occur at game start
|
# This should only occur at game start
|
||||||
if l in ["!RESTART"]:
|
if l in ["!RESTART"]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.frame_counter += 1
|
self._state_counter += 1
|
||||||
|
|
||||||
# Parse status string
|
# Parse state string
|
||||||
for entry in l.split(";"):
|
for entry in l.split(";"):
|
||||||
if entry == "":
|
if entry == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key, val = entry.split(":")
|
key, val = entry.split(":")
|
||||||
self.internal_status[key] = val
|
self._internal_state[key] = val
|
||||||
|
|
||||||
|
|
||||||
# Update checkpoints
|
# Update checkpoints
|
||||||
|
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||||
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
x = self.state.xpos
|
||||||
x = self.status["xpos"]
|
y = self.state.ypos
|
||||||
y = self.status["ypos"]
|
|
||||||
dist = math.sqrt(
|
dist = math.sqrt(
|
||||||
(x-tx)*(x-tx) +
|
(x-tx)*(x-tx) +
|
||||||
(y-ty)*(y-ty)
|
((y-ty)*(y-ty))/2
|
||||||
|
# Possible modification:
|
||||||
|
# make x-distance twice as valuable as y-distance
|
||||||
)
|
)
|
||||||
|
|
||||||
if dist <= 4 and y == ty:
|
if dist <= 5:
|
||||||
print(f"Got point {self.next_point}")
|
print(f"Got point {self._next_checkpoint_idx}")
|
||||||
self.next_point += 1
|
self._next_checkpoint_idx += 1
|
||||||
self.last_point_frame = self.frame_counter
|
self._last_checkpoint_state = self._state_counter
|
||||||
|
|
||||||
# Recalculate distance to new point
|
# Recalculate distance to new point
|
||||||
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||||
dist = math.sqrt(
|
dist = math.sqrt(
|
||||||
(x-tx)*(x-tx) +
|
(x-tx)*(x-tx) +
|
||||||
(y-ty)*(y-ty)
|
((y-ty)*(y-ty))/2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Timeout if we spend too long between points
|
# Timeout if we spend too long between points
|
||||||
elif self.frame_counter - self.last_point_frame > 40:
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||||
self.internal_status["dc"] = str(int(self.internal_status["dc"]) + 1)
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||||
|
|
||||||
self.dist = dist
|
self._dist = dist
|
||||||
|
|
||||||
# Call step callbacks
|
# Call step callbacks
|
||||||
if self.before_out is not None:
|
# These should call celeste.act() to set next input
|
||||||
after(self, self.before_out)
|
if self._before_out is not None:
|
||||||
if not self.resetting:
|
after(self, self._before_out)
|
||||||
self.before_out = before(self)
|
|
||||||
|
# Do not run before callback if after() triggered a reset.
|
||||||
|
if not self._resetting:
|
||||||
|
self._before_out = before(self)
|
||||||
|
self._apply_keys()
|
||||||
|
|
150
celeste/main.py
150
celeste/main.py
|
@ -1,30 +1,24 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Glue layer
|
|
||||||
from celeste import Celeste
|
from celeste import Celeste
|
||||||
|
|
||||||
|
|
||||||
|
run_data_path = Path("out")
|
||||||
|
run_data_path.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
state_number_map = [
|
|
||||||
"xpos",
|
|
||||||
"ypos",
|
|
||||||
"xvel",
|
|
||||||
"yvel",
|
|
||||||
"next_point"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Celeste env properties
|
# Celeste env properties
|
||||||
n_observations = len(state_number_map)
|
n_observations = len(Celeste.state_number_map)
|
||||||
n_actions = len(Celeste.action_space)
|
n_actions = len(Celeste.action_space)
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +33,7 @@ EPS_END = 0.05
|
||||||
EPS_DECAY = 1000
|
EPS_DECAY = 1000
|
||||||
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 1_000
|
||||||
# Learning rate of target_net.
|
# Learning rate of target_net.
|
||||||
# Controls how soft our soft update is.
|
# Controls how soft our soft update is.
|
||||||
#
|
#
|
||||||
|
@ -64,9 +58,19 @@ GAMMA = 0.99
|
||||||
class DQN(torch.nn.Module):
|
class DQN(torch.nn.Module):
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
def __init__(self, n_observations: int, n_actions: int):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
self.layer1 = torch.nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = torch.nn.Linear(128, 128)
|
self.layers = torch.nn.Sequential(
|
||||||
self.layer3 = torch.nn.Linear(128, n_actions)
|
torch.nn.Linear(n_observations, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.nn.Linear(128, 128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
|
||||||
|
torch.torch.nn.Linear(128, n_actions)
|
||||||
|
)
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
# Can be called with one input, or with a batch.
|
||||||
#
|
#
|
||||||
|
@ -77,9 +81,7 @@ class DQN(torch.nn.Module):
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
# Recall that Q(s, a) is the (expected) return of taking
|
||||||
# action `a` at state `s`
|
# action `a` at state `s`
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = torch.nn.functional.relu(self.layer1(x))
|
return self.layers(x)
|
||||||
x = torch.nn.functional.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +96,7 @@ num_episodes = 100
|
||||||
# Memory: a deque that holds recent states as Transitions
|
# Memory: a deque that holds recent states as Transitions
|
||||||
# Has a fixed length, drops oldest
|
# Has a fixed length, drops oldest
|
||||||
# element if maxlen is exceeded.
|
# element if maxlen is exceeded.
|
||||||
memory = deque([], maxlen=10_000)
|
memory = deque([], maxlen=100_000)
|
||||||
|
|
||||||
|
|
||||||
policy_net = DQN(
|
policy_net = DQN(
|
||||||
|
@ -112,11 +114,10 @@ target_net.load_state_dict(policy_net.state_dict())
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
policy_net.parameters(),
|
policy_net.parameters(),
|
||||||
lr = 1e-4, # Hyperparameter: learning rate
|
lr = 0.01, # Hyperparameter: learning rate
|
||||||
amsgrad = True
|
amsgrad = True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_action(state, steps_done):
|
def select_action(state, steps_done):
|
||||||
"""
|
"""
|
||||||
Select an action using an epsilon-greedy policy.
|
Select an action using an epsilon-greedy policy.
|
||||||
|
@ -303,39 +304,68 @@ def optimize_model():
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
episode_number = 0
|
||||||
|
|
||||||
|
|
||||||
|
if (run_data_path/"checkpoint.torch").is_file():
|
||||||
|
# Load model if one exists
|
||||||
|
checkpoint = torch.load((run_data_path/"checkpoint.torch"))
|
||||||
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
memory = checkpoint["memory"]
|
||||||
|
episode_number = checkpoint["episode_number"] + 1
|
||||||
|
steps_done = checkpoint["steps_done"]
|
||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
global steps_done
|
global steps_done
|
||||||
|
|
||||||
# Conversion to pytorch
|
# Conversion to pytorch
|
||||||
|
|
||||||
state = celeste.status
|
state = celeste.state
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
pt_state = torch.tensor(
|
||||||
[state[x] for x in state_number_map],
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = compute_device
|
device = compute_device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
action = None
|
||||||
|
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
||||||
action = select_action(
|
action = select_action(
|
||||||
pt_state,
|
pt_state,
|
||||||
steps_done
|
steps_done
|
||||||
)
|
)
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
steps_done += 1
|
steps_done += 1
|
||||||
|
|
||||||
# Turn number into action string
|
|
||||||
str_action = Celeste.action_space[action]
|
|
||||||
|
|
||||||
|
# For manual testing
|
||||||
|
#str_action = ""
|
||||||
|
#while str_action not in Celeste.action_space:
|
||||||
|
# str_action = input("action> ")
|
||||||
|
#action = Celeste.action_space.index(str_action)
|
||||||
|
|
||||||
|
print(str_action)
|
||||||
celeste.act(str_action)
|
celeste.act(str_action)
|
||||||
|
|
||||||
return state, action
|
return state, action
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
image_interval = 10
|
||||||
|
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
def on_state_after(celeste, before_out):
|
||||||
|
global episode_number
|
||||||
|
global image_count
|
||||||
|
|
||||||
state, action = before_out
|
state, action = before_out
|
||||||
|
next_state = celeste.state
|
||||||
|
|
||||||
pt_state = torch.tensor(
|
pt_state = torch.tensor(
|
||||||
[state[x] for x in state_number_map],
|
[getattr(state, x) for x in Celeste.state_number_map],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = compute_device
|
device = compute_device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
@ -346,33 +376,30 @@ def on_state_after(celeste, before_out):
|
||||||
dtype = torch.long
|
dtype = torch.long
|
||||||
)
|
)
|
||||||
|
|
||||||
next_state = celeste.status
|
if next_state.deaths != 0:
|
||||||
|
|
||||||
if next_state["deaths"] != 0:
|
|
||||||
pt_next_state = None
|
pt_next_state = None
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pt_next_state = torch.tensor(
|
pt_next_state = torch.tensor(
|
||||||
[next_state[x] for x in state_number_map],
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = compute_device
|
device = compute_device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
if state["next_point"] == next_state["next_point"]:
|
if state.next_point == next_state.next_point:
|
||||||
reward = state["dist"] - next_state["dist"]
|
reward = state.dist - next_state.dist
|
||||||
|
|
||||||
if reward > 0:
|
# Clip rewards that are too large
|
||||||
|
if reward > 1:
|
||||||
reward = 1
|
reward = 1
|
||||||
elif reward < 0:
|
|
||||||
reward = -1
|
|
||||||
else:
|
else:
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Score for reaching a point
|
# Score for reaching a point
|
||||||
reward = 10
|
reward = 1
|
||||||
|
|
||||||
|
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
|
||||||
|
@ -387,6 +414,8 @@ def on_state_after(celeste, before_out):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("==> ", int(reward))
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
# Only train the network if we have enough
|
||||||
|
@ -406,8 +435,51 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
# Move on to the next episode once we reach
|
||||||
# a terminal state.
|
# a terminal state.
|
||||||
if (next_state["deaths"] != 0):
|
if (next_state.deaths != 0):
|
||||||
|
s = celeste.state
|
||||||
|
with open(run_data_path / "train.log", "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"checkpoints": s.next_point,
|
||||||
|
"state_count": s.state_count
|
||||||
|
}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
torch.save({
|
||||||
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"memory": memory,
|
||||||
|
"episode_number": episode_number,
|
||||||
|
"steps_done": steps_done
|
||||||
|
}, run_data_path / "checkpoint.torch")
|
||||||
|
|
||||||
|
|
||||||
|
# Clean up screenshots
|
||||||
|
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||||
|
|
||||||
|
target = run_data_path / Path(f"screenshots/{episode_number}")
|
||||||
|
target.mkdir(parents = True)
|
||||||
|
|
||||||
|
for s in shots:
|
||||||
|
s.rename(target / s.name)
|
||||||
|
|
||||||
|
# Save a prediction graph
|
||||||
|
if episode_number % image_interval == 0:
|
||||||
|
p = run_data_path / Path("model_images")
|
||||||
|
p.mkdir(parents = True, exist_ok = True)
|
||||||
|
torch.save({
|
||||||
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
"target_state_dict": target_net.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"memory": memory,
|
||||||
|
"episode_number": episode_number,
|
||||||
|
"steps_done": steps_done
|
||||||
|
}, p / f"{episode_number}.torch")
|
||||||
|
|
||||||
|
|
||||||
print("State over, resetting")
|
print("State over, resetting")
|
||||||
|
episode_number += 1
|
||||||
celeste.reset()
|
celeste.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in New Issue