Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-19 12:54:27 -08:00
parent 4fbf1ea3f5
commit 97f3cabd75
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 13 additions and 12 deletions

View File

@ -11,7 +11,7 @@ from celeste import Celeste
if __name__ == "__main__": if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data") model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
@ -40,7 +40,7 @@ if __name__ == "__main__":
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.05 EPS_END = 0.05
EPS_DECAY = 1000 EPS_DECAY = 4000
BATCH_SIZE = 1_000 BATCH_SIZE = 1_000
@ -60,7 +60,7 @@ if __name__ == "__main__":
# GAMMA is the discount factor as mentioned in the previous section # GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99 GAMMA = 0.9
# Outline our network # Outline our network
@ -115,7 +115,7 @@ if __name__ == "__main__":
# Memory: a deque that holds recent states as Transitions # Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest # Has a fixed length, drops oldest
# element if maxlen is exceeded. # element if maxlen is exceeded.
memory = deque([], maxlen=100_000) memory = deque([], maxlen=50_000)
policy_net = DQN( policy_net = DQN(
n_observations, n_observations,
@ -205,7 +205,7 @@ def optimize_model():
# Compute a mask of non_final_states. # Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch. # Each element of this tensor corresponds to an element in the batch.
# True if this is a final state, False if it is. # True if this is a final state, False if it isn't.
# #
# We use this to select non-final states later. # We use this to select non-final states later.
non_final_mask = torch.tensor( non_final_mask = torch.tensor(
@ -269,7 +269,6 @@ def optimize_model():
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean? # TODO: What does this mean?
# "Compute expected Q values" # "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA) expected_state_action_values = reward_batch + (next_state_values * GAMMA)
@ -287,7 +286,6 @@ def optimize_model():
) )
# We can now run a step of backpropagation on our model. # We can now run a step of backpropagation on our model.
# TODO: what does this do? # TODO: what does this do?
@ -314,6 +312,8 @@ def optimize_model():
# in the .grad attribute of the parameter. # in the .grad attribute of the parameter.
optimizer.step() optimizer.step()
return loss
def on_state_before(celeste): def on_state_before(celeste):
@ -354,7 +354,6 @@ def on_state_before(celeste):
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global episode_number global episode_number
global image_count
state, action = before_out state, action = before_out
next_state = celeste.state next_state = celeste.state
@ -393,7 +392,7 @@ def on_state_after(celeste, before_out):
reward = 0 reward = 0
else: else:
# Score for reaching a point # Reward for reaching a point
reward = 1 reward = 1
pt_reward = torch.tensor([reward], device = compute_device) pt_reward = torch.tensor([reward], device = compute_device)
@ -410,13 +409,14 @@ def on_state_after(celeste, before_out):
) )
print("==> ", int(reward)) print("==> ", int(reward))
print("\n") print("")
loss = None
# Only train the network if we have enough # Only train the network if we have enough
# transitions in memory to do so. # transitions in memory to do so.
if len(memory) >= BATCH_SIZE: if len(memory) >= BATCH_SIZE:
optimize_model() loss = optimize_model()
# Soft update target_net weights # Soft update target_net weights
target_net_state = target_net.state_dict() target_net_state = target_net.state_dict()
@ -435,7 +435,8 @@ def on_state_after(celeste, before_out):
with model_train_log.open("a") as f: with model_train_log.open("a") as f:
f.write(json.dumps({ f.write(json.dumps({
"checkpoints": s.next_point, "checkpoints": s.next_point,
"state_count": s.state_count "state_count": s.state_count,
"loss": None if loss is None else loss.item()
}) + "\n") }) + "\n")