Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/polecart/main/main.py

132 lines
2.3 KiB
Python
Executable File

import gymnasium as gym
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from agent import Agent
# Outline our network
class DQN(nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layer1 = nn.Linear(n_observations, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_actions)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
"""
from multiprocessing import Pool
def train(i):
print(f"Running {i}")
agent = Agent(
env_name = "CartPole-v1",
network = DQN,
BATCH_SIZE = 128,
TAU = 0.005,
OPT_LR = 1e-4
)
# Train model episodes
episode_durations = agent.train(600)
#print(f"Model has been trained on {agent.steps_done} steps.")
durations_t = torch.tensor(episode_durations, dtype=torch.float)
fig, axs = plt.subplots(1, 1)
axs.plot(durations_t.numpy())
fig.savefig(f"main-{i}.png")
if __name__ == '__main__':
with Pool(3) as p:
p.map(train, list(range(10)))
"""
# Make the model
#
# Should work with...
# CartPole-v1
# Acrobot-v1
agent = Agent(
env_name = "CartPole-v1",
network = DQN,
BATCH_SIZE = 128,
TAU = 0.005,
OPT_LR = 1e-4
)
# Train the model
episode_durations = agent.train(600, verbose = True)
# Plot training progress
durations_t = torch.tensor(episode_durations, dtype=torch.float)
fig, axs = plt.subplots(1, 1)
axs.plot(durations_t.numpy())
fig.savefig(f"main.png")
# Test the model
env = gym.make(
agent.env_name,
render_mode = "human"
)
while True:
state, _ = env.reset()
state = torch.tensor(
state,
dtype = torch.float32,
device = agent.compute_device
).unsqueeze(0)
terminated = False
truncated = False
while not (terminated or truncated):
# Predict best action given state
action = agent.predict(state)
# Do that action, get new state
( state,
reward,
terminated,
truncated,
_
) = env.step(action)
state = torch.tensor(
state,
dtype = torch.float32,
device = agent.compute_device
).unsqueeze(0)
env.render()
# Environment needs to be reset after a session ends
env.reset()