37 lines
565 B
Python
37 lines
565 B
Python
|
import torch
|
||
|
from collections import namedtuple
|
||
|
|
||
|
|
||
|
Transition = namedtuple(
|
||
|
"Transition",
|
||
|
(
|
||
|
"state",
|
||
|
"action",
|
||
|
"next_state",
|
||
|
"reward"
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class DQN(torch.nn.Module):
|
||
|
def __init__(self, n_observations: int, n_actions: int):
|
||
|
super(DQN, self).__init__()
|
||
|
|
||
|
self.layers = torch.nn.Sequential(
|
||
|
torch.nn.Linear(n_observations, 128),
|
||
|
torch.nn.ReLU(),
|
||
|
|
||
|
torch.nn.Linear(128, 128),
|
||
|
torch.nn.ReLU(),
|
||
|
|
||
|
torch.nn.Linear(128, 128),
|
||
|
torch.nn.ReLU(),
|
||
|
|
||
|
torch.torch.nn.Linear(128, n_actions)
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.layers(x)
|
||
|
|
||
|
|