77 lines
1.5 KiB
Python
Executable File
77 lines
1.5 KiB
Python
Executable File
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
import math
|
|
import random
|
|
from collections import namedtuple
|
|
|
|
|
|
Transition = namedtuple(
|
|
"Transition",
|
|
(
|
|
"state",
|
|
"action",
|
|
"next_state",
|
|
"reward"
|
|
)
|
|
)
|
|
|
|
|
|
def select_action(
|
|
state,
|
|
|
|
*,
|
|
|
|
# Number of steps that have been done
|
|
steps_done: int,
|
|
|
|
# TF parameters
|
|
policy_net, # DQN policy network
|
|
device, # Render device, "gpu" or "cpu"
|
|
env, # GYM environment instance
|
|
|
|
# Epsilon parameters
|
|
#
|
|
# Original docs:
|
|
# EPS_START is the starting value of epsilon
|
|
# EPS_END is the final value of epsilon
|
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
EPS_START = 0.9,
|
|
EPS_END = 0.05,
|
|
EPS_DECAY = 1000
|
|
):
|
|
"""
|
|
Given a state, select an action using an epsilon-greedy policy.
|
|
|
|
Sometimes use our model, sometimes sample one uniformly.
|
|
|
|
P(random action) starts at EPS_START and decays to EPS_END.
|
|
Decay rate is controlled by EPS_DECAY.
|
|
"""
|
|
|
|
# Random number 0 <= x < 1
|
|
sample = random.random()
|
|
|
|
# Calculate random step threshhold
|
|
eps_threshold = (
|
|
EPS_END + (EPS_START - EPS_END) *
|
|
math.exp(
|
|
-1.0 * steps_done /
|
|
EPS_DECAY
|
|
)
|
|
)
|
|
|
|
if sample > eps_threshold:
|
|
with torch.no_grad():
|
|
# t.max(1) will return the largest column value of each row.
|
|
# second column on max result is index of where max element was
|
|
# found, so we pick action with the larger expected reward.
|
|
return policy_net(state).max(1)[1].view(1, 1)
|
|
|
|
else:
|
|
return torch.tensor(
|
|
[ [env.action_space.sample()] ],
|
|
device=device,
|
|
dtype=torch.long
|
|
) |