2023-02-15 19:24:03 -08:00
|
|
|
import gymnasium as gym
|
|
|
|
|
|
|
|
import matplotlib
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from agent import Agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Outline our network
|
|
|
|
class DQN(nn.Module):
|
|
|
|
def __init__(self, n_observations: int, n_actions: int):
|
|
|
|
super(DQN, self).__init__()
|
|
|
|
self.layer1 = nn.Linear(n_observations, 128)
|
|
|
|
self.layer2 = nn.Linear(128, 128)
|
|
|
|
self.layer3 = nn.Linear(128, n_actions)
|
|
|
|
|
|
|
|
# Can be called with one input, or with a batch.
|
|
|
|
#
|
|
|
|
# Returns tensor(
|
|
|
|
# [ Q(s, left), Q(s, right) ], ...
|
|
|
|
# )
|
|
|
|
#
|
|
|
|
# Recall that Q(s, a) is the (expected) return of taking
|
|
|
|
# action `a` at state `s`
|
|
|
|
def forward(self, x):
|
|
|
|
x = F.relu(self.layer1(x))
|
|
|
|
x = F.relu(self.layer2(x))
|
|
|
|
return self.layer3(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
from multiprocessing import Pool
|
|
|
|
|
|
|
|
def train(i):
|
|
|
|
print(f"Running {i}")
|
|
|
|
|
|
|
|
agent = Agent(
|
|
|
|
env_name = "CartPole-v1",
|
|
|
|
network = DQN,
|
|
|
|
BATCH_SIZE = 128,
|
|
|
|
TAU = 0.005,
|
|
|
|
OPT_LR = 1e-4
|
|
|
|
)
|
|
|
|
|
|
|
|
# Train model episodes
|
|
|
|
episode_durations = agent.train(600)
|
|
|
|
|
|
|
|
#print(f"Model has been trained on {agent.steps_done} steps.")
|
|
|
|
|
|
|
|
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
|
|
|
|
|
|
fig, axs = plt.subplots(1, 1)
|
|
|
|
axs.plot(durations_t.numpy())
|
|
|
|
fig.savefig(f"main-{i}.png")
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
with Pool(3) as p:
|
|
|
|
p.map(train, list(range(10)))
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Make the model
|
|
|
|
#
|
|
|
|
# Should work with...
|
|
|
|
# CartPole-v1
|
|
|
|
# Acrobot-v1
|
|
|
|
agent = Agent(
|
2023-02-24 14:22:08 -08:00
|
|
|
env_name = "CartPole-v1",
|
2023-02-15 19:24:03 -08:00
|
|
|
network = DQN,
|
|
|
|
BATCH_SIZE = 128,
|
|
|
|
TAU = 0.005,
|
|
|
|
OPT_LR = 1e-4
|
|
|
|
)
|
|
|
|
|
|
|
|
# Train the model
|
|
|
|
episode_durations = agent.train(600, verbose = True)
|
|
|
|
|
|
|
|
# Plot training progress
|
|
|
|
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
|
|
|
fig, axs = plt.subplots(1, 1)
|
|
|
|
axs.plot(durations_t.numpy())
|
|
|
|
fig.savefig(f"main.png")
|
|
|
|
|
|
|
|
|
|
|
|
# Test the model
|
|
|
|
env = gym.make(
|
|
|
|
agent.env_name,
|
|
|
|
render_mode = "human"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
state, _ = env.reset()
|
|
|
|
state = torch.tensor(
|
|
|
|
state,
|
|
|
|
dtype = torch.float32,
|
|
|
|
device = agent.compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
terminated = False
|
|
|
|
truncated = False
|
|
|
|
while not (terminated or truncated):
|
|
|
|
|
|
|
|
# Predict best action given state
|
|
|
|
action = agent.predict(state)
|
|
|
|
|
|
|
|
# Do that action, get new state
|
|
|
|
( state,
|
|
|
|
reward,
|
|
|
|
terminated,
|
|
|
|
truncated,
|
|
|
|
_
|
|
|
|
) = env.step(action)
|
|
|
|
state = torch.tensor(
|
|
|
|
state,
|
|
|
|
dtype = torch.float32,
|
|
|
|
device = agent.compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
env.render()
|
|
|
|
|
|
|
|
# Environment needs to be reset after a session ends
|
|
|
|
env.reset()
|