Tweaks
parent
4fbf1ea3f5
commit
97f3cabd75
|
@ -11,7 +11,7 @@ from celeste import Celeste
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Where to read/write model data.
|
# Where to read/write model data.
|
||||||
model_data_root = Path("model_data")
|
model_data_root = Path("model_data/current")
|
||||||
|
|
||||||
model_save_path = model_data_root / "model.torch"
|
model_save_path = model_data_root / "model.torch"
|
||||||
model_archive_dir = model_data_root / "model_archive"
|
model_archive_dir = model_data_root / "model_archive"
|
||||||
|
@ -40,7 +40,7 @@ if __name__ == "__main__":
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||||
EPS_START = 0.9
|
EPS_START = 0.9
|
||||||
EPS_END = 0.05
|
EPS_END = 0.05
|
||||||
EPS_DECAY = 1000
|
EPS_DECAY = 4000
|
||||||
|
|
||||||
|
|
||||||
BATCH_SIZE = 1_000
|
BATCH_SIZE = 1_000
|
||||||
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
# GAMMA is the discount factor as mentioned in the previous section
|
||||||
GAMMA = 0.99
|
GAMMA = 0.9
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
# Outline our network
|
||||||
|
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
||||||
# Memory: a deque that holds recent states as Transitions
|
# Memory: a deque that holds recent states as Transitions
|
||||||
# Has a fixed length, drops oldest
|
# Has a fixed length, drops oldest
|
||||||
# element if maxlen is exceeded.
|
# element if maxlen is exceeded.
|
||||||
memory = deque([], maxlen=100_000)
|
memory = deque([], maxlen=50_000)
|
||||||
|
|
||||||
policy_net = DQN(
|
policy_net = DQN(
|
||||||
n_observations,
|
n_observations,
|
||||||
|
@ -205,7 +205,7 @@ def optimize_model():
|
||||||
|
|
||||||
# Compute a mask of non_final_states.
|
# Compute a mask of non_final_states.
|
||||||
# Each element of this tensor corresponds to an element in the batch.
|
# Each element of this tensor corresponds to an element in the batch.
|
||||||
# True if this is a final state, False if it is.
|
# True if this is a final state, False if it isn't.
|
||||||
#
|
#
|
||||||
# We use this to select non-final states later.
|
# We use this to select non-final states later.
|
||||||
non_final_mask = torch.tensor(
|
non_final_mask = torch.tensor(
|
||||||
|
@ -269,7 +269,6 @@ def optimize_model():
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What does this mean?
|
# TODO: What does this mean?
|
||||||
# "Compute expected Q values"
|
# "Compute expected Q values"
|
||||||
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
||||||
|
@ -287,7 +286,6 @@ def optimize_model():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
# We can now run a step of backpropagation on our model.
|
||||||
|
|
||||||
# TODO: what does this do?
|
# TODO: what does this do?
|
||||||
|
@ -314,6 +312,8 @@ def optimize_model():
|
||||||
# in the .grad attribute of the parameter.
|
# in the .grad attribute of the parameter.
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def on_state_before(celeste):
|
def on_state_before(celeste):
|
||||||
|
@ -354,7 +354,6 @@ def on_state_before(celeste):
|
||||||
|
|
||||||
def on_state_after(celeste, before_out):
|
def on_state_after(celeste, before_out):
|
||||||
global episode_number
|
global episode_number
|
||||||
global image_count
|
|
||||||
|
|
||||||
state, action = before_out
|
state, action = before_out
|
||||||
next_state = celeste.state
|
next_state = celeste.state
|
||||||
|
@ -393,7 +392,7 @@ def on_state_after(celeste, before_out):
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Score for reaching a point
|
# Reward for reaching a point
|
||||||
reward = 1
|
reward = 1
|
||||||
|
|
||||||
pt_reward = torch.tensor([reward], device = compute_device)
|
pt_reward = torch.tensor([reward], device = compute_device)
|
||||||
|
@ -410,13 +409,14 @@ def on_state_after(celeste, before_out):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("==> ", int(reward))
|
print("==> ", int(reward))
|
||||||
print("\n")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
loss = None
|
||||||
# Only train the network if we have enough
|
# Only train the network if we have enough
|
||||||
# transitions in memory to do so.
|
# transitions in memory to do so.
|
||||||
if len(memory) >= BATCH_SIZE:
|
if len(memory) >= BATCH_SIZE:
|
||||||
optimize_model()
|
loss = optimize_model()
|
||||||
|
|
||||||
# Soft update target_net weights
|
# Soft update target_net weights
|
||||||
target_net_state = target_net.state_dict()
|
target_net_state = target_net.state_dict()
|
||||||
|
@ -435,7 +435,8 @@ def on_state_after(celeste, before_out):
|
||||||
with model_train_log.open("a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
"checkpoints": s.next_point,
|
"checkpoints": s.next_point,
|
||||||
"state_count": s.state_count
|
"state_count": s.state_count,
|
||||||
|
"loss": None if loss is None else loss.item()
|
||||||
}) + "\n")
|
}) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|
Reference in New Issue