import random from collections import deque import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import util def optimize_model( memory: deque, # Pytorch params compute_device, policy_net: nn.Module, target_net: nn.Module, optimizer, # Parameters: # # BATCH_SIZE is the number of transitions sampled from the replay buffer # GAMMA is the discount factor as mentioned in the previous section BATCH_SIZE = 128, GAMMA = 0.99 ): if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(memory, BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = util.Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it is. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # If this doesn't make sense, RTFD. # Compute Q(s_t, a). # - Use policy_net to compute Q(s_t) for each state in the batch. # This gives a tensor of [ Q(state, left), Q(state, right) ] # # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] # listing the action that was taken in each transition. # 0 => we went left, 1 => we went right. # # This aligns nicely with the output of policy_net. We use # action_batch to index the output of policy_net's prediction. # # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) # Don't compute gradient for operations in this block. # If you don't understand what this means, RTFD. with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. optimizer.step()