276 lines
5.7 KiB
Python
Executable File
276 lines
5.7 KiB
Python
Executable File
import gymnasium as gym
|
|
import math
|
|
import random
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
|
|
from collections import deque
|
|
|
|
from itertools import count
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.nn.functional as F
|
|
|
|
from tqdm import tqdm
|
|
import util
|
|
import optimize as optimize
|
|
|
|
|
|
# TODO: Parameter file
|
|
|
|
# TODO: What is this?
|
|
human_render = False
|
|
|
|
# TODO: What is this$
|
|
BATCH_SIZE = 128
|
|
|
|
# Learning rate of target_net.
|
|
# Controls how soft our soft update is.
|
|
#
|
|
# Should be between 0 and 1.
|
|
# Large values
|
|
# Small values do the opposite.
|
|
#
|
|
# A value of one makes target_net
|
|
# change at the same rate as policy_net.
|
|
#
|
|
# A value of zero makes target_net
|
|
# not change at all.
|
|
TAU = 0.005
|
|
|
|
|
|
# Setup game environment
|
|
if human_render:
|
|
env = gym.make("CartPole-v1", render_mode = "human")
|
|
else:
|
|
env = gym.make("CartPole-v1")
|
|
|
|
# Setup pytorch
|
|
compute_device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
|
|
|
|
# Number of training episodes.
|
|
# It will take a while to process a many of these without a GPU,
|
|
# but you will not see improvement with few training episodes.
|
|
if torch.cuda.is_available():
|
|
num_episodes = 600
|
|
else:
|
|
num_episodes = 50
|
|
|
|
|
|
|
|
# Create replay memory.
|
|
#
|
|
# Transition: a container for naming data (defined in util.py)
|
|
# Memory: a deque that holds recent states as Transitions
|
|
# Has a fixed length, drops oldest
|
|
# element if maxlen is exceeded.
|
|
memory = deque([], maxlen=10000)
|
|
|
|
|
|
|
|
# Outline our network
|
|
class DQN(nn.Module):
|
|
def __init__(self, n_observations: int, n_actions: int):
|
|
super(DQN, self).__init__()
|
|
self.layer1 = nn.Linear(n_observations, 128)
|
|
self.layer2 = nn.Linear(128, 128)
|
|
self.layer3 = nn.Linear(128, n_actions)
|
|
|
|
# Can be called with one input, or with a batch.
|
|
#
|
|
# Returns tensor(
|
|
# [ Q(s, left), Q(s, right) ], ...
|
|
# )
|
|
#
|
|
# Recall that Q(s, a) is the (expected) return of taking
|
|
# action `a` at state `s`
|
|
def forward(self, x):
|
|
x = F.relu(self.layer1(x))
|
|
x = F.relu(self.layer2(x))
|
|
return self.layer3(x)
|
|
|
|
|
|
|
|
|
|
## Create networks and optimizer
|
|
|
|
# n_actions: size of action space
|
|
# - 2 for cartpole: [0, 1] as "left" and "right"
|
|
#
|
|
# n_observations: size of observation vector
|
|
# - 4 for cartpole:
|
|
# position, velocity,
|
|
# angle, angular velocity
|
|
n_actions = env.action_space.n # type: ignore
|
|
state, _ = env.reset()
|
|
n_observations = len(state)
|
|
|
|
# TODO:
|
|
# What's the difference between these two?
|
|
# What do they do?
|
|
policy_net = DQN(n_observations, n_actions).to(compute_device)
|
|
target_net = DQN(n_observations, n_actions).to(compute_device)
|
|
|
|
# Both networks start with the same weights
|
|
target_net.load_state_dict(policy_net.state_dict())
|
|
|
|
#
|
|
optimizer = optim.AdamW(
|
|
policy_net.parameters(),
|
|
lr = 1e-4, # Hyperparameter: learning rate
|
|
amsgrad = True
|
|
)
|
|
|
|
|
|
|
|
# TODO: What is this?
|
|
steps_done = 0
|
|
|
|
|
|
|
|
episode_durations = []
|
|
|
|
|
|
# TRAINING LOOP
|
|
for ep in range(num_episodes):
|
|
|
|
# Reset environment and get game state
|
|
state, _ = env.reset()
|
|
|
|
# Conversion
|
|
state = torch.tensor(
|
|
state,
|
|
dtype = torch.float32,
|
|
device = compute_device
|
|
).unsqueeze(0)
|
|
|
|
|
|
# Iterate until game is over
|
|
for t in count():
|
|
|
|
# Select next action
|
|
action = util.select_action(
|
|
state,
|
|
steps_done = steps_done,
|
|
policy_net = policy_net,
|
|
device = compute_device,
|
|
env = env
|
|
)
|
|
steps_done += 1
|
|
|
|
|
|
# Perform one step of the environment with this action.
|
|
( next_state, # new state
|
|
reward, # number: reward as a result of action
|
|
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
truncated, # bool: end of time limit. If true, must reset.
|
|
_
|
|
) = env.step(action.item())
|
|
|
|
# Conversion
|
|
reward = torch.tensor([reward], device = compute_device)
|
|
|
|
if terminated:
|
|
# If the environment reached a terminal state,
|
|
# observations are meaningless. Set to None.
|
|
next_state = None
|
|
else:
|
|
# Conversion
|
|
next_state = torch.tensor(
|
|
next_state,
|
|
dtype = torch.float32,
|
|
device = compute_device
|
|
).unsqueeze(0)
|
|
|
|
|
|
# Add this state transition to memory.
|
|
memory.append(
|
|
util.Transition(
|
|
state,
|
|
action,
|
|
next_state,
|
|
reward
|
|
)
|
|
)
|
|
|
|
|
|
|
|
state = next_state
|
|
|
|
|
|
# Only train the network if we have enough
|
|
# transitions in memory to do so.
|
|
if len(memory) >= BATCH_SIZE:
|
|
# Run optimizer
|
|
optimize.optimize_model(
|
|
memory,
|
|
# Pytorch params
|
|
compute_device = compute_device,
|
|
policy_net = policy_net,
|
|
target_net = target_net,
|
|
optimizer = optimizer,
|
|
)
|
|
|
|
|
|
# Soft update target_net weights
|
|
target_net_state = target_net.state_dict()
|
|
policy_net_state = policy_net.state_dict()
|
|
for key in policy_net_state:
|
|
target_net_state[key] = (
|
|
policy_net_state[key] * TAU +
|
|
target_net_state[key] * (1-TAU)
|
|
)
|
|
target_net.load_state_dict(target_net_state)
|
|
|
|
# Move on to the next episode once we reach
|
|
# a terminal state.
|
|
if (terminated or truncated):
|
|
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
|
|
episode_durations.append(t + 1)
|
|
break
|
|
|
|
print("Complete.")
|
|
|
|
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
plt.xlabel('Episode')
|
|
plt.ylabel('Duration')
|
|
plt.plot(durations_t.numpy())
|
|
plt.show()
|
|
|
|
|
|
env.close()
|
|
en = gym.make("CartPole-v1", render_mode = "human")
|
|
|
|
while True:
|
|
state, _ = en.reset()
|
|
state = torch.tensor(
|
|
state,
|
|
dtype=torch.float32,
|
|
device=compute_device
|
|
).unsqueeze(0)
|
|
|
|
terminated = False
|
|
truncated = False
|
|
while not (terminated or truncated):
|
|
action = policy_net(state).max(1)[1].view(1, 1)
|
|
|
|
( state, # new state
|
|
reward, # reward as a result of action
|
|
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
|
|
truncated, # bool: end of time limit. If true, must reset.
|
|
_
|
|
) = en.step(action.item())
|
|
|
|
state = torch.tensor(
|
|
state,
|
|
dtype=torch.float32,
|
|
device=compute_device
|
|
).unsqueeze(0)
|
|
|
|
en.render()
|
|
en.reset() |