Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-19 12:54:27 -08:00
parent 4fbf1ea3f5
commit 97f3cabd75
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 13 additions and 12 deletions

View File

@ -11,7 +11,7 @@ from celeste import Celeste
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data")
model_data_root = Path("model_data/current")
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
@ -40,7 +40,7 @@ if __name__ == "__main__":
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
EPS_DECAY = 4000
BATCH_SIZE = 1_000
@ -60,7 +60,7 @@ if __name__ == "__main__":
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99
GAMMA = 0.9
# Outline our network
@ -115,7 +115,7 @@ if __name__ == "__main__":
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
@ -205,7 +205,7 @@ def optimize_model():
# Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch.
# True if this is a final state, False if it is.
# True if this is a final state, False if it isn't.
#
# We use this to select non-final states later.
non_final_mask = torch.tensor(
@ -269,7 +269,6 @@ def optimize_model():
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean?
# "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
@ -287,7 +286,6 @@ def optimize_model():
)
# We can now run a step of backpropagation on our model.
# TODO: what does this do?
@ -314,6 +312,8 @@ def optimize_model():
# in the .grad attribute of the parameter.
optimizer.step()
return loss
def on_state_before(celeste):
@ -354,7 +354,6 @@ def on_state_before(celeste):
def on_state_after(celeste, before_out):
global episode_number
global image_count
state, action = before_out
next_state = celeste.state
@ -393,7 +392,7 @@ def on_state_after(celeste, before_out):
reward = 0
else:
# Score for reaching a point
# Reward for reaching a point
reward = 1
pt_reward = torch.tensor([reward], device = compute_device)
@ -410,13 +409,14 @@ def on_state_after(celeste, before_out):
)
print("==> ", int(reward))
print("\n")
print("")
loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE:
optimize_model()
loss = optimize_model()
# Soft update target_net weights
target_net_state = target_net.state_dict()
@ -435,7 +435,8 @@ def on_state_after(celeste, before_out):
with model_train_log.open("a") as f:
f.write(json.dumps({
"checkpoints": s.next_point,
"state_count": s.state_count
"state_count": s.state_count,
"loss": None if loss is None else loss.item()
}) + "\n")