Changed hyperparameters, actions, action selection, and reward system
parent
672d330b62
commit
2ff526a072
|
@ -3,6 +3,8 @@ import subprocess
|
|||
import time
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
class CelesteError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -12,8 +14,12 @@ class CelesteState(NamedTuple):
|
|||
stage: int
|
||||
|
||||
# Player position
|
||||
# Regular position has 0,0 in top-left,
|
||||
# centered position has 0,0 in center.
|
||||
xpos: int
|
||||
ypos: int
|
||||
xpos_scaled: float
|
||||
ypos_scaled: float
|
||||
|
||||
# Player velocity
|
||||
xvel: float
|
||||
|
@ -37,28 +43,47 @@ class CelesteState(NamedTuple):
|
|||
|
||||
# True if Madeline can dash
|
||||
can_dash: bool
|
||||
can_dash_int: int
|
||||
|
||||
|
||||
class Celeste:
|
||||
action_space = [
|
||||
"left", # move left
|
||||
"right", # move right
|
||||
"jump", # jump
|
||||
"left", # move left 0
|
||||
"right", # move right 1
|
||||
#"jump", # jump
|
||||
"jump-l", # jump left 2
|
||||
"jump-r", # jump right 3
|
||||
|
||||
"dash-u", # dash up
|
||||
"dash-r", # dash right
|
||||
"dash-l", # dash left
|
||||
"dash-ru", # dash right-up
|
||||
"dash-lu" # dash left-up
|
||||
"dash-u", # dash up 4
|
||||
"dash-r", # dash right 5
|
||||
"dash-l", # dash left 6
|
||||
"dash-ru", # dash right-up 7
|
||||
"dash-lu" # dash left-up 8
|
||||
]
|
||||
|
||||
# Map integers to state values.
|
||||
# This also determines what data is fed to the model.
|
||||
state_number_map = [
|
||||
"xpos",
|
||||
"ypos",
|
||||
"next_point_x",
|
||||
"next_point_y"
|
||||
#"xpos",
|
||||
#"ypos",
|
||||
"xpos_scaled",
|
||||
"ypos_scaled",
|
||||
"can_dash_int"
|
||||
#"next_point_x",
|
||||
#"next_point_y"
|
||||
]
|
||||
|
||||
# Targets the agent tries to reach.
|
||||
# The last target MUST be outside the frame.
|
||||
target_checkpoints = [
|
||||
[ # Stage 1
|
||||
#(28, 88), # Start pillar
|
||||
(60, 80), # Middle pillar
|
||||
(105, 64), # Right ledge
|
||||
(25, 40), # Left ledge
|
||||
(110, 16), # End ledge
|
||||
(110, -2), # Next stage
|
||||
]
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
@ -110,19 +135,6 @@ class Celeste:
|
|||
self._resetting = False # True between a call to .reset() and the first state message from pico.
|
||||
self._keys = {} # Dictionary of "key": bool
|
||||
|
||||
# Targets the agent tries to reach.
|
||||
# The last target MUST be outside the frame.
|
||||
self.target_checkpoints = [
|
||||
[ # Stage 1
|
||||
#(28, 88), # Start pillar
|
||||
(60, 80), # Middle pillar
|
||||
(105, 64), # Right ledge
|
||||
(25, 40), # Left ledge
|
||||
(110, 16), # End ledge
|
||||
(110, -2), # Next stage
|
||||
]
|
||||
]
|
||||
|
||||
def act(self, action: str):
|
||||
"""
|
||||
Specify what keys should be down. This does NOT send key events.
|
||||
|
@ -141,6 +153,12 @@ class Celeste:
|
|||
self._keys["Right"] = True
|
||||
elif action == "jump":
|
||||
self._keys["c"] = True
|
||||
elif action == "jump-l":
|
||||
self._keys["c"] = True
|
||||
self._keys["Left"] = True
|
||||
elif action == "jump-r":
|
||||
self._keys["c"] = True
|
||||
self._keys["Right"] = True
|
||||
|
||||
elif action == "dash-u":
|
||||
self._keys["Up"] = True
|
||||
|
@ -183,12 +201,12 @@ class Celeste:
|
|||
[int(self._internal_state["rx"])]
|
||||
)
|
||||
|
||||
if len(self.target_checkpoints) < stage:
|
||||
if len(Celeste.target_checkpoints) < stage:
|
||||
next_point_x = None
|
||||
next_point_y = None
|
||||
else:
|
||||
next_point_x = self.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||
next_point_y = self.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
|
||||
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
|
||||
|
||||
|
||||
return CelesteState(
|
||||
|
@ -196,6 +214,8 @@ class Celeste:
|
|||
|
||||
xpos = int(self._internal_state["px"]),
|
||||
ypos = int(self._internal_state["py"]),
|
||||
xpos_scaled = int(self._internal_state["px"]) / 128.0,
|
||||
ypos_scaled = int(self._internal_state["py"]) / 128.0,
|
||||
xvel = float(self._internal_state["vx"]),
|
||||
yvel = float(self._internal_state["vy"]),
|
||||
deaths = int(self._internal_state["dc"]),
|
||||
|
@ -205,7 +225,8 @@ class Celeste:
|
|||
next_point_x = next_point_x,
|
||||
next_point_y = next_point_y,
|
||||
state_count = self._state_counter,
|
||||
can_dash = self._internal_state["ds"] == "t"
|
||||
can_dash = self._internal_state["ds"] == "t",
|
||||
can_dash_int = 1 if self._internal_state["ds"] == "t" else 0
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
|
@ -299,24 +320,36 @@ class Celeste:
|
|||
self._internal_state[key] = val
|
||||
|
||||
|
||||
# Update checkpoints
|
||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
|
||||
|
||||
# Calculate distance to each point
|
||||
x = self.state.xpos
|
||||
y = self.state.ypos
|
||||
dist = math.sqrt(
|
||||
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
|
||||
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
|
||||
if i < self._next_checkpoint_idx:
|
||||
dist[i] = 1000
|
||||
continue
|
||||
|
||||
# Update checkpoints
|
||||
tx, ty = c
|
||||
dist[i] = (math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
# Possible modification:
|
||||
# make x-distance twice as valuable as y-distance
|
||||
)
|
||||
))
|
||||
min_idx = int(dist.argmin())
|
||||
dist = int(dist[min_idx])
|
||||
|
||||
if dist <= 5:
|
||||
print(f"Got point {self._next_checkpoint_idx}")
|
||||
self._next_checkpoint_idx += 1
|
||||
|
||||
if dist <= 8:
|
||||
print(f"Got point {min_idx}")
|
||||
self._next_checkpoint_idx = min_idx + 1
|
||||
self._last_checkpoint_state = self._state_counter
|
||||
|
||||
# Recalculate distance to new point
|
||||
tx, ty = self.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
|
||||
dist = math.sqrt(
|
||||
(x-tx)*(x-tx) +
|
||||
((y-ty)*(y-ty))/2
|
||||
|
@ -326,6 +359,7 @@ class Celeste:
|
|||
elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
|
||||
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
|
||||
|
||||
|
||||
self._dist = dist
|
||||
|
||||
# Call step callbacks
|
||||
|
|
|
@ -24,6 +24,12 @@ if __name__ == "__main__":
|
|||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
# Remove old screenshots
|
||||
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
||||
for s in shots:
|
||||
s.unlink()
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
@ -41,11 +47,15 @@ if __name__ == "__main__":
|
|||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
EPS_START = 0.9
|
||||
EPS_END = 0.05
|
||||
EPS_DECAY = 4000
|
||||
EPS_END = 0.02
|
||||
EPS_DECAY = 100
|
||||
|
||||
# How many times we've reached each point.
|
||||
# Used to compute epsilon-greedy probability with
|
||||
# the parameters above.
|
||||
point_counter = [0] * len(Celeste.target_checkpoints[0])
|
||||
|
||||
BATCH_SIZE = 1_000
|
||||
BATCH_SIZE = 100
|
||||
# Learning rate of target_net.
|
||||
# Controls how soft our soft update is.
|
||||
#
|
||||
|
@ -58,7 +68,7 @@ if __name__ == "__main__":
|
|||
#
|
||||
# A value of zero makes target_net
|
||||
# not change at all.
|
||||
TAU = 0.005
|
||||
TAU = 0.05
|
||||
|
||||
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
|
@ -90,9 +100,10 @@ if __name__ == "__main__":
|
|||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
|
||||
learning_rate = 0.001
|
||||
optimizer = torch.optim.AdamW(
|
||||
policy_net.parameters(),
|
||||
lr = 0.01, # Hyperparameter: learning rate
|
||||
lr = learning_rate,
|
||||
amsgrad = True
|
||||
)
|
||||
|
||||
|
@ -109,6 +120,7 @@ if __name__ == "__main__":
|
|||
memory = checkpoint["memory"]
|
||||
episode_number = checkpoint["episode_number"] + 1
|
||||
steps_done = checkpoint["steps_done"]
|
||||
point_counter = checkpoint["point_counter"]
|
||||
|
||||
def select_action(state, steps_done):
|
||||
"""
|
||||
|
@ -144,7 +156,6 @@ def select_action(state, steps_done):
|
|||
|
||||
|
||||
def optimize_model():
|
||||
|
||||
if len(memory) < BATCH_SIZE:
|
||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
||||
|
||||
|
@ -189,19 +200,8 @@ def optimize_model():
|
|||
# out[i, j] = a[ i ][ b[i,j] ]
|
||||
#
|
||||
# a is "input," b is "index"
|
||||
# If this doesn't make sense, RTFD.
|
||||
|
||||
# Compute Q(s_t, a).
|
||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
||||
#
|
||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
||||
# listing the action that was taken in each transition.
|
||||
# 0 => we went left, 1 => we went right.
|
||||
#
|
||||
# This aligns nicely with the output of policy_net. We use
|
||||
# action_batch to index the output of policy_net's prediction.
|
||||
#
|
||||
# This gives us a tensor that contains the return we expect to get
|
||||
# at that state if we follow the model's advice.
|
||||
|
||||
|
@ -214,8 +214,7 @@ def optimize_model():
|
|||
# = the maximum reward over all possible actions at state s_t+1.
|
||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
||||
|
||||
# Don't compute gradient for operations in this block.
|
||||
# If you don't understand what this means, RTFD.
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
# Note the use of non_final_mask here.
|
||||
|
@ -291,6 +290,15 @@ def on_state_before(celeste):
|
|||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
|
||||
|
||||
action = select_action(
|
||||
pt_state,
|
||||
point_counter[state.next_point]
|
||||
)
|
||||
str_action = Celeste.action_space[action]
|
||||
|
||||
|
||||
"""
|
||||
action = None
|
||||
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
||||
action = select_action(
|
||||
|
@ -298,6 +306,8 @@ def on_state_before(celeste):
|
|||
steps_done
|
||||
)
|
||||
str_action = Celeste.action_space[action]
|
||||
"""
|
||||
|
||||
steps_done += 1
|
||||
|
||||
|
||||
|
@ -343,37 +353,37 @@ def on_state_after(celeste, before_out):
|
|||
).unsqueeze(0)
|
||||
|
||||
|
||||
|
||||
if state.next_point == next_state.next_point:
|
||||
reward = state.dist - next_state.dist
|
||||
|
||||
# Clip rewards that are too large
|
||||
if reward > 1:
|
||||
reward = 1
|
||||
else:
|
||||
reward = 0
|
||||
|
||||
else:
|
||||
# Reward for reaching a point
|
||||
reward = 1
|
||||
reward = next_state.next_point - state.next_point
|
||||
|
||||
# Add to point counter
|
||||
for i in range(state.next_point, state.next_point + reward):
|
||||
point_counter[i] += 1
|
||||
|
||||
reward = reward * 10
|
||||
pt_reward = torch.tensor([reward], device = compute_device)
|
||||
|
||||
|
||||
# Add this state transition to memory.
|
||||
memory.append(
|
||||
Transition(
|
||||
pt_state, # last state
|
||||
pt_state,
|
||||
pt_action,
|
||||
pt_next_state, # next state
|
||||
pt_next_state,
|
||||
pt_reward
|
||||
)
|
||||
)
|
||||
|
||||
print("==> ", int(reward))
|
||||
print("==> ", reward)
|
||||
print("")
|
||||
|
||||
|
||||
loss = None
|
||||
|
||||
# Only train the network if we have enough
|
||||
# transitions in memory to do so.
|
||||
if len(memory) >= BATCH_SIZE:
|
||||
|
@ -407,8 +417,18 @@ def on_state_after(celeste, before_out):
|
|||
"target_state_dict": target_net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"memory": memory,
|
||||
"point_counter": point_counter,
|
||||
"episode_number": episode_number,
|
||||
"steps_done": steps_done
|
||||
"steps_done": steps_done,
|
||||
|
||||
# Hyperparameters
|
||||
"eps_start": EPS_START,
|
||||
"eps_end": EPS_END,
|
||||
"eps_decay": EPS_DECAY,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"tau": TAU,
|
||||
"learning_rate": learning_rate,
|
||||
"gamma": GAMMA
|
||||
}, model_save_path)
|
||||
|
||||
|
||||
|
@ -421,7 +441,7 @@ def on_state_after(celeste, before_out):
|
|||
for s in shots:
|
||||
s.rename(target / s.name)
|
||||
|
||||
# Save a prediction graph
|
||||
# Save a snapshot
|
||||
if episode_number % archive_interval == 0:
|
||||
torch.save({
|
||||
"policy_state_dict": policy_net.state_dict(),
|
||||
|
|
Reference in New Issue