Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/polecart/basic/main.py

277 lines
5.7 KiB
Python
Raw Normal View History

2023-02-15 19:24:03 -08:00
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import util
import optimize as optimize
# TODO: Parameter file
# TODO: What is this?
human_render = False
# TODO: What is this$
BATCH_SIZE = 128
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.005
# Setup game environment
if human_render:
env = gym.make("CartPole-v1", render_mode = "human")
else:
env = gym.make("CartPole-v1")
# Setup pytorch
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Number of training episodes.
# It will take a while to process a many of these without a GPU,
# but you will not see improvement with few training episodes.
if torch.cuda.is_available():
num_episodes = 600
else:
num_episodes = 50
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=10000)
# Outline our network
class DQN(nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layer1 = nn.Linear(n_observations, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_actions)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
## Create networks and optimizer
# n_actions: size of action space
# - 2 for cartpole: [0, 1] as "left" and "right"
#
# n_observations: size of observation vector
# - 4 for cartpole:
# position, velocity,
# angle, angular velocity
n_actions = env.action_space.n # type: ignore
state, _ = env.reset()
n_observations = len(state)
# TODO:
# What's the difference between these two?
# What do they do?
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
# Both networks start with the same weights
target_net.load_state_dict(policy_net.state_dict())
#
optimizer = optim.AdamW(
policy_net.parameters(),
lr = 1e-4, # Hyperparameter: learning rate
amsgrad = True
)
# TODO: What is this?
steps_done = 0
episode_durations = []
# TRAINING LOOP
for ep in range(num_episodes):
# Reset environment and get game state
state, _ = env.reset()
# Conversion
state = torch.tensor(
state,
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
# Iterate until game is over
for t in count():
# Select next action
action = util.select_action(
state,
steps_done = steps_done,
policy_net = policy_net,
device = compute_device,
env = env
)
steps_done += 1
# Perform one step of the environment with this action.
( next_state, # new state
reward, # number: reward as a result of action
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
truncated, # bool: end of time limit. If true, must reset.
_
) = env.step(action.item())
# Conversion
reward = torch.tensor([reward], device = compute_device)
if terminated:
# If the environment reached a terminal state,
# observations are meaningless. Set to None.
next_state = None
else:
# Conversion
next_state = torch.tensor(
next_state,
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
# Add this state transition to memory.
memory.append(
util.Transition(
state,
action,
next_state,
reward
)
)
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE:
state = next_state
# Run optimizer
optimize.optimize_model(
memory,
# Pytorch params
compute_device = compute_device,
policy_net = policy_net,
target_net = target_net,
optimizer = optimizer,
)
# Soft update target_net weights
target_net_state = target_net.state_dict()
policy_net_state = policy_net.state_dict()
for key in policy_net_state:
target_net_state[key] = (
policy_net_state[key] * TAU +
target_net_state[key] * (1-TAU)
)
target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (terminated or truncated):
print(f"Episode {ep}/{num_episodes}, last duration {t+1}", end="\r" )
episode_durations.append(t + 1)
break
print("Complete.")
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
plt.show()
env.close()
en = gym.make("CartPole-v1", render_mode = "human")
while True:
state, _ = en.reset()
state = torch.tensor(
state,
dtype=torch.float32,
device=compute_device
).unsqueeze(0)
terminated = False
truncated = False
while not (terminated or truncated):
action = policy_net(state).max(1)[1].view(1, 1)
( state, # new state
reward, # reward as a result of action
terminated, # bool: reached a terminal state (win or loss). If True, must reset.
truncated, # bool: end of time limit. If true, must reset.
_
) = en.step(action.item())
state = torch.tensor(
state,
dtype=torch.float32,
device=compute_device
).unsqueeze(0)
en.render()
en.reset()