162 lines
4.1 KiB
Python
Executable File
162 lines
4.1 KiB
Python
Executable File
import random
|
|
from collections import deque
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
|
|
import util
|
|
|
|
def optimize_model(
|
|
memory: deque,
|
|
|
|
# Pytorch params
|
|
compute_device,
|
|
policy_net: nn.Module,
|
|
target_net: nn.Module,
|
|
optimizer,
|
|
|
|
# Parameters:
|
|
#
|
|
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
# GAMMA is the discount factor as mentioned in the previous section
|
|
BATCH_SIZE = 128,
|
|
GAMMA = 0.99
|
|
):
|
|
|
|
if len(memory) < BATCH_SIZE:
|
|
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
|
|
|
|
|
|
|
# Get a random sample of transitions
|
|
batch = random.sample(memory, BATCH_SIZE)
|
|
|
|
# Conversion.
|
|
# Transposes batch, turning an array of Transitions
|
|
# into a Transition of arrays.
|
|
batch = util.Transition(*zip(*batch))
|
|
|
|
# Conversion.
|
|
# Combine states, actions, and rewards into their own tensors.
|
|
state_batch = torch.cat(batch.state)
|
|
action_batch = torch.cat(batch.action)
|
|
reward_batch = torch.cat(batch.reward)
|
|
|
|
|
|
|
|
# Compute a mask of non_final_states.
|
|
# Each element of this tensor corresponds to an element in the batch.
|
|
# True if this is a final state, False if it is.
|
|
#
|
|
# We use this to select non-final states later.
|
|
non_final_mask = torch.tensor(
|
|
tuple(map(
|
|
lambda s: s is not None,
|
|
batch.next_state
|
|
))
|
|
)
|
|
|
|
non_final_next_states = torch.cat(
|
|
[s for s in batch.next_state if s is not None]
|
|
)
|
|
|
|
|
|
|
|
# How .gather works:
|
|
# if out = a.gather(1, b),
|
|
# out[i, j] = a[ i ][ b[i,j] ]
|
|
#
|
|
# a is "input," b is "index"
|
|
# If this doesn't make sense, RTFD.
|
|
|
|
# Compute Q(s_t, a).
|
|
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
#
|
|
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
# listing the action that was taken in each transition.
|
|
# 0 => we went left, 1 => we went right.
|
|
#
|
|
# This aligns nicely with the output of policy_net. We use
|
|
# action_batch to index the output of policy_net's prediction.
|
|
#
|
|
# This gives us a tensor that contains the return we expect to get
|
|
# at that state if we follow the model's advice.
|
|
|
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
|
|
|
|
|
|
# Compute V(s_t+1) for all next states.
|
|
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
# = the maximum reward over all possible actions at state s_t+1.
|
|
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
|
|
|
# Don't compute gradient for operations in this block.
|
|
# If you don't understand what this means, RTFD.
|
|
with torch.no_grad():
|
|
|
|
# Note the use of non_final_mask here.
|
|
# States that are final do not have their reward set by the line
|
|
# below, so their reward stays at zero.
|
|
#
|
|
# States that are not final get their predicted value
|
|
# set to the best value the model predicts.
|
|
#
|
|
#
|
|
# Expected values of action are selected with the "older" target net,
|
|
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
|
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
|
|
|
|
|
|
# TODO: What does this mean?
|
|
# "Compute expected Q values"
|
|
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
|
|
|
|
|
|
|
# Compute Huber loss between predicted reward and expected reward.
|
|
# Pytorch is will account for this when we compute the gradient of loss.
|
|
#
|
|
# loss is a single-element tensor (i.e, a scalar).
|
|
criterion = nn.SmoothL1Loss()
|
|
loss = criterion(
|
|
state_action_values,
|
|
expected_state_action_values.unsqueeze(1)
|
|
)
|
|
|
|
|
|
|
|
# We can now run a step of backpropagation on our model.
|
|
|
|
# TODO: what does this do?
|
|
#
|
|
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
# Thus, we reset the gradient before each step.
|
|
optimizer.zero_grad()
|
|
|
|
# Compute the gradient of loss wrt... something?
|
|
# TODO: what does this do, we never use loss again?!
|
|
loss.backward()
|
|
|
|
|
|
# Prevent vanishing and exploding gradients.
|
|
# Forces gradients to be in [-clip_value, +clip_value]
|
|
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
policy_net.parameters(),
|
|
clip_value = 100
|
|
)
|
|
|
|
# Perform a single optimizer step.
|
|
#
|
|
# Uses the current gradient, which is stored
|
|
# in the .grad attribute of the parameter.
|
|
optimizer.step()
|
|
|
|
|
|
|