Fixed bugs
parent
ce02009d64
commit
1216378c49
|
@ -49,6 +49,8 @@ class Celeste:
|
||||||
|
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
self.internal_status = {}
|
self.internal_status = {}
|
||||||
|
self.before_out = None
|
||||||
|
self.last_point_frame = 0
|
||||||
|
|
||||||
# Score system
|
# Score system
|
||||||
self.frame_counter = 0
|
self.frame_counter = 0
|
||||||
|
@ -166,6 +168,9 @@ class Celeste:
|
||||||
self.internal_status = {}
|
self.internal_status = {}
|
||||||
self.next_point = 0
|
self.next_point = 0
|
||||||
self.frame_counter = 0
|
self.frame_counter = 0
|
||||||
|
self.before_out = None
|
||||||
|
self.resetting = True
|
||||||
|
self.last_point_frame = 0
|
||||||
|
|
||||||
self.keypress("Escape")
|
self.keypress("Escape")
|
||||||
self.keystring("run")
|
self.keystring("run")
|
||||||
|
@ -185,13 +190,12 @@ class Celeste:
|
||||||
# Get state, call callback, wait for state
|
# Get state, call callback, wait for state
|
||||||
# One line => one frame.
|
# One line => one frame.
|
||||||
|
|
||||||
before_out = None
|
|
||||||
|
|
||||||
it = iter(self.process.stdout.readline, "")
|
it = iter(self.process.stdout.readline, "")
|
||||||
|
|
||||||
|
|
||||||
for line in it:
|
for line in it:
|
||||||
l = line.decode("utf-8")[:-1].strip()
|
l = line.decode("utf-8")[:-1].strip()
|
||||||
|
self.resetting = False
|
||||||
|
|
||||||
# This should only occur at game start
|
# This should only occur at game start
|
||||||
if l in ["!RESTART"]:
|
if l in ["!RESTART"]:
|
||||||
|
@ -206,7 +210,7 @@ class Celeste:
|
||||||
|
|
||||||
key, val = entry.split(":")
|
key, val = entry.split(":")
|
||||||
self.internal_status[key] = val
|
self.internal_status[key] = val
|
||||||
|
|
||||||
|
|
||||||
# Update checkpoints
|
# Update checkpoints
|
||||||
|
|
||||||
|
@ -221,6 +225,7 @@ class Celeste:
|
||||||
if dist <= 4 and y == ty:
|
if dist <= 4 and y == ty:
|
||||||
print(f"Got point {self.next_point}")
|
print(f"Got point {self.next_point}")
|
||||||
self.next_point += 1
|
self.next_point += 1
|
||||||
|
self.last_point_frame = self.frame_counter
|
||||||
|
|
||||||
# Recalculate distance to new point
|
# Recalculate distance to new point
|
||||||
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
tx, ty = self.target_points[self.status["stage"]][self.next_point]
|
||||||
|
@ -229,9 +234,14 @@ class Celeste:
|
||||||
(y-ty)*(y-ty)
|
(y-ty)*(y-ty)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Timeout if we spend too long between points
|
||||||
|
elif self.frame_counter - self.last_point_frame > 40:
|
||||||
|
self.internal_status["dc"] = str(int(self.internal_status["dc"]) + 1)
|
||||||
|
|
||||||
self.dist = dist
|
self.dist = dist
|
||||||
|
|
||||||
# Call step callback
|
# Call step callbacks
|
||||||
if before_out is not None:
|
if self.before_out is not None:
|
||||||
after(self, before_out)
|
after(self, self.before_out)
|
||||||
before_out = before(self)
|
if not self.resetting:
|
||||||
|
self.before_out = before(self)
|
|
@ -42,9 +42,9 @@ EPS_DECAY = 1000
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
# Learning rate of target_net.
|
# Learning rate of target_net.
|
||||||
# Controls how soft our soft update is.
|
# Controls how soft our soft update is.
|
||||||
#
|
#
|
||||||
# Should be between 0 and 1.
|
# Should be between 0 and 1.
|
||||||
# Large values
|
# Large values
|
||||||
# Small values do the opposite.
|
# Small values do the opposite.
|
||||||
#
|
#
|
||||||
# A value of one makes target_net
|
# A value of one makes target_net
|
||||||
|
@ -174,7 +174,7 @@ def optimize_model():
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get a random sample of transitions
|
# Get a random sample of transitions
|
||||||
batch = random.sample(memory, BATCH_SIZE)
|
batch = random.sample(memory, BATCH_SIZE)
|
||||||
|
|
||||||
|
@ -238,13 +238,13 @@ def optimize_model():
|
||||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
# = the maximum reward over all possible actions at state s_t+1.
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
# Don't compute gradient for operations in this block.
|
||||||
# If you don't understand what this means, RTFD.
|
# If you don't understand what this means, RTFD.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
# Note the use of non_final_mask here.
|
||||||
# States that are final do not have their reward set by the line
|
# States that are final do not have their reward set by the line
|
||||||
# below, so their reward stays at zero.
|
# below, so their reward stays at zero.
|
||||||
#
|
#
|
||||||
# States that are not final get their predicted value
|
# States that are not final get their predicted value
|
||||||
|
@ -274,7 +274,7 @@ def optimize_model():
|
||||||
expected_state_action_values.unsqueeze(1)
|
expected_state_action_values.unsqueeze(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
# We can now run a step of backpropagation on our model.
|
||||||
|
|
||||||
|
@ -362,10 +362,18 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
if state["next_point"] == next_state["next_point"]:
|
if state["next_point"] == next_state["next_point"]:
|
||||||
reward = state["dist"] - next_state["dist"]
|
reward = state["dist"] - next_state["dist"]
|
||||||
|
|
||||||
|
if reward > 0:
|
||||||
|
reward = 1
|
||||||
|
elif reward < 0:
|
||||||
|
reward = -1
|
||||||
|
else:
|
||||||
|
reward = 0
|
||||||
else:
|
else:
|
||||||
# Score for reaching a point
|
# Score for reaching a point
|
||||||
reward = 10
|
reward = 10
|
||||||
|
|
||||||
|
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in New Issue