import torch from collections import namedtuple Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(n_observations, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.torch.nn.Linear(128, n_actions) ) def forward(self, x): return self.layers(x)