Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste/main.py

180 lines
3.2 KiB
Python

from collections import namedtuple
from collections import deque
import random
import math
import torch
# Glue layer
from celeste import Celeste
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Epsilon-greedy parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layer1 = torch.nn.Linear(n_observations, 128)
self.layer2 = torch.nn.Linear(128, 128)
self.layer3 = torch.nn.Linear(128, n_actions)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
x = torch.nn.functional.relu(self.layer1(x))
x = torch.nn.functional.relu(self.layer2(x))
return self.layer3(x)
# Celeste env properties
n_observations = 4
n_actions = len(Celeste.action_space)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
def select_action(state, steps_done):
"""
Select an action using an epsilon-greedy policy.
Sometimes use our model, sometimes sample one uniformly.
P(random action) starts at EPS_START and decays to EPS_END.
Decay rate is controlled by EPS_DECAY.
"""
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
math.exp(
-1.0 * steps_done /
EPS_DECAY
)
)
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1).item()
else:
return random.randint( 0, n_actions-1 )
last_state = None
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def on_state(celeste):
global last_state
s = celeste.status
if last_state is None:
last_state = s
return
s_next = s["next_point"]
s_dist = s["dist"]
l_next = last_state["next_point"]
l_dist = last_state["dist"]
if l_next == s_next:
reward = l_dist - s_dist
else:
reward = 10
dead = s["deaths"] != 0
frame_count = s["frame_count"]
# Values at this point
# reward: reward for last action
# dead: true if game over
state_number_map = [
"xpos",
"ypos",
"xvel",
"yvel"
]
tf_state = torch.tensor(
[s[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
tf_last = torch.tensor(
[last_state[x] for x in state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = select_action(
tf_state,
frame_count
)
# Turn number into action string
action = Celeste.action_space[action]
celeste.act(action)
# Update previous state
last_state = s
c = Celeste(
on_state
)
c.update_loop()