import torch import gymnasium as gym import random import math import time from collections import deque from itertools import count from collections import namedtuple Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) class Agent: def __init__( self, *, ## Misc parameters # # Computation backend. Usually "cpu" or "gpu." # Automatic selection if left as None. # It's best to leave this as None. compute_device = None, # # Gymnasium environment name. env_name = "CartPole-v1", ## Modules network, ## Hyperparameters # # BATCH_SIZE is the of batch we should train on, sampled from memory # GAMMA is the discount factor for optimization BATCH_SIZE = 128, GAMMA = 0.99, # Learning rate of target_net. # Controls how soft our soft update is. # # Should be between 0 and 1. # Large values # Small values do the opposite. # # A value of one makes target_net # change at the same rate as policy_net. # # A value of zero makes target_net # not change at all. TAU = 0.005, # Optimizer learning rate. OPT_LR = 1e-4, # Epsilon-greedy parameters # # Original docs: # EPS_START is the starting value of epsilon # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9, EPS_END = 0.05, EPS_DECAY = 1000, ): ## Auto-select compute device if compute_device is None: self.compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) else: self.compute_device = compute_device ## Initialize misc values self.steps_done = 0 # How many steps this agent has been trained on self.network = network # Network class this agent should use self.env = gym.make(env_name) # Gym environment self.env_name = env_name ## Initialize replay memory. # This is a deque of util.Transitions. self.memory = deque([], maxlen = 10_000) ## Save model hyperparameters self.BATCH_SIZE = BATCH_SIZE self.GAMMA = GAMMA self.TAU = TAU self.OPT_LR = OPT_LR self.EPS_START = EPS_START self.EPS_END = EPS_END self.EPS_DECAY = EPS_DECAY ## Create networks and optimizer # n_actions: size of action space # - 2 for cartpole: [0, 1] as "left" and "right" # # n_observations: size of observation vector # - 4 for cartpole: # position, velocity, # angle, angular velocity n_actions = self.env.action_space.n # type: ignore state, _ = self.env.reset() n_observations = len(state) # TODO: # What's the difference between these two? # What do they do? self.policy_net = self.network(n_observations, n_actions).to(self.compute_device) self.target_net = self.network(n_observations, n_actions).to(self.compute_device) # Both networks should start with the same weights self.target_net.load_state_dict(self.policy_net.state_dict()) ## Initialize optimizer. self.optimizer = torch.optim.AdamW( self.policy_net.parameters(), lr = self.OPT_LR, amsgrad = True ) def _select_action(self, state): """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Random number 0 <= x < 1 sample = random.random() # Calculate random step threshhold eps_threshold = ( self.EPS_END + (self.EPS_START - self.EPS_END) * math.exp( -1.0 * self.steps_done / self.EPS_DECAY ) ) if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return self.policy_net(state).max(1)[1].view(1, 1) else: return torch.tensor( [ [self.env.action_space.sample()] ], device = self.compute_device, dtype = torch.long ) def _optimize(self): if len(self.memory) < self.BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {self.BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(self.memory, self.BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it is. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # If this doesn't make sense, RTFD. # Compute Q(s_t, a). # - Use policy_net to compute Q(s_t) for each state in the batch. # This gives a tensor of [ Q(state, left), Q(state, right) ] # # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] # listing the action that was taken in each transition. # 0 => we went left, 1 => we went right. # # This aligns nicely with the output of policy_net. We use # action_batch to index the output of policy_net's prediction. # # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = self.policy_net(state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros( self.BATCH_SIZE, device = self.compute_device ) # Don't compute gradient for operations in this block. # If you don't understand what this means, RTFD. with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * self.GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = torch.nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. self.optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore self.policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. self.optimizer.step() def train( self, # Number of training episodes. # Need ~400 to see results. num_episodes = 400, # If true, print progress verbose = False ) -> list[int]: # Returns a list of training episode durations. # Good for graphing. episode_durations = [] for ep in range(num_episodes): # Reset environment and get game state state, _ = self.env.reset() state = torch.tensor( state, dtype = torch.float32, device = self.compute_device ).unsqueeze(0) # Iterate until game is over for t in count(): # Select next action action = self._select_action(state) self.steps_done += 1 # Perform one step of the environment with this action. ( next_state, # new state reward, # number: reward as a result of action terminated, # bool: reached a terminal state (win or loss). If True, must reset. truncated, # bool: end of time limit. If true, must reset. _ ) = self.env.step(action.item()) # Conversion reward = torch.tensor([reward], device = self.compute_device) if terminated: # If the environment reached a terminal state, # observations are meaningless. Set to None. next_state = None else: # Conversion next_state = torch.tensor( next_state, dtype = torch.float32, device = self.compute_device ).unsqueeze(0) # Add this state transition to memory. self.memory.append( Transition( state, action, next_state, reward ) ) state = next_state # Only train the network if we have enough # transitions in memory to do so. if len(self.memory) >= self.BATCH_SIZE: # Run optimizer self._optimize() # Soft update target_net weights target_net_state = self.target_net.state_dict() policy_net_state = self.policy_net.state_dict() for key in policy_net_state: target_net_state[key] = ( policy_net_state[key] * self.TAU + target_net_state[key] * (1-self.TAU) ) self.target_net.load_state_dict(target_net_state) # Move on to the next episode once we reach # a terminal state. if (terminated or truncated): if verbose: print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" ) episode_durations.append(t + 1) break return episode_durations def predict(self, state): return ( self.policy_net(state) .max(1)[1] .view(1, 1) .item() )