Mark
/
celeste-ai
Archived
1
0
Fork 0
This repository has been archived on 2023-11-28. You can view files and clone it, but cannot push or open issues/pull-requests.
celeste-ai/celeste_ai/train.py

493 lines
12 KiB
Python
Raw Permalink Normal View History

2023-02-15 22:24:40 -08:00
from collections import namedtuple
from collections import deque
2023-02-18 19:28:02 -08:00
from pathlib import Path
2023-02-15 22:24:40 -08:00
import random
2023-02-15 19:24:19 -08:00
import math
2023-02-18 19:28:02 -08:00
import json
2023-02-15 22:24:40 -08:00
import torch
2023-02-26 12:13:21 -08:00
import shutil
2023-02-15 19:24:19 -08:00
2023-02-19 20:57:19 -08:00
from celeste_ai import Celeste
from celeste_ai import DQN
from celeste_ai import Transition
2023-02-26 12:13:21 -08:00
from celeste_ai.util.screenshots import ScreenshotManager
2023-02-15 22:24:40 -08:00
2023-02-18 19:50:43 -08:00
if __name__ == "__main__":
# Where to read/write model data.
2023-02-19 12:54:27 -08:00
model_data_root = Path("model_data/current")
2023-02-18 19:35:46 -08:00
2023-02-26 12:13:21 -08:00
sm = ScreenshotManager(
# Where PICO-8 saves screenshots.
# Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
2023-02-24 21:56:37 -08:00
2023-02-18 19:50:43 -08:00
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
2023-02-18 19:50:43 -08:00
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
2023-02-15 22:24:40 -08:00
2023-02-18 19:50:43 -08:00
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
2023-02-15 22:24:40 -08:00
2023-02-15 23:38:27 -08:00
2023-02-18 19:50:43 -08:00
# Epsilon-greedy parameters
2023-02-26 12:13:21 -08:00
# Probability of choosing a random action starts at
# EPS_START and decays to EPS_END.
# EPS_DECAY controls the rate of decay.
2023-02-18 19:50:43 -08:00
EPS_START = 0.9
EPS_END = 0.02
EPS_DECAY = 100
2023-02-18 19:50:43 -08:00
2023-02-26 12:13:21 -08:00
# Bellman equation time-discount factor
GAMMA = 0.9
2023-02-18 19:50:43 -08:00
2023-02-26 12:13:21 -08:00
# Train on this many transitions from
# replay memory each round
BATCH_SIZE = 100
2023-02-26 12:13:21 -08:00
# Controls target_net soft update.
2023-02-18 19:50:43 -08:00
# Should be between 0 and 1.
TAU = 0.05
2023-02-15 23:38:27 -08:00
2023-02-26 12:13:21 -08:00
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
2023-02-15 23:38:27 -08:00
2023-02-26 12:13:21 -08:00
# How many times we've reached each point.
# This is used to compute epsilon-greedy probability.
point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on
n_steps = 0 # Number of training steps we've completed
2023-02-15 23:38:27 -08:00
2023-02-18 19:50:43 -08:00
# Create replay memory.
#
2023-02-26 12:13:21 -08:00
# Holds <Transition> objects, defined in
# network.py
2023-02-19 12:54:27 -08:00
memory = deque([], maxlen=50_000)
2023-02-15 23:38:27 -08:00
2023-02-15 22:24:40 -08:00
2023-02-26 12:13:21 -08:00
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
2023-02-18 19:50:43 -08:00
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = learning_rate,
2023-02-18 19:50:43 -08:00
amsgrad = True
)
2023-02-15 23:38:27 -08:00
2023-02-18 19:50:43 -08:00
if model_save_path.is_file():
# Load model if one exists
2023-02-19 20:57:19 -08:00
checkpoint = torch.load(
model_save_path,
map_location = compute_device
)
2023-02-18 19:50:43 -08:00
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
2023-02-26 12:13:21 -08:00
n_episodes = checkpoint["n_episodes"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"]
2023-02-15 23:38:27 -08:00
2023-02-26 12:13:21 -08:00
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
2023-02-15 22:24:40 -08:00
"""
Select an action using an epsilon-greedy policy.
Sometimes use our model, sometimes sample one uniformly.
P(random action) starts at EPS_START and decays to EPS_END.
Decay rate is controlled by EPS_DECAY.
"""
# Calculate random step threshhold
eps_threshold = (
EPS_END + (EPS_START - EPS_END) *
2023-02-26 12:13:21 -08:00
math.exp(-1.0 * x / EPS_DECAY)
2023-02-15 22:24:40 -08:00
)
2023-02-26 12:13:21 -08:00
if random.random() > eps_threshold:
2023-02-15 22:24:40 -08:00
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1).item()
else:
return random.randint( 0, n_actions-1 )
2023-02-15 23:38:27 -08:00
def optimize_model():
if len(memory) < BATCH_SIZE:
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
2023-02-15 22:24:40 -08:00
2023-02-16 13:52:59 -08:00
2023-02-15 23:38:27 -08:00
# Get a random sample of transitions
batch = random.sample(memory, BATCH_SIZE)
2023-02-15 22:24:40 -08:00
2023-02-15 23:38:27 -08:00
# Conversion.
# Transposes batch, turning an array of Transitions
# into a Transition of arrays.
batch = Transition(*zip(*batch))
2023-02-15 22:24:40 -08:00
2023-02-15 23:38:27 -08:00
# Conversion.
# Combine states, actions, and rewards into their own tensors.
2023-02-26 12:13:21 -08:00
last_state_batch = torch.cat(batch.last_state)
2023-02-15 23:38:27 -08:00
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute a mask of non_final_states.
# Each element of this tensor corresponds to an element in the batch.
2023-02-19 12:54:27 -08:00
# True if this is a final state, False if it isn't.
2023-02-15 23:38:27 -08:00
#
# We use this to select non-final states later.
non_final_mask = torch.tensor(
tuple(map(
lambda s: s is not None,
batch.next_state
))
)
non_final_next_states = torch.cat(
[s for s in batch.next_state if s is not None]
)
# How .gather works:
# if out = a.gather(1, b),
# out[i, j] = a[ i ][ b[i,j] ]
#
# a is "input," b is "index"
# Compute Q(s_t, a).
# This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice.
2023-02-26 12:13:21 -08:00
state_action_values = policy_net(last_state_batch).gather(1, action_batch)
2023-02-15 23:38:27 -08:00
# Compute V(s_t+1) for all next states.
# V(s_t+1) = max_a ( Q(s_t+1, a) )
# = the maximum reward over all possible actions at state s_t+1.
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
2023-02-16 13:52:59 -08:00
2023-02-15 23:38:27 -08:00
with torch.no_grad():
2023-02-16 13:52:59 -08:00
2023-02-15 23:38:27 -08:00
# Note the use of non_final_mask here.
2023-02-16 13:52:59 -08:00
# States that are final do not have their reward set by the line
2023-02-15 23:38:27 -08:00
# below, so their reward stays at zero.
#
# States that are not final get their predicted value
# set to the best value the model predicts.
#
#
# Expected values of action are selected with the "older" target net,
# and their best reward (over possible actions) is selected with max(1)[0].
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
# TODO: What does this mean?
# "Compute expected Q values"
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
# Compute Huber loss between predicted reward and expected reward.
# Pytorch is will account for this when we compute the gradient of loss.
#
# loss is a single-element tensor (i.e, a scalar).
criterion = torch.nn.SmoothL1Loss()
loss = criterion(
state_action_values,
expected_state_action_values.unsqueeze(1)
)
2023-02-16 13:52:59 -08:00
2023-02-15 23:38:27 -08:00
# We can now run a step of backpropagation on our model.
# TODO: what does this do?
#
# Calling .backward() multiple times will accumulate parameter gradients.
# Thus, we reset the gradient before each step.
optimizer.zero_grad()
# Compute the gradient of loss wrt... something?
# TODO: what does this do, we never use loss again?!
loss.backward()
# Prevent vanishing and exploding gradients.
# Forces gradients to be in [-clip_value, +clip_value]
torch.nn.utils.clip_grad_value_( # type: ignore
policy_net.parameters(),
clip_value = 100
)
2023-02-15 22:24:40 -08:00
2023-02-15 23:38:27 -08:00
# Perform a single optimizer step.
#
# Uses the current gradient, which is stored
# in the .grad attribute of the parameter.
optimizer.step()
2023-02-19 12:54:27 -08:00
return loss
2023-02-15 23:38:27 -08:00
2023-02-16 12:11:04 -08:00
def on_state_before(celeste):
2023-02-18 19:28:02 -08:00
state = celeste.state
2023-02-15 23:38:27 -08:00
action = select_action(
2023-02-26 12:13:21 -08:00
# Put state in a tensor
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# Random action probability is determined by
# the number of times we've reached the next point.
point_counter[state.next_point]
)
2023-02-16 12:11:04 -08:00
2023-02-18 19:28:02 -08:00
# For manual testing
#str_action = ""
#while str_action not in Celeste.action_space:
# str_action = input("action> ")
#action = Celeste.action_space.index(str_action)
2023-02-16 12:11:04 -08:00
2023-02-26 12:13:21 -08:00
print(Celeste.action_space[action])
celeste.act(action)
2023-02-16 12:11:04 -08:00
2023-02-26 12:13:21 -08:00
return (
state, # CelesteState
action # Integer
)
2023-02-18 19:28:02 -08:00
2023-02-16 12:11:04 -08:00
2023-02-26 12:13:21 -08:00
def compute_reward(last_state, state):
global point_counter
2023-02-16 12:11:04 -08:00
2023-02-26 12:13:21 -08:00
reward = None
2023-02-15 23:38:27 -08:00
2023-02-24 17:46:07 -08:00
# No reward if dead
2023-02-26 12:13:21 -08:00
if state.deaths != 0:
2023-02-15 23:38:27 -08:00
reward = 0
2023-02-24 21:56:37 -08:00
# Reward for finishing a stage
2023-02-26 12:13:21 -08:00
elif state.stage >= 1:
print("FINISHED STAGE!!")
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
2023-02-24 17:46:07 -08:00
reward += 1
# Add to point counter
2023-02-26 12:13:21 -08:00
for i in range(last_state.next_point, len(point_counter)):
2023-02-24 17:46:07 -08:00
point_counter[i] += 1
2023-02-26 12:13:21 -08:00
# Reward for reaching a checkpoint
elif last_state.next_point != state.next_point:
print(f"Got point {state.next_point}")
reward = state.next_point - last_state.next_point
2023-02-15 23:38:27 -08:00
2023-02-26 12:13:21 -08:00
# Add to point counter
for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1
2023-02-18 19:28:02 -08:00
2023-02-26 12:13:21 -08:00
# No reward otherwise
else:
reward = 0
2023-02-26 12:13:21 -08:00
# Strawberry reward
# (Will probably break current version of model)
#if state.berries[state.stage] and not state.berries[state.stage]:
# print(f"Got stage {state.stage} bonus")
# reward += 1
2023-02-16 13:52:59 -08:00
2023-02-26 12:13:21 -08:00
assert reward is not None
return reward * 10
2023-02-24 21:56:37 -08:00
2023-02-26 12:13:21 -08:00
def on_state_after(celeste, before_out):
global n_episodes
global n_steps
2023-02-24 21:56:37 -08:00
2023-02-26 12:13:21 -08:00
last_state, action = before_out
next_state = celeste.state
dead = next_state.deaths != 0
done = next_state.stage >= 1
2023-02-24 21:56:37 -08:00
2023-02-15 23:38:27 -08:00
2023-02-26 12:13:21 -08:00
reward = compute_reward(last_state, next_state)
if dead:
next_state = None
elif done:
# We don't set the next state to None because
# the optimization routine forces zero reward
# for terminal states.
# Copy last state instead. It's a hack, but it
# should work.
next_state = last_state
2023-02-15 23:38:27 -08:00
# Add this state transition to memory.
memory.append(
Transition(
2023-02-26 12:13:21 -08:00
# last state
torch.tensor(
[getattr(last_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
2023-02-15 23:38:27 -08:00
)
)
print("==> ", reward)
2023-02-19 12:54:27 -08:00
print("")
2023-02-15 22:24:40 -08:00
2023-02-26 12:13:21 -08:00
# Perform a training step
2023-02-19 12:54:27 -08:00
loss = None
2023-02-15 23:38:27 -08:00
if len(memory) >= BATCH_SIZE:
2023-02-26 12:13:21 -08:00
n_steps += 1
2023-02-19 12:54:27 -08:00
loss = optimize_model()
2023-02-15 22:24:40 -08:00
2023-02-15 23:38:27 -08:00
# Soft update target_net weights
target_net_state = target_net.state_dict()
policy_net_state = policy_net.state_dict()
for key in policy_net_state:
target_net_state[key] = (
policy_net_state[key] * TAU +
target_net_state[key] * (1-TAU)
)
target_net.load_state_dict(target_net_state)
2023-02-15 22:24:40 -08:00
2023-02-26 12:13:21 -08:00
# Move on to the next episode and run
# housekeeping tasks.
if (dead or done):
2023-02-18 19:28:02 -08:00
s = celeste.state
2023-02-26 12:13:21 -08:00
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
2023-02-18 19:35:46 -08:00
with model_train_log.open("a") as f:
2023-02-18 19:28:02 -08:00
f.write(json.dumps({
2023-02-26 12:13:21 -08:00
"n_episodes": n_episodes,
"n_steps": n_steps,
2023-02-18 19:28:02 -08:00
"checkpoints": s.next_point,
2023-02-26 12:13:21 -08:00
"loss": None if loss is None else loss.item(),
"done": done
2023-02-18 19:28:02 -08:00
}) + "\n")
2023-02-18 19:28:02 -08:00
# Save a snapshot
2023-02-26 12:13:21 -08:00
if n_episodes % model_save_interval == 0:
save_model(model_archive_dir / f"{n_episodes}.torch")
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
2023-02-18 19:28:02 -08:00
2023-02-18 19:35:46 -08:00
print("Game over. Resetting.")
2023-02-15 23:38:27 -08:00
celeste.reset()
2023-02-15 22:24:40 -08:00
2023-02-15 19:24:19 -08:00
2023-02-26 12:13:21 -08:00
2023-02-18 19:50:43 -08:00
if __name__ == "__main__":
2023-02-19 20:57:19 -08:00
c = Celeste(
"resources/pico-8/linux/pico8"
)
2023-02-15 19:24:19 -08:00
2023-02-18 19:50:43 -08:00
c.update_loop(
on_state_before,
on_state_after
)