Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/polecart/basic/util.py

77 lines
1.5 KiB
Python
Executable File

import matplotlib
import matplotlib.pyplot as plt
import torch
import math
import random
from collections import namedtuple
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def select_action(
state,
*,
# Number of steps that have been done
steps_done: int,
# TF parameters
policy_net, # DQN policy network
device, # Render device, "gpu" or "cpu"
env, # GYM environment instance
# Epsilon parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9,
EPS_END = 0.05,
EPS_DECAY = 1000
):
"""
Given a state, select an action using an epsilon-greedy policy.
Sometimes use our model, sometimes sample one uniformly.
P(random action) starts at EPS_START and decays to EPS_END.
Decay rate is controlled by EPS_DECAY.
"""
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
math.exp(
-1.0 * steps_done /
EPS_DECAY
)
)
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor(
[ [env.action_space.sample()] ],
device=device,
dtype=torch.long
)