Mark
/
celeste-ai
Archived
1
0
Fork 0

Changed hyperparameters, actions, action selection, and reward system

master
Mark 2023-02-24 14:16:43 -08:00
parent 672d330b62
commit 2ff526a072
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 130 additions and 76 deletions

View File

@ -3,6 +3,8 @@ import subprocess
import time import time
import math import math
import numpy as np
class CelesteError(Exception): class CelesteError(Exception):
pass pass
@ -12,8 +14,12 @@ class CelesteState(NamedTuple):
stage: int stage: int
# Player position # Player position
# Regular position has 0,0 in top-left,
# centered position has 0,0 in center.
xpos: int xpos: int
ypos: int ypos: int
xpos_scaled: float
ypos_scaled: float
# Player velocity # Player velocity
xvel: float xvel: float
@ -37,28 +43,47 @@ class CelesteState(NamedTuple):
# True if Madeline can dash # True if Madeline can dash
can_dash: bool can_dash: bool
can_dash_int: int
class Celeste: class Celeste:
action_space = [ action_space = [
"left", # move left "left", # move left 0
"right", # move right "right", # move right 1
"jump", # jump #"jump", # jump
"jump-l", # jump left 2
"jump-r", # jump right 3
"dash-u", # dash up "dash-u", # dash up 4
"dash-r", # dash right "dash-r", # dash right 5
"dash-l", # dash left "dash-l", # dash left 6
"dash-ru", # dash right-up "dash-ru", # dash right-up 7
"dash-lu" # dash left-up "dash-lu" # dash left-up 8
] ]
# Map integers to state values. # Map integers to state values.
# This also determines what data is fed to the model. # This also determines what data is fed to the model.
state_number_map = [ state_number_map = [
"xpos", #"xpos",
"ypos", #"ypos",
"next_point_x", "xpos_scaled",
"next_point_y" "ypos_scaled",
"can_dash_int"
#"next_point_x",
#"next_point_y"
]
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
target_checkpoints = [
[ # Stage 1
#(28, 88), # Start pillar
(60, 80), # Middle pillar
(105, 64), # Right ledge
(25, 40), # Left ledge
(110, 16), # End ledge
(110, -2), # Next stage
]
] ]
def __init__( def __init__(
@ -110,19 +135,6 @@ class Celeste:
self._resetting = False # True between a call to .reset() and the first state message from pico. self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool self._keys = {} # Dictionary of "key": bool
# Targets the agent tries to reach.
# The last target MUST be outside the frame.
self.target_checkpoints = [
[ # Stage 1
#(28, 88), # Start pillar
(60, 80), # Middle pillar
(105, 64), # Right ledge
(25, 40), # Left ledge
(110, 16), # End ledge
(110, -2), # Next stage
]
]
def act(self, action: str): def act(self, action: str):
""" """
Specify what keys should be down. This does NOT send key events. Specify what keys should be down. This does NOT send key events.
@ -141,6 +153,12 @@ class Celeste:
self._keys["Right"] = True self._keys["Right"] = True
elif action == "jump": elif action == "jump":
self._keys["c"] = True self._keys["c"] = True
elif action == "jump-l":
self._keys["c"] = True
self._keys["Left"] = True
elif action == "jump-r":
self._keys["c"] = True
self._keys["Right"] = True
elif action == "dash-u": elif action == "dash-u":
self._keys["Up"] = True self._keys["Up"] = True
@ -183,12 +201,12 @@ class Celeste:
[int(self._internal_state["rx"])] [int(self._internal_state["rx"])]
) )
if len(self.target_checkpoints) < stage: if len(Celeste.target_checkpoints) < stage:
next_point_x = None next_point_x = None
next_point_y = None next_point_y = None
else: else:
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0] next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1] next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
return CelesteState( return CelesteState(
@ -196,6 +214,8 @@ class Celeste:
xpos = int(self._internal_state["px"]), xpos = int(self._internal_state["px"]),
ypos = int(self._internal_state["py"]), ypos = int(self._internal_state["py"]),
xpos_scaled = int(self._internal_state["px"]) / 128.0,
ypos_scaled = int(self._internal_state["py"]) / 128.0,
xvel = float(self._internal_state["vx"]), xvel = float(self._internal_state["vx"]),
yvel = float(self._internal_state["vy"]), yvel = float(self._internal_state["vy"]),
deaths = int(self._internal_state["dc"]), deaths = int(self._internal_state["dc"]),
@ -205,7 +225,8 @@ class Celeste:
next_point_x = next_point_x, next_point_x = next_point_x,
next_point_y = next_point_y, next_point_y = next_point_y,
state_count = self._state_counter, state_count = self._state_counter,
can_dash = self._internal_state["ds"] == "t" can_dash = self._internal_state["ds"] == "t",
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
) )
except KeyError: except KeyError:
@ -299,24 +320,36 @@ class Celeste:
self._internal_state[key] = val self._internal_state[key] = val
# Update checkpoints
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
# Calculate distance to each point
x = self.state.xpos x = self.state.xpos
y = self.state.ypos y = self.state.ypos
dist = math.sqrt( dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
(x-tx)*(x-tx) + for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
((y-ty)*(y-ty))/2 if i < self._next_checkpoint_idx:
# Possible modification: dist[i] = 1000
# make x-distance twice as valuable as y-distance continue
)
if dist <= 5: # Update checkpoints
print(f"Got point {self._next_checkpoint_idx}") tx, ty = c
self._next_checkpoint_idx += 1 dist[i] = (math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
# Possible modification:
# make x-distance twice as valuable as y-distance
))
min_idx = int(dist.argmin())
dist = int(dist[min_idx])
if dist <= 8:
print(f"Got point {min_idx}")
self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point # Recalculate distance to new point
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx] tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
dist = math.sqrt( dist = math.sqrt(
(x-tx)*(x-tx) + (x-tx)*(x-tx) +
((y-ty)*(y-ty))/2 ((y-ty)*(y-ty))/2
@ -326,6 +359,7 @@ class Celeste:
elif self._state_counter - self._last_checkpoint_state > self.state_timeout: elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1) self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
self._dist = dist self._dist = dist
# Call step callbacks # Call step callbacks

View File

@ -24,6 +24,12 @@ if __name__ == "__main__":
screenshot_dir.mkdir(parents = True, exist_ok = True) screenshot_dir.mkdir(parents = True, exist_ok = True)
# Remove old screenshots
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
for s in shots:
s.unlink()
compute_device = torch.device( compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu" "cuda" if torch.cuda.is_available() else "cpu"
) )
@ -41,11 +47,15 @@ if __name__ == "__main__":
# EPS_END is the final value of epsilon # EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.05 EPS_END = 0.02
EPS_DECAY = 4000 EPS_DECAY = 100
# How many times we've reached each point.
# Used to compute epsilon-greedy probability with
# the parameters above.
point_counter = [0] * len(Celeste.target_checkpoints[0])
BATCH_SIZE = 1_000 BATCH_SIZE = 100
# Learning rate of target_net. # Learning rate of target_net.
# Controls how soft our soft update is. # Controls how soft our soft update is.
# #
@ -58,7 +68,7 @@ if __name__ == "__main__":
# #
# A value of zero makes target_net # A value of zero makes target_net
# not change at all. # not change at all.
TAU = 0.005 TAU = 0.05
# GAMMA is the discount factor as mentioned in the previous section # GAMMA is the discount factor as mentioned in the previous section
@ -90,9 +100,10 @@ if __name__ == "__main__":
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
learning_rate = 0.001
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate lr = learning_rate,
amsgrad = True amsgrad = True
) )
@ -109,6 +120,7 @@ if __name__ == "__main__":
memory = checkpoint["memory"] memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1 episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"] steps_done = checkpoint["steps_done"]
point_counter = checkpoint["point_counter"]
def select_action(state, steps_done): def select_action(state, steps_done):
""" """
@ -144,7 +156,6 @@ def select_action(state, steps_done):
def optimize_model(): def optimize_model():
if len(memory) < BATCH_SIZE: if len(memory) < BATCH_SIZE:
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
@ -189,19 +200,8 @@ def optimize_model():
# out[i, j] = a[ i ][ b[i,j] ] # out[i, j] = a[ i ][ b[i,j] ]
# #
# a is "input," b is "index" # a is "input," b is "index"
# If this doesn't make sense, RTFD.
# Compute Q(s_t, a). # Compute Q(s_t, a).
# - Use policy_net to compute Q(s_t) for each state in the batch.
# This gives a tensor of [ Q(state, left), Q(state, right) ]
#
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
# listing the action that was taken in each transition.
# 0 => we went left, 1 => we went right.
#
# This aligns nicely with the output of policy_net. We use
# action_batch to index the output of policy_net's prediction.
#
# This gives us a tensor that contains the return we expect to get # This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice. # at that state if we follow the model's advice.
@ -214,8 +214,7 @@ def optimize_model():
# = the maximum reward over all possible actions at state s_t+1. # = the maximum reward over all possible actions at state s_t+1.
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
# Don't compute gradient for operations in this block.
# If you don't understand what this means, RTFD.
with torch.no_grad(): with torch.no_grad():
# Note the use of non_final_mask here. # Note the use of non_final_mask here.
@ -291,6 +290,15 @@ def on_state_before(celeste):
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0)
action = select_action(
pt_state,
point_counter[state.next_point]
)
str_action = Celeste.action_space[action]
"""
action = None action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action( action = select_action(
@ -298,6 +306,8 @@ def on_state_before(celeste):
steps_done steps_done
) )
str_action = Celeste.action_space[action] str_action = Celeste.action_space[action]
"""
steps_done += 1 steps_done += 1
@ -343,37 +353,37 @@ def on_state_after(celeste, before_out):
).unsqueeze(0) ).unsqueeze(0)
if state.next_point == next_state.next_point: if state.next_point == next_state.next_point:
reward = state.dist - next_state.dist reward = 0
# Clip rewards that are too large
if reward > 1:
reward = 1
else:
reward = 0
else: else:
# Reward for reaching a point # Reward for reaching a point
reward = 1 reward = next_state.next_point - state.next_point
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
point_counter[i] += 1
reward = reward * 10
pt_reward = torch.tensor([reward], device = compute_device) pt_reward = torch.tensor([reward], device = compute_device)
# Add this state transition to memory. # Add this state transition to memory.
memory.append( memory.append(
Transition( Transition(
pt_state, # last state pt_state,
pt_action, pt_action,
pt_next_state, # next state pt_next_state,
pt_reward pt_reward
) )
) )
print("==> ", int(reward)) print("==> ", reward)
print("") print("")
loss = None loss = None
# Only train the network if we have enough # Only train the network if we have enough
# transitions in memory to do so. # transitions in memory to do so.
if len(memory) >= BATCH_SIZE: if len(memory) >= BATCH_SIZE:
@ -407,8 +417,18 @@ def on_state_after(celeste, before_out):
"target_state_dict": target_net.state_dict(), "target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
"memory": memory, "memory": memory,
"point_counter": point_counter,
"episode_number": episode_number, "episode_number": episode_number,
"steps_done": steps_done "steps_done": steps_done,
# Hyperparameters
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, model_save_path) }, model_save_path)
@ -421,7 +441,7 @@ def on_state_after(celeste, before_out):
for s in shots: for s in shots:
s.rename(target / s.name) s.rename(target / s.name)
# Save a prediction graph # Save a snapshot
if episode_number % archive_interval == 0: if episode_number % archive_interval == 0:
torch.save({ torch.save({
"policy_state_dict": policy_net.state_dict(), "policy_state_dict": policy_net.state_dict(),