Tweaks
parent
4fbf1ea3f5
commit
97f3cabd75
|
@ -11,7 +11,7 @@ from celeste import Celeste
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data")
|
||||
model_data_root = Path("model_data/current")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
|
@ -40,7 +40,7 @@ if __name__ == "__main__":
|
|||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
EPS_START = 0.9
|
||||
EPS_END = 0.05
|
||||
EPS_DECAY = 1000
|
||||
EPS_DECAY = 4000
|
||||
|
||||
|
||||
BATCH_SIZE = 1_000
|
||||
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
|||
|
||||
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
GAMMA = 0.99
|
||||
GAMMA = 0.9
|
||||
|
||||
|
||||
# Outline our network
|
||||
|
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
|||
# Memory: a deque that holds recent states as Transitions
|
||||
# Has a fixed length, drops oldest
|
||||
# element if maxlen is exceeded.
|
||||
memory = deque([], maxlen=100_000)
|
||||
memory = deque([], maxlen=50_000)
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
|
@ -205,7 +205,7 @@ def optimize_model():
|
|||
|
||||
# Compute a mask of non_final_states.
|
||||
# Each element of this tensor corresponds to an element in the batch.
|
||||
# True if this is a final state, False if it is.
|
||||
# True if this is a final state, False if it isn't.
|
||||
#
|
||||
# We use this to select non-final states later.
|
||||
non_final_mask = torch.tensor(
|
||||
|
@ -269,7 +269,6 @@ def optimize_model():
|
|||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
||||
|
||||
|
||||
|
||||
# TODO: What does this mean?
|
||||
# "Compute expected Q values"
|
||||
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
||||
|
@ -287,7 +286,6 @@ def optimize_model():
|
|||
)
|
||||
|
||||
|
||||
|
||||
# We can now run a step of backpropagation on our model.
|
||||
|
||||
# TODO: what does this do?
|
||||
|
@ -314,6 +312,8 @@ def optimize_model():
|
|||
# in the .grad attribute of the parameter.
|
||||
optimizer.step()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
|
@ -354,7 +354,6 @@ def on_state_before(celeste):
|
|||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
global image_count
|
||||
|
||||
state, action = before_out
|
||||
next_state = celeste.state
|
||||
|
@ -393,7 +392,7 @@ def on_state_after(celeste, before_out):
|
|||
reward = 0
|
||||
|
||||
else:
|
||||
# Score for reaching a point
|
||||
# Reward for reaching a point
|
||||
reward = 1
|
||||
|
||||
pt_reward = torch.tensor([reward], device = compute_device)
|
||||
|
@ -410,13 +409,14 @@ def on_state_after(celeste, before_out):
|
|||
)
|
||||
|
||||
print("==> ", int(reward))
|
||||
print("\n")
|
||||
print("")
|
||||
|
||||
|
||||
loss = None
|
||||
# Only train the network if we have enough
|
||||
# transitions in memory to do so.
|
||||
if len(memory) >= BATCH_SIZE:
|
||||
optimize_model()
|
||||
loss = optimize_model()
|
||||
|
||||
# Soft update target_net weights
|
||||
target_net_state = target_net.state_dict()
|
||||
|
@ -435,7 +435,8 @@ def on_state_after(celeste, before_out):
|
|||
with model_train_log.open("a") as f:
|
||||
f.write(json.dumps({
|
||||
"checkpoints": s.next_point,
|
||||
"state_count": s.state_count
|
||||
"state_count": s.state_count,
|
||||
"loss": None if loss is None else loss.item()
|
||||
}) + "\n")
|
||||
|
||||
|
||||
|
|
Reference in New Issue