416 lines
10 KiB
Python
416 lines
10 KiB
Python
|
import torch
|
||
|
import gymnasium as gym
|
||
|
|
||
|
import random
|
||
|
import math
|
||
|
import time
|
||
|
|
||
|
from collections import deque
|
||
|
from itertools import count
|
||
|
from collections import namedtuple
|
||
|
|
||
|
|
||
|
Transition = namedtuple(
|
||
|
"Transition",
|
||
|
(
|
||
|
"state",
|
||
|
"action",
|
||
|
"next_state",
|
||
|
"reward"
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class Agent:
|
||
|
def __init__(
|
||
|
self,
|
||
|
*,
|
||
|
|
||
|
## Misc parameters
|
||
|
#
|
||
|
# Computation backend. Usually "cpu" or "gpu."
|
||
|
# Automatic selection if left as None.
|
||
|
# It's best to leave this as None.
|
||
|
compute_device = None,
|
||
|
#
|
||
|
# Gymnasium environment name.
|
||
|
env_name = "CartPole-v1",
|
||
|
|
||
|
## Modules
|
||
|
network,
|
||
|
|
||
|
## Hyperparameters
|
||
|
#
|
||
|
# BATCH_SIZE is the of batch we should train on, sampled from memory
|
||
|
# GAMMA is the discount factor for optimization
|
||
|
BATCH_SIZE = 128,
|
||
|
GAMMA = 0.99,
|
||
|
|
||
|
# Learning rate of target_net.
|
||
|
# Controls how soft our soft update is.
|
||
|
#
|
||
|
# Should be between 0 and 1.
|
||
|
# Large values
|
||
|
# Small values do the opposite.
|
||
|
#
|
||
|
# A value of one makes target_net
|
||
|
# change at the same rate as policy_net.
|
||
|
#
|
||
|
# A value of zero makes target_net
|
||
|
# not change at all.
|
||
|
TAU = 0.005,
|
||
|
|
||
|
# Optimizer learning rate.
|
||
|
OPT_LR = 1e-4,
|
||
|
|
||
|
|
||
|
# Epsilon-greedy parameters
|
||
|
#
|
||
|
# Original docs:
|
||
|
# EPS_START is the starting value of epsilon
|
||
|
# EPS_END is the final value of epsilon
|
||
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||
|
EPS_START = 0.9,
|
||
|
EPS_END = 0.05,
|
||
|
EPS_DECAY = 1000,
|
||
|
):
|
||
|
|
||
|
## Auto-select compute device
|
||
|
if compute_device is None:
|
||
|
self.compute_device = torch.device(
|
||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||
|
)
|
||
|
else:
|
||
|
self.compute_device = compute_device
|
||
|
|
||
|
|
||
|
## Initialize misc values
|
||
|
self.steps_done = 0 # How many steps this agent has been trained on
|
||
|
self.network = network # Network class this agent should use
|
||
|
self.env = gym.make(env_name) # Gym environment
|
||
|
self.env_name = env_name
|
||
|
|
||
|
## Initialize replay memory.
|
||
|
# This is a deque of util.Transitions.
|
||
|
self.memory = deque([], maxlen = 10_000)
|
||
|
|
||
|
## Save model hyperparameters
|
||
|
self.BATCH_SIZE = BATCH_SIZE
|
||
|
self.GAMMA = GAMMA
|
||
|
self.TAU = TAU
|
||
|
self.OPT_LR = OPT_LR
|
||
|
self.EPS_START = EPS_START
|
||
|
self.EPS_END = EPS_END
|
||
|
self.EPS_DECAY = EPS_DECAY
|
||
|
|
||
|
|
||
|
## Create networks and optimizer
|
||
|
# n_actions: size of action space
|
||
|
# - 2 for cartpole: [0, 1] as "left" and "right"
|
||
|
#
|
||
|
# n_observations: size of observation vector
|
||
|
# - 4 for cartpole:
|
||
|
# position, velocity,
|
||
|
# angle, angular velocity
|
||
|
n_actions = self.env.action_space.n # type: ignore
|
||
|
state, _ = self.env.reset()
|
||
|
n_observations = len(state)
|
||
|
|
||
|
# TODO:
|
||
|
# What's the difference between these two?
|
||
|
# What do they do?
|
||
|
self.policy_net = self.network(n_observations, n_actions).to(self.compute_device)
|
||
|
self.target_net = self.network(n_observations, n_actions).to(self.compute_device)
|
||
|
|
||
|
# Both networks should start with the same weights
|
||
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||
|
|
||
|
|
||
|
## Initialize optimizer.
|
||
|
self.optimizer = torch.optim.AdamW(
|
||
|
self.policy_net.parameters(),
|
||
|
lr = self.OPT_LR,
|
||
|
amsgrad = True
|
||
|
)
|
||
|
|
||
|
def _select_action(self, state):
|
||
|
"""
|
||
|
Select an action using an epsilon-greedy policy.
|
||
|
|
||
|
Sometimes use our model, sometimes sample one uniformly.
|
||
|
|
||
|
P(random action) starts at EPS_START and decays to EPS_END.
|
||
|
Decay rate is controlled by EPS_DECAY.
|
||
|
"""
|
||
|
|
||
|
# Random number 0 <= x < 1
|
||
|
sample = random.random()
|
||
|
|
||
|
# Calculate random step threshhold
|
||
|
eps_threshold = (
|
||
|
self.EPS_END + (self.EPS_START - self.EPS_END) *
|
||
|
math.exp(
|
||
|
-1.0 * self.steps_done /
|
||
|
self.EPS_DECAY
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if sample > eps_threshold:
|
||
|
with torch.no_grad():
|
||
|
# t.max(1) will return the largest column value of each row.
|
||
|
# second column on max result is index of where max element was
|
||
|
# found, so we pick action with the larger expected reward.
|
||
|
return self.policy_net(state).max(1)[1].view(1, 1)
|
||
|
|
||
|
else:
|
||
|
return torch.tensor(
|
||
|
[ [self.env.action_space.sample()] ],
|
||
|
device = self.compute_device,
|
||
|
dtype = torch.long
|
||
|
)
|
||
|
|
||
|
def _optimize(self):
|
||
|
if len(self.memory) < self.BATCH_SIZE:
|
||
|
raise Exception(f"Not enough elements in memory for a batch of {self.BATCH_SIZE}")
|
||
|
|
||
|
|
||
|
|
||
|
# Get a random sample of transitions
|
||
|
batch = random.sample(self.memory, self.BATCH_SIZE)
|
||
|
|
||
|
# Conversion.
|
||
|
# Transposes batch, turning an array of Transitions
|
||
|
# into a Transition of arrays.
|
||
|
batch = Transition(*zip(*batch))
|
||
|
|
||
|
# Conversion.
|
||
|
# Combine states, actions, and rewards into their own tensors.
|
||
|
state_batch = torch.cat(batch.state)
|
||
|
action_batch = torch.cat(batch.action)
|
||
|
reward_batch = torch.cat(batch.reward)
|
||
|
|
||
|
|
||
|
|
||
|
# Compute a mask of non_final_states.
|
||
|
# Each element of this tensor corresponds to an element in the batch.
|
||
|
# True if this is a final state, False if it is.
|
||
|
#
|
||
|
# We use this to select non-final states later.
|
||
|
non_final_mask = torch.tensor(
|
||
|
tuple(map(
|
||
|
lambda s: s is not None,
|
||
|
batch.next_state
|
||
|
))
|
||
|
)
|
||
|
|
||
|
non_final_next_states = torch.cat(
|
||
|
[s for s in batch.next_state if s is not None]
|
||
|
)
|
||
|
|
||
|
|
||
|
|
||
|
# How .gather works:
|
||
|
# if out = a.gather(1, b),
|
||
|
# out[i, j] = a[ i ][ b[i,j] ]
|
||
|
#
|
||
|
# a is "input," b is "index"
|
||
|
# If this doesn't make sense, RTFD.
|
||
|
|
||
|
# Compute Q(s_t, a).
|
||
|
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
||
|
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
||
|
#
|
||
|
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
||
|
# listing the action that was taken in each transition.
|
||
|
# 0 => we went left, 1 => we went right.
|
||
|
#
|
||
|
# This aligns nicely with the output of policy_net. We use
|
||
|
# action_batch to index the output of policy_net's prediction.
|
||
|
#
|
||
|
# This gives us a tensor that contains the return we expect to get
|
||
|
# at that state if we follow the model's advice.
|
||
|
|
||
|
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
|
||
|
|
||
|
|
||
|
|
||
|
# Compute V(s_t+1) for all next states.
|
||
|
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
||
|
# = the maximum reward over all possible actions at state s_t+1.
|
||
|
next_state_values = torch.zeros(
|
||
|
self.BATCH_SIZE,
|
||
|
device = self.compute_device
|
||
|
)
|
||
|
|
||
|
# Don't compute gradient for operations in this block.
|
||
|
# If you don't understand what this means, RTFD.
|
||
|
with torch.no_grad():
|
||
|
|
||
|
# Note the use of non_final_mask here.
|
||
|
# States that are final do not have their reward set by the line
|
||
|
# below, so their reward stays at zero.
|
||
|
#
|
||
|
# States that are not final get their predicted value
|
||
|
# set to the best value the model predicts.
|
||
|
#
|
||
|
#
|
||
|
# Expected values of action are selected with the "older" target net,
|
||
|
# and their best reward (over possible actions) is selected with max(1)[0].
|
||
|
|
||
|
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
|
||
|
|
||
|
|
||
|
|
||
|
# TODO: What does this mean?
|
||
|
# "Compute expected Q values"
|
||
|
expected_state_action_values = reward_batch + (next_state_values * self.GAMMA)
|
||
|
|
||
|
|
||
|
|
||
|
# Compute Huber loss between predicted reward and expected reward.
|
||
|
# Pytorch is will account for this when we compute the gradient of loss.
|
||
|
#
|
||
|
# loss is a single-element tensor (i.e, a scalar).
|
||
|
criterion = torch.nn.SmoothL1Loss()
|
||
|
loss = criterion(
|
||
|
state_action_values,
|
||
|
expected_state_action_values.unsqueeze(1)
|
||
|
)
|
||
|
|
||
|
|
||
|
|
||
|
# We can now run a step of backpropagation on our model.
|
||
|
|
||
|
# TODO: what does this do?
|
||
|
#
|
||
|
# Calling .backward() multiple times will accumulate parameter gradients.
|
||
|
# Thus, we reset the gradient before each step.
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
# Compute the gradient of loss wrt... something?
|
||
|
# TODO: what does this do, we never use loss again?!
|
||
|
loss.backward()
|
||
|
|
||
|
|
||
|
# Prevent vanishing and exploding gradients.
|
||
|
# Forces gradients to be in [-clip_value, +clip_value]
|
||
|
torch.nn.utils.clip_grad_value_( # type: ignore
|
||
|
self.policy_net.parameters(),
|
||
|
clip_value = 100
|
||
|
)
|
||
|
|
||
|
# Perform a single optimizer step.
|
||
|
#
|
||
|
# Uses the current gradient, which is stored
|
||
|
# in the .grad attribute of the parameter.
|
||
|
self.optimizer.step()
|
||
|
|
||
|
def train(
|
||
|
self,
|
||
|
|
||
|
# Number of training episodes.
|
||
|
# Need ~400 to see results.
|
||
|
num_episodes = 400,
|
||
|
|
||
|
# If true, print progress
|
||
|
verbose = False
|
||
|
) -> list[int]:
|
||
|
# Returns a list of training episode durations.
|
||
|
# Good for graphing.
|
||
|
|
||
|
|
||
|
episode_durations = []
|
||
|
|
||
|
for ep in range(num_episodes):
|
||
|
|
||
|
# Reset environment and get game state
|
||
|
state, _ = self.env.reset()
|
||
|
state = torch.tensor(
|
||
|
state,
|
||
|
dtype = torch.float32,
|
||
|
device = self.compute_device
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
|
||
|
# Iterate until game is over
|
||
|
for t in count():
|
||
|
|
||
|
# Select next action
|
||
|
action = self._select_action(state)
|
||
|
self.steps_done += 1
|
||
|
|
||
|
|
||
|
# Perform one step of the environment with this action.
|
||
|
( next_state, # new state
|
||
|
reward, # number: reward as a result of action
|
||
|
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
||
|
truncated, # bool: end of time limit. If true, must reset.
|
||
|
_
|
||
|
) = self.env.step(action.item())
|
||
|
|
||
|
# Conversion
|
||
|
reward = torch.tensor([reward], device = self.compute_device)
|
||
|
|
||
|
if terminated:
|
||
|
# If the environment reached a terminal state,
|
||
|
# observations are meaningless. Set to None.
|
||
|
next_state = None
|
||
|
else:
|
||
|
# Conversion
|
||
|
next_state = torch.tensor(
|
||
|
next_state,
|
||
|
dtype = torch.float32,
|
||
|
device = self.compute_device
|
||
|
).unsqueeze(0)
|
||
|
|
||
|
|
||
|
# Add this state transition to memory.
|
||
|
self.memory.append(
|
||
|
Transition(
|
||
|
state,
|
||
|
action,
|
||
|
next_state,
|
||
|
reward
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
# Only train the network if we have enough
|
||
|
# transitions in memory to do so.
|
||
|
if len(self.memory) >= self.BATCH_SIZE:
|
||
|
|
||
|
state = next_state
|
||
|
|
||
|
# Run optimizer
|
||
|
self._optimize()
|
||
|
|
||
|
|
||
|
# Soft update target_net weights
|
||
|
target_net_state = self.target_net.state_dict()
|
||
|
policy_net_state = self.policy_net.state_dict()
|
||
|
for key in policy_net_state:
|
||
|
target_net_state[key] = (
|
||
|
policy_net_state[key] * self.TAU +
|
||
|
target_net_state[key] * (1-self.TAU)
|
||
|
)
|
||
|
self.target_net.load_state_dict(target_net_state)
|
||
|
|
||
|
# Move on to the next episode once we reach
|
||
|
# a terminal state.
|
||
|
if (terminated or truncated):
|
||
|
if verbose:
|
||
|
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
||
|
episode_durations.append(t + 1)
|
||
|
break
|
||
|
|
||
|
|
||
|
return episode_durations
|
||
|
|
||
|
def predict(self, state):
|
||
|
return (
|
||
|
self.policy_net(state)
|
||
|
.max(1)[1]
|
||
|
.view(1, 1)
|
||
|
.item()
|
||
|
)
|