Cleanup
parent
058292c0bd
commit
571a337ff4
|
@ -1,276 +0,0 @@
|
||||||
import gymnasium as gym
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import util
|
|
||||||
import optimize as optimize
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Parameter file
|
|
||||||
|
|
||||||
# TODO: What is this?
|
|
||||||
human_render = False
|
|
||||||
|
|
||||||
# TODO: What is this$
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
|
|
||||||
# Learning rate of target_net.
|
|
||||||
# Controls how soft our soft update is.
|
|
||||||
#
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
# Large values
|
|
||||||
# Small values do the opposite.
|
|
||||||
#
|
|
||||||
# A value of one makes target_net
|
|
||||||
# change at the same rate as policy_net.
|
|
||||||
#
|
|
||||||
# A value of zero makes target_net
|
|
||||||
# not change at all.
|
|
||||||
TAU = 0.005
|
|
||||||
|
|
||||||
|
|
||||||
# Setup game environment
|
|
||||||
if human_render:
|
|
||||||
env = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
else:
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
|
|
||||||
# Setup pytorch
|
|
||||||
compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Number of training episodes.
|
|
||||||
# It will take a while to process a many of these without a GPU,
|
|
||||||
# but you will not see improvement with few training episodes.
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
num_episodes = 600
|
|
||||||
else:
|
|
||||||
num_episodes = 50
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Create replay memory.
|
|
||||||
#
|
|
||||||
# Transition: a container for naming data (defined in util.py)
|
|
||||||
# Memory: a deque that holds recent states as Transitions
|
|
||||||
# Has a fixed length, drops oldest
|
|
||||||
# element if maxlen is exceeded.
|
|
||||||
memory = deque([], maxlen=10000)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
|
||||||
class DQN(nn.Module):
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
|
||||||
#
|
|
||||||
# Returns tensor(
|
|
||||||
# [ Q(s, left), Q(s, right) ], ...
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
|
||||||
# action `a` at state `s`
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Create networks and optimizer
|
|
||||||
|
|
||||||
# n_actions: size of action space
|
|
||||||
# - 2 for cartpole: [0, 1] as "left" and "right"
|
|
||||||
#
|
|
||||||
# n_observations: size of observation vector
|
|
||||||
# - 4 for cartpole:
|
|
||||||
# position, velocity,
|
|
||||||
# angle, angular velocity
|
|
||||||
n_actions = env.action_space.n # type: ignore
|
|
||||||
state, _ = env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# What's the difference between these two?
|
|
||||||
# What do they do?
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(compute_device)
|
|
||||||
|
|
||||||
# Both networks start with the same weights
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
|
||||||
|
|
||||||
#
|
|
||||||
optimizer = optim.AdamW(
|
|
||||||
policy_net.parameters(),
|
|
||||||
lr = 1e-4, # Hyperparameter: learning rate
|
|
||||||
amsgrad = True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What is this?
|
|
||||||
steps_done = 0
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
|
|
||||||
# TRAINING LOOP
|
|
||||||
for ep in range(num_episodes):
|
|
||||||
|
|
||||||
# Reset environment and get game state
|
|
||||||
state, _ = env.reset()
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Iterate until game is over
|
|
||||||
for t in count():
|
|
||||||
|
|
||||||
# Select next action
|
|
||||||
action = util.select_action(
|
|
||||||
state,
|
|
||||||
steps_done = steps_done,
|
|
||||||
policy_net = policy_net,
|
|
||||||
device = compute_device,
|
|
||||||
env = env
|
|
||||||
)
|
|
||||||
steps_done += 1
|
|
||||||
|
|
||||||
|
|
||||||
# Perform one step of the environment with this action.
|
|
||||||
( next_state, # new state
|
|
||||||
reward, # number: reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = env.step(action.item())
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
reward = torch.tensor([reward], device = compute_device)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
# If the environment reached a terminal state,
|
|
||||||
# observations are meaningless. Set to None.
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
# Conversion
|
|
||||||
next_state = torch.tensor(
|
|
||||||
next_state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
|
||||||
memory.append(
|
|
||||||
util.Transition(
|
|
||||||
state,
|
|
||||||
action,
|
|
||||||
next_state,
|
|
||||||
reward
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
|
||||||
# transitions in memory to do so.
|
|
||||||
if len(memory) >= BATCH_SIZE:
|
|
||||||
# Run optimizer
|
|
||||||
optimize.optimize_model(
|
|
||||||
memory,
|
|
||||||
# Pytorch params
|
|
||||||
compute_device = compute_device,
|
|
||||||
policy_net = policy_net,
|
|
||||||
target_net = target_net,
|
|
||||||
optimizer = optimizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Soft update target_net weights
|
|
||||||
target_net_state = target_net.state_dict()
|
|
||||||
policy_net_state = policy_net.state_dict()
|
|
||||||
for key in policy_net_state:
|
|
||||||
target_net_state[key] = (
|
|
||||||
policy_net_state[key] * TAU +
|
|
||||||
target_net_state[key] * (1-TAU)
|
|
||||||
)
|
|
||||||
target_net.load_state_dict(target_net_state)
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (terminated or truncated):
|
|
||||||
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Complete.")
|
|
||||||
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
plt.xlabel('Episode')
|
|
||||||
plt.ylabel('Duration')
|
|
||||||
plt.plot(durations_t.numpy())
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
en = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = en.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
action = policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
( state, # new state
|
|
||||||
reward, # reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = en.step(action.item())
|
|
||||||
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
en.render()
|
|
||||||
en.reset()
|
|
|
@ -1,161 +0,0 @@
|
||||||
import random
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
import util
|
|
||||||
|
|
||||||
def optimize_model(
|
|
||||||
memory: deque,
|
|
||||||
|
|
||||||
# Pytorch params
|
|
||||||
compute_device,
|
|
||||||
policy_net: nn.Module,
|
|
||||||
target_net: nn.Module,
|
|
||||||
optimizer,
|
|
||||||
|
|
||||||
# Parameters:
|
|
||||||
#
|
|
||||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
GAMMA = 0.99
|
|
||||||
):
|
|
||||||
|
|
||||||
if len(memory) < BATCH_SIZE:
|
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get a random sample of transitions
|
|
||||||
batch = random.sample(memory, BATCH_SIZE)
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Transposes batch, turning an array of Transitions
|
|
||||||
# into a Transition of arrays.
|
|
||||||
batch = util.Transition(*zip(*batch))
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute a mask of non_final_states.
|
|
||||||
# Each element of this tensor corresponds to an element in the batch.
|
|
||||||
# True if this is a final state, False if it is.
|
|
||||||
#
|
|
||||||
# We use this to select non-final states later.
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# How .gather works:
|
|
||||||
# if out = a.gather(1, b),
|
|
||||||
# out[i, j] = a[ i ][ b[i,j] ]
|
|
||||||
#
|
|
||||||
# a is "input," b is "index"
|
|
||||||
# If this doesn't make sense, RTFD.
|
|
||||||
|
|
||||||
# Compute Q(s_t, a).
|
|
||||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
||||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
||||||
#
|
|
||||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
||||||
# listing the action that was taken in each transition.
|
|
||||||
# 0 => we went left, 1 => we went right.
|
|
||||||
#
|
|
||||||
# This aligns nicely with the output of policy_net. We use
|
|
||||||
# action_batch to index the output of policy_net's prediction.
|
|
||||||
#
|
|
||||||
# This gives us a tensor that contains the return we expect to get
|
|
||||||
# at that state if we follow the model's advice.
|
|
||||||
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute V(s_t+1) for all next states.
|
|
||||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
|
||||||
# If you don't understand what this means, RTFD.
|
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
|
||||||
# States that are final do not have their reward set by the line
|
|
||||||
# below, so their reward stays at zero.
|
|
||||||
#
|
|
||||||
# States that are not final get their predicted value
|
|
||||||
# set to the best value the model predicts.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Expected values of action are selected with the "older" target net,
|
|
||||||
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
||||||
|
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What does this mean?
|
|
||||||
# "Compute expected Q values"
|
|
||||||
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute Huber loss between predicted reward and expected reward.
|
|
||||||
# Pytorch is will account for this when we compute the gradient of loss.
|
|
||||||
#
|
|
||||||
# loss is a single-element tensor (i.e, a scalar).
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
loss = criterion(
|
|
||||||
state_action_values,
|
|
||||||
expected_state_action_values.unsqueeze(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
|
||||||
|
|
||||||
# TODO: what does this do?
|
|
||||||
#
|
|
||||||
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
||||||
# Thus, we reset the gradient before each step.
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Compute the gradient of loss wrt... something?
|
|
||||||
# TODO: what does this do, we never use loss again?!
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
|
|
||||||
# Prevent vanishing and exploding gradients.
|
|
||||||
# Forces gradients to be in [-clip_value, +clip_value]
|
|
||||||
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
||||||
policy_net.parameters(),
|
|
||||||
clip_value = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform a single optimizer step.
|
|
||||||
#
|
|
||||||
# Uses the current gradient, which is stored
|
|
||||||
# in the .grad attribute of the parameter.
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,77 +0,0 @@
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_action(
|
|
||||||
state,
|
|
||||||
|
|
||||||
*,
|
|
||||||
|
|
||||||
# Number of steps that have been done
|
|
||||||
steps_done: int,
|
|
||||||
|
|
||||||
# TF parameters
|
|
||||||
policy_net, # DQN policy network
|
|
||||||
device, # Render device, "gpu" or "cpu"
|
|
||||||
env, # GYM environment instance
|
|
||||||
|
|
||||||
# Epsilon parameters
|
|
||||||
#
|
|
||||||
# Original docs:
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
EPS_START = 0.9,
|
|
||||||
EPS_END = 0.05,
|
|
||||||
EPS_DECAY = 1000
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Given a state, select an action using an epsilon-greedy policy.
|
|
||||||
|
|
||||||
Sometimes use our model, sometimes sample one uniformly.
|
|
||||||
|
|
||||||
P(random action) starts at EPS_START and decays to EPS_END.
|
|
||||||
Decay rate is controlled by EPS_DECAY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Random number 0 <= x < 1
|
|
||||||
sample = random.random()
|
|
||||||
|
|
||||||
# Calculate random step threshhold
|
|
||||||
eps_threshold = (
|
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * steps_done /
|
|
||||||
EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [env.action_space.sample()] ],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
|
@ -1,415 +0,0 @@
|
||||||
import torch
|
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
import time
|
|
||||||
|
|
||||||
from collections import deque
|
|
||||||
from itertools import count
|
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
|
|
||||||
## Misc parameters
|
|
||||||
#
|
|
||||||
# Computation backend. Usually "cpu" or "gpu."
|
|
||||||
# Automatic selection if left as None.
|
|
||||||
# It's best to leave this as None.
|
|
||||||
compute_device = None,
|
|
||||||
#
|
|
||||||
# Gymnasium environment name.
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
|
|
||||||
## Modules
|
|
||||||
network,
|
|
||||||
|
|
||||||
## Hyperparameters
|
|
||||||
#
|
|
||||||
# BATCH_SIZE is the of batch we should train on, sampled from memory
|
|
||||||
# GAMMA is the discount factor for optimization
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
GAMMA = 0.99,
|
|
||||||
|
|
||||||
# Learning rate of target_net.
|
|
||||||
# Controls how soft our soft update is.
|
|
||||||
#
|
|
||||||
# Should be between 0 and 1.
|
|
||||||
# Large values
|
|
||||||
# Small values do the opposite.
|
|
||||||
#
|
|
||||||
# A value of one makes target_net
|
|
||||||
# change at the same rate as policy_net.
|
|
||||||
#
|
|
||||||
# A value of zero makes target_net
|
|
||||||
# not change at all.
|
|
||||||
TAU = 0.005,
|
|
||||||
|
|
||||||
# Optimizer learning rate.
|
|
||||||
OPT_LR = 1e-4,
|
|
||||||
|
|
||||||
|
|
||||||
# Epsilon-greedy parameters
|
|
||||||
#
|
|
||||||
# Original docs:
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
EPS_START = 0.9,
|
|
||||||
EPS_END = 0.05,
|
|
||||||
EPS_DECAY = 1000,
|
|
||||||
):
|
|
||||||
|
|
||||||
## Auto-select compute device
|
|
||||||
if compute_device is None:
|
|
||||||
self.compute_device = torch.device(
|
|
||||||
"cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.compute_device = compute_device
|
|
||||||
|
|
||||||
|
|
||||||
## Initialize misc values
|
|
||||||
self.steps_done = 0 # How many steps this agent has been trained on
|
|
||||||
self.network = network # Network class this agent should use
|
|
||||||
self.env = gym.make(env_name) # Gym environment
|
|
||||||
self.env_name = env_name
|
|
||||||
|
|
||||||
## Initialize replay memory.
|
|
||||||
# This is a deque of util.Transitions.
|
|
||||||
self.memory = deque([], maxlen = 10_000)
|
|
||||||
|
|
||||||
## Save model hyperparameters
|
|
||||||
self.BATCH_SIZE = BATCH_SIZE
|
|
||||||
self.GAMMA = GAMMA
|
|
||||||
self.TAU = TAU
|
|
||||||
self.OPT_LR = OPT_LR
|
|
||||||
self.EPS_START = EPS_START
|
|
||||||
self.EPS_END = EPS_END
|
|
||||||
self.EPS_DECAY = EPS_DECAY
|
|
||||||
|
|
||||||
|
|
||||||
## Create networks and optimizer
|
|
||||||
# n_actions: size of action space
|
|
||||||
# - 2 for cartpole: [0, 1] as "left" and "right"
|
|
||||||
#
|
|
||||||
# n_observations: size of observation vector
|
|
||||||
# - 4 for cartpole:
|
|
||||||
# position, velocity,
|
|
||||||
# angle, angular velocity
|
|
||||||
n_actions = self.env.action_space.n # type: ignore
|
|
||||||
state, _ = self.env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# What's the difference between these two?
|
|
||||||
# What do they do?
|
|
||||||
self.policy_net = self.network(n_observations, n_actions).to(self.compute_device)
|
|
||||||
self.target_net = self.network(n_observations, n_actions).to(self.compute_device)
|
|
||||||
|
|
||||||
# Both networks should start with the same weights
|
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
## Initialize optimizer.
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
self.policy_net.parameters(),
|
|
||||||
lr = self.OPT_LR,
|
|
||||||
amsgrad = True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _select_action(self, state):
|
|
||||||
"""
|
|
||||||
Select an action using an epsilon-greedy policy.
|
|
||||||
|
|
||||||
Sometimes use our model, sometimes sample one uniformly.
|
|
||||||
|
|
||||||
P(random action) starts at EPS_START and decays to EPS_END.
|
|
||||||
Decay rate is controlled by EPS_DECAY.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Random number 0 <= x < 1
|
|
||||||
sample = random.random()
|
|
||||||
|
|
||||||
# Calculate random step threshhold
|
|
||||||
eps_threshold = (
|
|
||||||
self.EPS_END + (self.EPS_START - self.EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * self.steps_done /
|
|
||||||
self.EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return self.policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [self.env.action_space.sample()] ],
|
|
||||||
device = self.compute_device,
|
|
||||||
dtype = torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _optimize(self):
|
|
||||||
if len(self.memory) < self.BATCH_SIZE:
|
|
||||||
raise Exception(f"Not enough elements in memory for a batch of {self.BATCH_SIZE}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Get a random sample of transitions
|
|
||||||
batch = random.sample(self.memory, self.BATCH_SIZE)
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Transposes batch, turning an array of Transitions
|
|
||||||
# into a Transition of arrays.
|
|
||||||
batch = Transition(*zip(*batch))
|
|
||||||
|
|
||||||
# Conversion.
|
|
||||||
# Combine states, actions, and rewards into their own tensors.
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute a mask of non_final_states.
|
|
||||||
# Each element of this tensor corresponds to an element in the batch.
|
|
||||||
# True if this is a final state, False if it is.
|
|
||||||
#
|
|
||||||
# We use this to select non-final states later.
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
))
|
|
||||||
)
|
|
||||||
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# How .gather works:
|
|
||||||
# if out = a.gather(1, b),
|
|
||||||
# out[i, j] = a[ i ][ b[i,j] ]
|
|
||||||
#
|
|
||||||
# a is "input," b is "index"
|
|
||||||
# If this doesn't make sense, RTFD.
|
|
||||||
|
|
||||||
# Compute Q(s_t, a).
|
|
||||||
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
||||||
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
||||||
#
|
|
||||||
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
||||||
# listing the action that was taken in each transition.
|
|
||||||
# 0 => we went left, 1 => we went right.
|
|
||||||
#
|
|
||||||
# This aligns nicely with the output of policy_net. We use
|
|
||||||
# action_batch to index the output of policy_net's prediction.
|
|
||||||
#
|
|
||||||
# This gives us a tensor that contains the return we expect to get
|
|
||||||
# at that state if we follow the model's advice.
|
|
||||||
|
|
||||||
state_action_values = self.policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute V(s_t+1) for all next states.
|
|
||||||
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
||||||
# = the maximum reward over all possible actions at state s_t+1.
|
|
||||||
next_state_values = torch.zeros(
|
|
||||||
self.BATCH_SIZE,
|
|
||||||
device = self.compute_device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Don't compute gradient for operations in this block.
|
|
||||||
# If you don't understand what this means, RTFD.
|
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
# Note the use of non_final_mask here.
|
|
||||||
# States that are final do not have their reward set by the line
|
|
||||||
# below, so their reward stays at zero.
|
|
||||||
#
|
|
||||||
# States that are not final get their predicted value
|
|
||||||
# set to the best value the model predicts.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Expected values of action are selected with the "older" target net,
|
|
||||||
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
||||||
|
|
||||||
next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: What does this mean?
|
|
||||||
# "Compute expected Q values"
|
|
||||||
expected_state_action_values = reward_batch + (next_state_values * self.GAMMA)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Compute Huber loss between predicted reward and expected reward.
|
|
||||||
# Pytorch is will account for this when we compute the gradient of loss.
|
|
||||||
#
|
|
||||||
# loss is a single-element tensor (i.e, a scalar).
|
|
||||||
criterion = torch.nn.SmoothL1Loss()
|
|
||||||
loss = criterion(
|
|
||||||
state_action_values,
|
|
||||||
expected_state_action_values.unsqueeze(1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We can now run a step of backpropagation on our model.
|
|
||||||
|
|
||||||
# TODO: what does this do?
|
|
||||||
#
|
|
||||||
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
||||||
# Thus, we reset the gradient before each step.
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Compute the gradient of loss wrt... something?
|
|
||||||
# TODO: what does this do, we never use loss again?!
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
|
|
||||||
# Prevent vanishing and exploding gradients.
|
|
||||||
# Forces gradients to be in [-clip_value, +clip_value]
|
|
||||||
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
||||||
self.policy_net.parameters(),
|
|
||||||
clip_value = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform a single optimizer step.
|
|
||||||
#
|
|
||||||
# Uses the current gradient, which is stored
|
|
||||||
# in the .grad attribute of the parameter.
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
|
|
||||||
# Number of training episodes.
|
|
||||||
# Need ~400 to see results.
|
|
||||||
num_episodes = 400,
|
|
||||||
|
|
||||||
# If true, print progress
|
|
||||||
verbose = False
|
|
||||||
) -> list[int]:
|
|
||||||
# Returns a list of training episode durations.
|
|
||||||
# Good for graphing.
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
for ep in range(num_episodes):
|
|
||||||
|
|
||||||
# Reset environment and get game state
|
|
||||||
state, _ = self.env.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = self.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Iterate until game is over
|
|
||||||
for t in count():
|
|
||||||
|
|
||||||
# Select next action
|
|
||||||
action = self._select_action(state)
|
|
||||||
self.steps_done += 1
|
|
||||||
|
|
||||||
|
|
||||||
# Perform one step of the environment with this action.
|
|
||||||
( next_state, # new state
|
|
||||||
reward, # number: reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = self.env.step(action.item())
|
|
||||||
|
|
||||||
# Conversion
|
|
||||||
reward = torch.tensor([reward], device = self.compute_device)
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
# If the environment reached a terminal state,
|
|
||||||
# observations are meaningless. Set to None.
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
# Conversion
|
|
||||||
next_state = torch.tensor(
|
|
||||||
next_state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = self.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
# Add this state transition to memory.
|
|
||||||
self.memory.append(
|
|
||||||
Transition(
|
|
||||||
state,
|
|
||||||
action,
|
|
||||||
next_state,
|
|
||||||
reward
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
|
||||||
# transitions in memory to do so.
|
|
||||||
if len(self.memory) >= self.BATCH_SIZE:
|
|
||||||
|
|
||||||
|
|
||||||
# Run optimizer
|
|
||||||
self._optimize()
|
|
||||||
|
|
||||||
|
|
||||||
# Soft update target_net weights
|
|
||||||
target_net_state = self.target_net.state_dict()
|
|
||||||
policy_net_state = self.policy_net.state_dict()
|
|
||||||
for key in policy_net_state:
|
|
||||||
target_net_state[key] = (
|
|
||||||
policy_net_state[key] * self.TAU +
|
|
||||||
target_net_state[key] * (1-self.TAU)
|
|
||||||
)
|
|
||||||
self.target_net.load_state_dict(target_net_state)
|
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
|
||||||
# a terminal state.
|
|
||||||
if (terminated or truncated):
|
|
||||||
if verbose:
|
|
||||||
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
return episode_durations
|
|
||||||
|
|
||||||
def predict(self, state):
|
|
||||||
return (
|
|
||||||
self.policy_net(state)
|
|
||||||
.max(1)[1]
|
|
||||||
.view(1, 1)
|
|
||||||
.item()
|
|
||||||
)
|
|
|
@ -1,132 +0,0 @@
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from agent import Agent
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Outline our network
|
|
||||||
class DQN(nn.Module):
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Can be called with one input, or with a batch.
|
|
||||||
#
|
|
||||||
# Returns tensor(
|
|
||||||
# [ Q(s, left), Q(s, right) ], ...
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Recall that Q(s, a) is the (expected) return of taking
|
|
||||||
# action `a` at state `s`
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
def train(i):
|
|
||||||
print(f"Running {i}")
|
|
||||||
|
|
||||||
agent = Agent(
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
network = DQN,
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
TAU = 0.005,
|
|
||||||
OPT_LR = 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train model episodes
|
|
||||||
episode_durations = agent.train(600)
|
|
||||||
|
|
||||||
#print(f"Model has been trained on {agent.steps_done} steps.")
|
|
||||||
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 1)
|
|
||||||
axs.plot(durations_t.numpy())
|
|
||||||
fig.savefig(f"main-{i}.png")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
with Pool(3) as p:
|
|
||||||
p.map(train, list(range(10)))
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Make the model
|
|
||||||
#
|
|
||||||
# Should work with...
|
|
||||||
# CartPole-v1
|
|
||||||
# Acrobot-v1
|
|
||||||
agent = Agent(
|
|
||||||
env_name = "CartPole-v1",
|
|
||||||
network = DQN,
|
|
||||||
BATCH_SIZE = 128,
|
|
||||||
TAU = 0.005,
|
|
||||||
OPT_LR = 1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train the model
|
|
||||||
episode_durations = agent.train(600, verbose = True)
|
|
||||||
|
|
||||||
# Plot training progress
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
fig, axs = plt.subplots(1, 1)
|
|
||||||
axs.plot(durations_t.numpy())
|
|
||||||
fig.savefig(f"main.png")
|
|
||||||
|
|
||||||
|
|
||||||
# Test the model
|
|
||||||
env = gym.make(
|
|
||||||
agent.env_name,
|
|
||||||
render_mode = "human"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = env.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = agent.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
|
|
||||||
# Predict best action given state
|
|
||||||
action = agent.predict(state)
|
|
||||||
|
|
||||||
# Do that action, get new state
|
|
||||||
( state,
|
|
||||||
reward,
|
|
||||||
terminated,
|
|
||||||
truncated,
|
|
||||||
_
|
|
||||||
) = env.step(action)
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = agent.compute_device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
env.render()
|
|
||||||
|
|
||||||
# Environment needs to be reset after a session ends
|
|
||||||
env.reset()
|
|
|
@ -1,316 +0,0 @@
|
||||||
|
|
||||||
## Setup
|
|
||||||
import gymnasium as gym
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from collections import namedtuple, deque
|
|
||||||
from itertools import count
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
env = gym.make("CartPole-v1")
|
|
||||||
|
|
||||||
# set up matplotlib
|
|
||||||
is_ipython = 'inline' in matplotlib.get_backend()
|
|
||||||
if is_ipython:
|
|
||||||
from IPython import display
|
|
||||||
|
|
||||||
plt.ion()
|
|
||||||
|
|
||||||
# if gpu is to be used
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Replay Memory
|
|
||||||
#
|
|
||||||
# We'll be using experience replay memory for training our DQN. It stores the transitions that the agent observes, allowing us to reuse this data later. By sampling from it randomly, the transitions that build up a batch are decorrelated. It has been shown that this greatly stabilizes and improves the DQN training procedure.
|
|
||||||
|
|
||||||
# For this, we're going to need two classses:
|
|
||||||
|
|
||||||
# Transition - a named tuple representing a single transition in our environment. It essentially maps (state, action) pairs to their (next_state, reward) result, with the state being the screen difference image as described later on.
|
|
||||||
|
|
||||||
# ReplayMemory - a cyclic buffer of bounded size that holds the transitions observed recently. It also implements a .sample() method for selecting a random batch of transitions for training.
|
|
||||||
|
|
||||||
|
|
||||||
Transition = namedtuple(
|
|
||||||
"Transition",
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
"action",
|
|
||||||
"next_state",
|
|
||||||
"reward"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemory(object):
|
|
||||||
def __init__(self, capacity):
|
|
||||||
self.memory = deque([], maxlen=capacity)
|
|
||||||
|
|
||||||
def push(self, *args):
|
|
||||||
"""Save a transition"""
|
|
||||||
self.memory.append(Transition(*args))
|
|
||||||
|
|
||||||
def sample(self, batch_size):
|
|
||||||
return random.sample(self.memory, batch_size)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.memory)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# DQN Algorithm
|
|
||||||
#
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, n_observations: int, n_actions: int):
|
|
||||||
super(DQN, self).__init__()
|
|
||||||
self.layer1 = nn.Linear(n_observations, 128)
|
|
||||||
self.layer2 = nn.Linear(128, 128)
|
|
||||||
self.layer3 = nn.Linear(128, n_actions)
|
|
||||||
|
|
||||||
# Called with either one element to determine next action, or a batch
|
|
||||||
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.relu(self.layer1(x))
|
|
||||||
x = F.relu(self.layer2(x))
|
|
||||||
return self.layer3(x)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
|
||||||
# GAMMA is the discount factor as mentioned in the previous section
|
|
||||||
# EPS_START is the starting value of epsilon
|
|
||||||
# EPS_END is the final value of epsilon
|
|
||||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
||||||
# TAU is the update rate of the target network
|
|
||||||
# LR is the learning rate of the AdamW optimizer
|
|
||||||
BATCH_SIZE = 128
|
|
||||||
GAMMA = 0.99
|
|
||||||
EPS_START = 0.9
|
|
||||||
EPS_END = 0.05
|
|
||||||
EPS_DECAY = 1000
|
|
||||||
TAU = 0.005
|
|
||||||
LR = 1e-4
|
|
||||||
|
|
||||||
# Get number of actions from gym action space
|
|
||||||
n_actions = env.action_space.n
|
|
||||||
# Get the number of state observations
|
|
||||||
state, info = env.reset()
|
|
||||||
n_observations = len(state)
|
|
||||||
|
|
||||||
policy_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net = DQN(n_observations, n_actions).to(device)
|
|
||||||
target_net.load_state_dict(policy_net.state_dict())
|
|
||||||
|
|
||||||
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
|
||||||
memory = ReplayMemory(10000)
|
|
||||||
|
|
||||||
|
|
||||||
steps_done = 0
|
|
||||||
|
|
||||||
def select_action(state):
|
|
||||||
global steps_done
|
|
||||||
sample = random.random()
|
|
||||||
eps_threshold = (
|
|
||||||
EPS_END + (EPS_START - EPS_END) *
|
|
||||||
math.exp(
|
|
||||||
-1.0 * steps_done /
|
|
||||||
EPS_DECAY
|
|
||||||
)
|
|
||||||
)
|
|
||||||
steps_done += 1
|
|
||||||
|
|
||||||
if sample > eps_threshold:
|
|
||||||
with torch.no_grad():
|
|
||||||
# t.max(1) will return the largest column value of each row.
|
|
||||||
# second column on max result is index of where max element was
|
|
||||||
# found, so we pick action with the larger expected reward.
|
|
||||||
return policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
return torch.tensor(
|
|
||||||
[ [env.action_space.sample()] ],
|
|
||||||
device=device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
episode_durations = []
|
|
||||||
|
|
||||||
|
|
||||||
def plot_durations(show_result=False):
|
|
||||||
plt.figure(1)
|
|
||||||
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
||||||
if show_result:
|
|
||||||
plt.title('Result')
|
|
||||||
else:
|
|
||||||
plt.clf()
|
|
||||||
plt.title('Training...')
|
|
||||||
plt.xlabel('Episode')
|
|
||||||
plt.ylabel('Duration')
|
|
||||||
plt.plot(durations_t.numpy())
|
|
||||||
# Take 100 episode averages and plot them too
|
|
||||||
if len(durations_t) >= 100:
|
|
||||||
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
|
|
||||||
means = torch.cat((torch.zeros(99), means))
|
|
||||||
plt.plot(means.numpy())
|
|
||||||
|
|
||||||
plt.pause(0.001) # pause a bit so that plots are updated
|
|
||||||
if is_ipython:
|
|
||||||
if not show_result:
|
|
||||||
display.display(plt.gcf())
|
|
||||||
display.clear_output(wait=True)
|
|
||||||
else:
|
|
||||||
display.display(plt.gcf())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def optimize_model():
|
|
||||||
if len(memory) < BATCH_SIZE:
|
|
||||||
return
|
|
||||||
transitions = memory.sample(BATCH_SIZE)
|
|
||||||
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
|
||||||
# detailed explanation). This converts batch-array of Transitions
|
|
||||||
# to Transition of batch-arrays.
|
|
||||||
batch = Transition(*zip(*transitions))
|
|
||||||
|
|
||||||
# Compute a mask of non-final states and concatenate the batch elements
|
|
||||||
# (a final state would've been the one after which simulation ended)
|
|
||||||
non_final_mask = torch.tensor(
|
|
||||||
tuple(
|
|
||||||
map(
|
|
||||||
lambda s: s is not None,
|
|
||||||
batch.next_state
|
|
||||||
)
|
|
||||||
),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
non_final_next_states = torch.cat(
|
|
||||||
[s for s in batch.next_state if s is not None]
|
|
||||||
)
|
|
||||||
state_batch = torch.cat(batch.state)
|
|
||||||
action_batch = torch.cat(batch.action)
|
|
||||||
reward_batch = torch.cat(batch.reward)
|
|
||||||
|
|
||||||
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
|
||||||
# columns of actions taken. These are the actions which would've been taken
|
|
||||||
# for each batch state according to policy_net
|
|
||||||
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
||||||
|
|
||||||
# Compute V(s_{t+1}) for all next states.
|
|
||||||
# Expected values of actions for non_final_next_states are computed based
|
|
||||||
# on the "older" target_net; selecting their best reward with max(1)[0].
|
|
||||||
# This is merged based on the mask, such that we'll have either the expected
|
|
||||||
# state value or 0 in case the state was final.
|
|
||||||
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
|
||||||
with torch.no_grad():
|
|
||||||
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
||||||
# Compute the expected Q values
|
|
||||||
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
|
||||||
|
|
||||||
# Compute Huber loss
|
|
||||||
criterion = nn.SmoothL1Loss()
|
|
||||||
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
|
||||||
|
|
||||||
# Optimize the model
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
# In-place gradient clipping
|
|
||||||
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
num_episodes = 600
|
|
||||||
else:
|
|
||||||
num_episodes = 50
|
|
||||||
|
|
||||||
for i_episode in range(num_episodes):
|
|
||||||
# Initialize the environment and get its state
|
|
||||||
state, info = env.reset()
|
|
||||||
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
|
|
||||||
for t in count():
|
|
||||||
action = select_action(state)
|
|
||||||
observation, reward, terminated, truncated, _ = env.step(action.item())
|
|
||||||
reward = torch.tensor([reward], device=device)
|
|
||||||
done = terminated or truncated
|
|
||||||
|
|
||||||
if terminated:
|
|
||||||
next_state = None
|
|
||||||
else:
|
|
||||||
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
# Store the transition in memory
|
|
||||||
memory.push(state, action, next_state, reward)
|
|
||||||
|
|
||||||
# Move to the next state
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
# Perform one step of the optimization (on the policy network)
|
|
||||||
optimize_model()
|
|
||||||
|
|
||||||
# Soft update of the target network's weights
|
|
||||||
# θ′ ← τ θ + (1 −τ )θ′
|
|
||||||
target_net_state_dict = target_net.state_dict()
|
|
||||||
policy_net_state_dict = policy_net.state_dict()
|
|
||||||
for key in policy_net_state_dict:
|
|
||||||
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
|
|
||||||
target_net.load_state_dict(target_net_state_dict)
|
|
||||||
|
|
||||||
if done:
|
|
||||||
episode_durations.append(t + 1)
|
|
||||||
plot_durations()
|
|
||||||
break
|
|
||||||
|
|
||||||
print('Complete')
|
|
||||||
plot_durations(show_result=True)
|
|
||||||
plt.ioff()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
en = gym.make("CartPole-v1", render_mode = "human")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
state, _ = en.reset()
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
while not (terminated or truncated):
|
|
||||||
action = policy_net(state).max(1)[1].view(1, 1)
|
|
||||||
|
|
||||||
( state, # new state
|
|
||||||
reward, # reward as a result of action
|
|
||||||
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
||||||
truncated, # bool: end of time limit. If true, must reset.
|
|
||||||
_
|
|
||||||
) = en.step(action.item())
|
|
||||||
|
|
||||||
state = torch.tensor(
|
|
||||||
state,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device
|
|
||||||
).unsqueeze(0)
|
|
||||||
|
|
||||||
en.render()
|
|
||||||
en.reset()
|
|
|
@ -1,3 +0,0 @@
|
||||||
gymnasium[classic_control]==0.27.1
|
|
||||||
matplotlib==3.6.3
|
|
||||||
torch==1.13.1
|
|
Before Width: | Height: | Size: 9.5 KiB After Width: | Height: | Size: 9.5 KiB |
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 16 KiB |
Before Width: | Height: | Size: 442 B After Width: | Height: | Size: 442 B |
Reference in New Issue