Mark
/
celeste-ai
Archived
1
0
Fork 0

Added RL features

master
Mark 2023-02-15 23:38:27 -08:00
parent fd02c65b41
commit c1379a0116
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 282 additions and 59 deletions

View File

@ -2,6 +2,7 @@ import subprocess
import time
import threading
import math
from tqdm import tqdm
class CelesteError(Exception):
pass
@ -51,7 +52,6 @@ class Celeste:
# Initialize variables
self.internal_status = {}
self.dead = False
# Score system
self.frame_counter = 0
@ -173,7 +173,8 @@ class Celeste:
self.keypress("Escape")
self.keystring("run")
self.keypress("Enter", post = 1000)
self.dead = False
self.flush_reader()
def flush_reader(self):
for k in iter(self.process.stdout.readline, ""):
@ -186,7 +187,10 @@ class Celeste:
# Get state, call callback, wait for state
# One line => one frame.
for line in iter(self.process.stdout.readline, ""):
it = iter(self.process.stdout.readline, "")
for line in it:
l = line.decode("utf-8")[:-1].strip()
# This should only occur at game start
@ -215,6 +219,7 @@ class Celeste:
)
if dist <= 4 and y == ty:
print(f"Got point {self.next_point}")
self.next_point += 1
# Recalculate distance to new point

View File

@ -5,7 +5,6 @@ import math
import torch
# Glue layer
from celeste import Celeste
@ -15,6 +14,19 @@ compute_device = torch.device(
)
state_number_map = [
"xpos",
"ypos",
"xvel",
"yvel",
"next_point"
]
# Celeste env properties
n_observations = len(state_number_map)
n_actions = len(Celeste.action_space)
# Epsilon-greedy parameters
#
@ -27,6 +39,27 @@ EPS_END = 0.05
EPS_DECAY = 1000
BATCH_SIZE = 128
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.005
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
@ -50,15 +83,39 @@ class DQN(torch.nn.Module):
# Celeste env properties
n_observations = 4
n_actions = len(Celeste.action_space)
steps_done = 0
num_episodes = 100
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=10_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 1e-4, # Hyperparameter: learning rate
amsgrad = True
)
def select_action(state, steps_done):
"""
@ -107,68 +164,229 @@ Transition = namedtuple(
)
def on_state(celeste):
global last_state
s = celeste.status
if last_state is None:
last_state = s
return
s_next = s["next_point"]
s_dist = s["dist"]
l_next = last_state["next_point"]
l_dist = last_state["dist"]
if l_next == s_next:
reward = l_dist - s_dist
else:
reward = 10
dead = s["deaths"] != 0
frame_count = s["frame_count"]
# Values at this point
# reward: reward for last action
# dead: true if game over
state_number_map = [
"xpos",
"ypos",
"xvel",
"yvel"
]
tf_state = torch.tensor(
[s[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
tf_last = torch.tensor(
[last_state[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = select_action(
tf_state,
frame_count
def optimize_model():
if len(memory) < BATCH_SIZE:
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
# Get a random sample of transitions
batch = random.sample(memory, BATCH_SIZE)
# Conversion.
# Transposes batch, turning an array of Transitions
# into a Transition of arrays.
batch = Transition(*zip(*batch))
# Conversion.
# Combine states, actions, and rewards into their own tensors.
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch.
# True if this is a final state, False if it is.
#
# We use this to select non-final states later.
non_final_mask = torch.tensor(
tuple(map(
lambda s: s is not None,
batch.next_state
))
)
non_final_next_states = torch.cat(
[s for s in batch.next_state if s is not None]
)
# How .gather works:
# if out = a.gather(1, b),
# out[i, j] = a[ i ][ b[i,j] ]
#
# a is "input," b is "index"
# If this doesn't make sense, RTFD.
# Compute Q(s_t, a).
# - Use policy_net to compute Q(s_t) for each state in the batch.
# This gives a tensor of [ Q(state, left), Q(state, right) ]
#
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
# listing the action that was taken in each transition.
# 0 => we went left, 1 => we went right.
#
# This aligns nicely with the output of policy_net. We use
# action_batch to index the output of policy_net's prediction.
#
# This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice.
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Compute V(s_t+1) for all next states.
# V(s_t+1) = max_a ( Q(s_t+1, a) )
# = the maximum reward over all possible actions at state s_t+1.
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
# Don't compute gradient for operations in this block.
# If you don't understand what this means, RTFD.
with torch.no_grad():
# Note the use of non_final_mask here.
# States that are final do not have their reward set by the line
# below, so their reward stays at zero.
#
# States that are not final get their predicted value
# set to the best value the model predicts.
#
#
# Expected values of action are selected with the "older" target net,
# and their best reward (over possible actions) is selected with max(1)[0].
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean?
# "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
# Compute Huber loss between predicted reward and expected reward.
# Pytorch is will account for this when we compute the gradient of loss.
#
# loss is a single-element tensor (i.e, a scalar).
criterion = torch.nn.SmoothL1Loss()
loss = criterion(
state_action_values,
expected_state_action_values.unsqueeze(1)
)
# We can now run a step of backpropagation on our model.
# TODO: what does this do?
#
# Calling .backward() multiple times will accumulate parameter gradients.
# Thus, we reset the gradient before each step.
optimizer.zero_grad()
# Compute the gradient of loss wrt... something?
# TODO: what does this do, we never use loss again?!
loss.backward()
# Prevent vanishing and exploding gradients.
# Forces gradients to be in [-clip_value, +clip_value]
torch.nn.utils.clip_grad_value_( # type: ignore
policy_net.parameters(),
clip_value = 100
)
# Perform a single optimizer step.
#
# Uses the current gradient, which is stored
# in the .grad attribute of the parameter.
optimizer.step()
def on_state(celeste):
global steps_done
# Conversion to pytorch
state = celeste.status
pt_state = torch.tensor(
[state[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = select_action(
pt_state,
steps_done
)
steps_done += 1
# Turn number into action string
action = Celeste.action_space[action]
str_action = Celeste.action_space[action]
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
)
celeste.act(action)
celeste.act(str_action)
next_state = celeste.status
if next_state["deaths"] != 0:
pt_next_state = None
reward = 0
else:
pt_next_state = torch.tensor(
[next_state[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
if state["next_point"] == next_state["next_point"]:
reward = state["dist"] - next_state["dist"]
else:
# Score for reaching a point
reward = 10
pt_reward = torch.tensor([reward], device = compute_device)
# Add this state transition to memory.
memory.append(
Transition(
pt_state, # last state
pt_action,
pt_next_state, # next state
pt_reward
)
)
# Update previous state
last_state = s
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE:
optimize_model()
# Soft update target_net weights
target_net_state = target_net.state_dict()
policy_net_state = policy_net.state_dict()
for key in policy_net_state:
target_net_state[key] = (
policy_net_state[key] * TAU +
target_net_state[key] * (1-TAU)
)
target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (next_state["deaths"] != 0):
print("State over, resetting")
celeste.reset()