Changed hyperparameters, actions, action selection, and reward system
parent
672d330b62
commit
2ff526a072
|
@ -3,6 +3,8 @@ import subprocess
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
class CelesteError(Exception):
|
class CelesteError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -12,8 +14,12 @@ class CelesteState(NamedTuple):
|
||||||
stage: int
|
stage: int
|
||||||
|
|
||||||
# Player position
|
# Player position
|
||||||
|
# Regular position has 0,0 in top-left,
|
||||||
|
# centered position has 0,0 in center.
|
||||||
xpos: int
|
xpos: int
|
||||||
ypos: int
|
ypos: int
|
||||||
|
xpos_scaled: float
|
||||||
|
ypos_scaled: float
|
||||||
|
|
||||||
# Player velocity
|
# Player velocity
|
||||||
xvel: float
|
xvel: float
|
||||||
|
@ -37,28 +43,47 @@ class CelesteState(NamedTuple):
|
||||||
|
|
||||||
# True if Madeline can dash
|
# True if Madeline can dash
|
||||||
can_dash: bool
|
can_dash: bool
|
||||||
|
can_dash_int: int
|
||||||
|
|
||||||
|
|
||||||
class Celeste:
|
class Celeste:
|
||||||
action_space = [
|
action_space = [
|
||||||
"left", # move left
|
"left", # move left 0
|
||||||
"right", # move right
|
"right", # move right 1
|
||||||
"jump", # jump
|
#"jump", # jump
|
||||||
|
"jump-l", # jump left 2
|
||||||
|
"jump-r", # jump right 3
|
||||||
|
|
||||||
"dash-u", # dash up
|
"dash-u", # dash up 4
|
||||||
"dash-r", # dash right
|
"dash-r", # dash right 5
|
||||||
"dash-l", # dash left
|
"dash-l", # dash left 6
|
||||||
"dash-ru", # dash right-up
|
"dash-ru", # dash right-up 7
|
||||||
"dash-lu" # dash left-up
|
"dash-lu" # dash left-up 8
|
||||||
]
|
]
|
||||||
|
|
||||||
# Map integers to state values.
|
# Map integers to state values.
|
||||||
# This also determines what data is fed to the model.
|
# This also determines what data is fed to the model.
|
||||||
state_number_map = [
|
state_number_map = [
|
||||||
"xpos",
|
#"xpos",
|
||||||
"ypos",
|
#"ypos",
|
||||||
"next_point_x",
|
"xpos_scaled",
|
||||||
"next_point_y"
|
"ypos_scaled",
|
||||||
|
"can_dash_int"
|
||||||
|
#"next_point_x",
|
||||||
|
#"next_point_y"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Targets the agent tries to reach.
|
||||||
|
# The last target MUST be outside the frame.
|
||||||
|
target_checkpoints = [
|
||||||
|
[ # Stage 1
|
||||||
|
#(28, 88), # Start pillar
|
||||||
|
(60, 80), # Middle pillar
|
||||||
|
(105, 64), # Right ledge
|
||||||
|
(25, 40), # Left ledge
|
||||||
|
(110, 16), # End ledge
|
||||||
|
(110, -2), # Next stage
|
||||||
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -110,19 +135,6 @@ class Celeste:
|
||||||
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||||
self._keys = {} # Dictionary of "key": bool
|
self._keys = {} # Dictionary of "key": bool
|
||||||
|
|
||||||
# Targets the agent tries to reach.
|
|
||||||
# The last target MUST be outside the frame.
|
|
||||||
self.target_checkpoints = [
|
|
||||||
[ # Stage 1
|
|
||||||
#(28, 88), # Start pillar
|
|
||||||
(60, 80), # Middle pillar
|
|
||||||
(105, 64), # Right ledge
|
|
||||||
(25, 40), # Left ledge
|
|
||||||
(110, 16), # End ledge
|
|
||||||
(110, -2), # Next stage
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def act(self, action: str):
|
def act(self, action: str):
|
||||||
"""
|
"""
|
||||||
Specify what keys should be down. This does NOT send key events.
|
Specify what keys should be down. This does NOT send key events.
|
||||||
|
@ -141,6 +153,12 @@ class Celeste:
|
||||||
self._keys["Right"] = True
|
self._keys["Right"] = True
|
||||||
elif action == "jump":
|
elif action == "jump":
|
||||||
self._keys["c"] = True
|
self._keys["c"] = True
|
||||||
|
elif action == "jump-l":
|
||||||
|
self._keys["c"] = True
|
||||||
|
self._keys["Left"] = True
|
||||||
|
elif action == "jump-r":
|
||||||
|
self._keys["c"] = True
|
||||||
|
self._keys["Right"] = True
|
||||||
|
|
||||||
elif action == "dash-u":
|
elif action == "dash-u":
|
||||||
self._keys["Up"] = True
|
self._keys["Up"] = True
|
||||||
|
@ -183,12 +201,12 @@ class Celeste:
|
||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.target_checkpoints) < stage:
|
if len(Celeste.target_checkpoints) < stage:
|
||||||
next_point_x = None
|
next_point_x = None
|
||||||
next_point_y = None
|
next_point_y = None
|
||||||
else:
|
else:
|
||||||
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||||
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||||
|
|
||||||
|
|
||||||
return CelesteState(
|
return CelesteState(
|
||||||
|
@ -196,6 +214,8 @@ class Celeste:
|
||||||
|
|
||||||
xpos = int(self._internal_state["px"]),
|
xpos = int(self._internal_state["px"]),
|
||||||
ypos = int(self._internal_state["py"]),
|
ypos = int(self._internal_state["py"]),
|
||||||
|
xpos_scaled = int(self._internal_state["px"]) / 128.0,
|
||||||
|
ypos_scaled = int(self._internal_state["py"]) / 128.0,
|
||||||
xvel = float(self._internal_state["vx"]),
|
xvel = float(self._internal_state["vx"]),
|
||||||
yvel = float(self._internal_state["vy"]),
|
yvel = float(self._internal_state["vy"]),
|
||||||
deaths = int(self._internal_state["dc"]),
|
deaths = int(self._internal_state["dc"]),
|
||||||
|
@ -205,7 +225,8 @@ class Celeste:
|
||||||
next_point_x = next_point_x,
|
next_point_x = next_point_x,
|
||||||
next_point_y = next_point_y,
|
next_point_y = next_point_y,
|
||||||
state_count = self._state_counter,
|
state_count = self._state_counter,
|
||||||
can_dash = self._internal_state["ds"] == "t"
|
can_dash = self._internal_state["ds"] == "t",
|
||||||
|
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -299,24 +320,36 @@ class Celeste:
|
||||||
self._internal_state[key] = val
|
self._internal_state[key] = val
|
||||||
|
|
||||||
|
|
||||||
# Update checkpoints
|
|
||||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
|
||||||
|
# Calculate distance to each point
|
||||||
x = self.state.xpos
|
x = self.state.xpos
|
||||||
y = self.state.ypos
|
y = self.state.ypos
|
||||||
dist = math.sqrt(
|
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||||
|
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||||
|
if i < self._next_checkpoint_idx:
|
||||||
|
dist[i] = 1000
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update checkpoints
|
||||||
|
tx, ty = c
|
||||||
|
dist[i] = (math.sqrt(
|
||||||
(x-tx)*(x-tx) +
|
(x-tx)*(x-tx) +
|
||||||
((y-ty)*(y-ty))/2
|
((y-ty)*(y-ty))/2
|
||||||
# Possible modification:
|
# Possible modification:
|
||||||
# make x-distance twice as valuable as y-distance
|
# make x-distance twice as valuable as y-distance
|
||||||
)
|
))
|
||||||
|
min_idx = int(dist.argmin())
|
||||||
|
dist = int(dist[min_idx])
|
||||||
|
|
||||||
if dist <= 5:
|
|
||||||
print(f"Got point {self._next_checkpoint_idx}")
|
if dist <= 8:
|
||||||
self._next_checkpoint_idx += 1
|
print(f"Got point {min_idx}")
|
||||||
|
self._next_checkpoint_idx = min_idx + 1
|
||||||
self._last_checkpoint_state = self._state_counter
|
self._last_checkpoint_state = self._state_counter
|
||||||
|
|
||||||
# Recalculate distance to new point
|
# Recalculate distance to new point
|
||||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||||
dist = math.sqrt(
|
dist = math.sqrt(
|
||||||
(x-tx)*(x-tx) +
|
(x-tx)*(x-tx) +
|
||||||
((y-ty)*(y-ty))/2
|
((y-ty)*(y-ty))/2
|
||||||
|
@ -326,6 +359,7 @@ class Celeste:
|
||||||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||||
|
|
||||||
|
|
||||||
self._dist = dist
|
self._dist = dist
|
||||||
|
|
||||||
# Call step callbacks
|
# Call step callbacks
|
||||||
|
|
|
@ -24,6 +24,12 @@ if __name__ == "__main__":
|
||||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
|
|
||||||
|
# Remove old screenshots
|
||||||
|
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||||
|
for s in shots:
|
||||||
|
s.unlink()
|
||||||
|
|
||||||
|
|
||||||
compute_device = torch.device(
|
compute_device = torch.device(
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
)
|
)
|
||||||
|
@ -41,11 +47,15 @@ if __name__ == "__main__":
|
||||||
# EPS_END is the final value of epsilon
|
# EPS_END is the final value of epsilon
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||||
EPS_START = 0.9
|
EPS_START = 0.9
|
||||||
EPS_END = 0.05
|
EPS_END = 0.02
|
||||||
EPS_DECAY = 4000
|
EPS_DECAY = 100
|
||||||
|
|
||||||
|
# How many times we've reached each point.
|
||||||
|
# Used to compute epsilon-greedy probability with
|
||||||
|
# the parameters above.
|
||||||
|
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
||||||
|
|
||||||
BATCH_SIZE = 1_000
|
BATCH_SIZE = 100
|
||||||
# Learning rate of target_net.
|
# Learning rate of target_net.
|
||||||
# Controls how soft our soft update is.
|
# Controls how soft our soft update is.
|
||||||
#
|
#
|
||||||
|
@ -58,7 +68,7 @@ if __name__ == "__main__":
|
||||||
#
|
#
|
||||||
# A value of zero makes target_net
|
# A value of zero makes target_net
|
||||||
# not change at all.
|
# not change at all.
|
||||||
TAU = 0.005
|
TAU = 0.05
|
||||||
|
|
||||||
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
# GAMMA is the discount factor as mentioned in the previous section
|
||||||
|
@ -90,9 +100,10 @@ if __name__ == "__main__":
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
target_net.load_state_dict(policy_net.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
policy_net.parameters(),
|
policy_net.parameters(),
|
||||||
lr = 0.01, # Hyperparameter: learning rate
|
lr = learning_rate,
|
||||||
amsgrad = True
|
amsgrad = True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -109,6 +120,7 @@ if __name__ == "__main__":
|
||||||
memory = checkpoint["memory"]
|
memory = checkpoint["memory"]
|
||||||
episode_number = checkpoint["episode_number"] + 1
|
episode_number = checkpoint["episode_number"] + 1
|
||||||
steps_done = checkpoint["steps_done"]
|
steps_done = checkpoint["steps_done"]
|
||||||
|
point_counter = checkpoint["point_counter"]
|
||||||
|
|
||||||
def select_action(state, steps_done):
|
def select_action(state, steps_done):
|
||||||
"""
|
"""
|
||||||
|
@ -144,7 +156,6 @@ def select_action(state, steps_done):
|
||||||
|
|
||||||
|
|
||||||
def optimize_model():
|
def optimize_model():
|
||||||
|
|
||||||
if len(memory) < BATCH_SIZE:
|
if len(memory) < BATCH_SIZE:
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
||||||
|
|
||||||
|
@ -189,19 +200,8 @@ def optimize_model():
|
||||||
# out[i, j] = a[ i ][ b[i,j] ]
|
# out[i, j] = a[ i ][ b[i,j] ]
|
||||||
#
|
#
|
||||||
# a is "input," b is "index"
|
# a is "input," b is "index"
|
||||||
# If this doesn't make sense, RTFD.
|
|
||||||
|
|
||||||
# Compute Q(s_t, a).
|
# Compute Q(s_t, a).
|
||||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
||||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
||||||
#
|
|
||||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
||||||
# listing the action that was taken in each transition.
|
|
||||||
# 0 => we went left, 1 => we went right.
|
|
||||||
#
|
|
||||||
# This aligns nicely with the output of policy_net. We use
|
|
||||||
# action_batch to index the output of policy_net's prediction.
|
|
||||||
#
|
|
||||||
# This gives us a tensor that contains the return we expect to get
|
# This gives us a tensor that contains the return we expect to get
|
||||||
# at that state if we follow the model's advice.
|
# at that state if we follow the model's advice.
|
||||||
|
|
||||||
|
@ -214,8 +214,7 @@ def optimize_model():
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
# = the maximum reward over all possible actions at state s_t+1.
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
|
||||||
# If you don't understand what this means, RTFD.
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
# Note the use of non_final_mask here.
|
||||||
|
@ -291,6 +290,15 @@ def on_state_before(celeste):
|
||||||
device = compute_device
|
device = compute_device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
action = select_action(
|
||||||
|
pt_state,
|
||||||
|
point_counter[state.next_point]
|
||||||
|
)
|
||||||
|
str_action = Celeste.action_space[action]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
action = None
|
action = None
|
||||||
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
||||||
action = select_action(
|
action = select_action(
|
||||||
|
@ -298,6 +306,8 @@ def on_state_before(celeste):
|
||||||
steps_done
|
steps_done
|
||||||
)
|
)
|
||||||
str_action = Celeste.action_space[action]
|
str_action = Celeste.action_space[action]
|
||||||
|
"""
|
||||||
|
|
||||||
steps_done += 1
|
steps_done += 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -343,37 +353,37 @@ def on_state_after(celeste, before_out):
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if state.next_point == next_state.next_point:
|
if state.next_point == next_state.next_point:
|
||||||
reward = state.dist - next_state.dist
|
|
||||||
|
|
||||||
# Clip rewards that are too large
|
|
||||||
if reward > 1:
|
|
||||||
reward = 1
|
|
||||||
else:
|
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Reward for reaching a point
|
# Reward for reaching a point
|
||||||
reward = 1
|
reward = next_state.next_point - state.next_point
|
||||||
|
|
||||||
|
# Add to point counter
|
||||||
|
for i in range(state.next_point, state.next_point + reward):
|
||||||
|
point_counter[i] += 1
|
||||||
|
|
||||||
|
reward = reward * 10
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
# Add this state transition to memory.
|
||||||
memory.append(
|
memory.append(
|
||||||
Transition(
|
Transition(
|
||||||
pt_state, # last state
|
pt_state,
|
||||||
pt_action,
|
pt_action,
|
||||||
pt_next_state, # next state
|
pt_next_state,
|
||||||
pt_reward
|
pt_reward
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
print("==> ", int(reward))
|
print("==> ", reward)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
# Only train the network if we have enough
|
# Only train the network if we have enough
|
||||||
# transitions in memory to do so.
|
# transitions in memory to do so.
|
||||||
if len(memory) >= BATCH_SIZE:
|
if len(memory) >= BATCH_SIZE:
|
||||||
|
@ -407,8 +417,18 @@ def on_state_after(celeste, before_out):
|
||||||
"target_state_dict": target_net.state_dict(),
|
"target_state_dict": target_net.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
"memory": memory,
|
"memory": memory,
|
||||||
|
"point_counter": point_counter,
|
||||||
"episode_number": episode_number,
|
"episode_number": episode_number,
|
||||||
"steps_done": steps_done
|
"steps_done": steps_done,
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
"eps_start": EPS_START,
|
||||||
|
"eps_end": EPS_END,
|
||||||
|
"eps_decay": EPS_DECAY,
|
||||||
|
"batch_size": BATCH_SIZE,
|
||||||
|
"tau": TAU,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"gamma": GAMMA
|
||||||
}, model_save_path)
|
}, model_save_path)
|
||||||
|
|
||||||
|
|
||||||
|
@ -421,7 +441,7 @@ def on_state_after(celeste, before_out):
|
||||||
for s in shots:
|
for s in shots:
|
||||||
s.rename(target / s.name)
|
s.rename(target / s.name)
|
||||||
|
|
||||||
# Save a prediction graph
|
# Save a snapshot
|
||||||
if episode_number % archive_interval == 0:
|
if episode_number % archive_interval == 0:
|
||||||
torch.save({
|
torch.save({
|
||||||
"policy_state_dict": policy_net.state_dict(),
|
"policy_state_dict": policy_net.state_dict(),
|
||||||
|
|
Reference in New Issue