import gymnasium as gym import matplotlib import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from agent import Agent # Outline our network class DQN(nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layer1 = nn.Linear(n_observations, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, n_actions) # Can be called with one input, or with a batch. # # Returns tensor( # [ Q(s, left), Q(s, right) ], ... # ) # # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x) """ from multiprocessing import Pool def train(i): print(f"Running {i}") agent = Agent( env_name = "CartPole-v1", network = DQN, BATCH_SIZE = 128, TAU = 0.005, OPT_LR = 1e-4 ) # Train model episodes episode_durations = agent.train(600) #print(f"Model has been trained on {agent.steps_done} steps.") durations_t = torch.tensor(episode_durations, dtype=torch.float) fig, axs = plt.subplots(1, 1) axs.plot(durations_t.numpy()) fig.savefig(f"main-{i}.png") if __name__ == '__main__': with Pool(3) as p: p.map(train, list(range(10))) """ # Make the model # # Should work with... # CartPole-v1 # Acrobot-v1 agent = Agent( env_name = "CartPole-v1", network = DQN, BATCH_SIZE = 128, TAU = 0.005, OPT_LR = 1e-4 ) # Train the model episode_durations = agent.train(600, verbose = True) # Plot training progress durations_t = torch.tensor(episode_durations, dtype=torch.float) fig, axs = plt.subplots(1, 1) axs.plot(durations_t.numpy()) fig.savefig(f"main.png") # Test the model env = gym.make( agent.env_name, render_mode = "human" ) while True: state, _ = env.reset() state = torch.tensor( state, dtype = torch.float32, device = agent.compute_device ).unsqueeze(0) terminated = False truncated = False while not (terminated or truncated): # Predict best action given state action = agent.predict(state) # Do that action, get new state ( state, reward, terminated, truncated, _ ) = env.step(action) state = torch.tensor( state, dtype = torch.float32, device = agent.compute_device ).unsqueeze(0) env.render() # Environment needs to be reset after a session ends env.reset()