Mark
/
celeste-ai
Archived
1
0
Fork 0

Cleanup & slight optimizations

master
Mark 2023-02-26 12:13:21 -08:00
parent c185965657
commit 25390f5455
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 269 additions and 201 deletions

View File

@ -5,33 +5,31 @@ import random
import math import math
import json import json
import torch import torch
import shutil
from celeste_ai import Celeste from celeste_ai import Celeste
from celeste_ai import DQN from celeste_ai import DQN
from celeste_ai import Transition from celeste_ai import Transition
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__": if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data/current") model_data_root = Path("model_data/current")
# Where PICO-8 saves screenshots. sm = ScreenshotManager(
# Probably your desktop. # Where PICO-8 saves screenshots.
screenshot_source = Path("/home/mark/Desktop") # Probably your desktop.
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log" model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True) model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
# Remove old screenshots
shots = screenshot_source.glob("hackcel_*.png")
for s in shots:
s.unlink()
compute_device = torch.device( compute_device = torch.device(
@ -45,66 +43,51 @@ if __name__ == "__main__":
# Epsilon-greedy parameters # Epsilon-greedy parameters
# # Probability of choosing a random action starts at
# Original docs: # EPS_START and decays to EPS_END.
# EPS_START is the starting value of epsilon # EPS_DECAY controls the rate of decay.
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.02 EPS_END = 0.02
EPS_DECAY = 100 EPS_DECAY = 100
# How many times we've reached each point. # Bellman equation time-discount factor
# Used to compute epsilon-greedy probability with
# the parameters above.
point_counter = [0] * len(Celeste.target_checkpoints[0])
BATCH_SIZE = 100
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.05
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9 GAMMA = 0.9
steps_done = 0 # Train on this many transitions from
num_episodes = 100 # replay memory each round
episode_number = 0 BATCH_SIZE = 100
archive_interval = 10
# Controls target_net soft update.
# Should be between 0 and 1.
TAU = 0.05
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
# How many times we've reached each point.
# This is used to compute epsilon-greedy probability.
point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on
n_steps = 0 # Number of training steps we've completed
# Create replay memory. # Create replay memory.
# #
# Transition: a container for naming data (defined in util.py) # Holds <Transition> objects, defined in
# Memory: a deque that holds recent states as Transitions # network.py
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=50_000) memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
learning_rate = 0.001
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = learning_rate, lr = learning_rate,
@ -122,11 +105,43 @@ if __name__ == "__main__":
target_net.load_state_dict(checkpoint["target_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"] memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"] n_episodes = checkpoint["n_episodes"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"] point_counter = checkpoint["point_counter"]
def select_action(state, steps_done):
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
""" """
Select an action using an epsilon-greedy policy. Select an action using an epsilon-greedy policy.
@ -136,19 +151,13 @@ def select_action(state, steps_done):
Decay rate is controlled by EPS_DECAY. Decay rate is controlled by EPS_DECAY.
""" """
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold # Calculate random step threshhold
eps_threshold = ( eps_threshold = (
EPS_END + (EPS_START - EPS_END) * EPS_END + (EPS_START - EPS_END) *
math.exp( math.exp(-1.0 * x / EPS_DECAY)
-1.0 * steps_done /
EPS_DECAY
)
) )
if sample > eps_threshold: if random.random() > eps_threshold:
with torch.no_grad(): with torch.no_grad():
# t.max(1) will return the largest column value of each row. # t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was # second column on max result is index of where max element was
@ -175,7 +184,7 @@ def optimize_model():
# Conversion. # Conversion.
# Combine states, actions, and rewards into their own tensors. # Combine states, actions, and rewards into their own tensors.
state_batch = torch.cat(batch.state) last_state_batch = torch.cat(batch.last_state)
action_batch = torch.cat(batch.action) action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward) reward_batch = torch.cat(batch.reward)
@ -209,7 +218,7 @@ def optimize_model():
# This gives us a tensor that contains the return we expect to get # This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice. # at that state if we follow the model's advice.
state_action_values = policy_net(state_batch).gather(1, action_batch) state_action_values = policy_net(last_state_batch).gather(1, action_batch)
@ -282,36 +291,21 @@ def optimize_model():
def on_state_before(celeste): def on_state_before(celeste):
global steps_done
state = celeste.state state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = select_action( action = select_action(
pt_state, # Put state in a tensor
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# Random action probability is determined by
# the number of times we've reached the next point.
point_counter[state.next_point] point_counter[state.next_point]
) )
str_action = Celeste.action_space[action]
"""
action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action(
pt_state,
steps_done
)
str_action = Celeste.action_space[action]
"""
steps_done += 1
# For manual testing # For manual testing
#str_action = "" #str_action = ""
@ -319,86 +313,114 @@ def on_state_before(celeste):
# str_action = input("action> ") # str_action = input("action> ")
#action = Celeste.action_space.index(str_action) #action = Celeste.action_space.index(str_action)
print(str_action) print(Celeste.action_space[action])
celeste.act(str_action) celeste.act(action)
return state, action return (
state, # CelesteState
action # Integer
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
) )
finished_stage = False
def compute_reward(last_state, state):
global point_counter
reward = None
# No reward if dead # No reward if dead
if next_state.deaths != 0: if state.deaths != 0:
pt_next_state = None
reward = 0 reward = 0
# Reward for finishing a stage # Reward for finishing a stage
elif next_state.stage >= 1: elif state.stage >= 1:
finished_stage = True print("FINISHED STAGE!!")
reward = next_state.next_point - state.next_point
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
reward += 1 reward += 1
# Add to point counter # Add to point counter
for i in range(state.next_point, state.next_point + reward): for i in range(last_state.next_point, len(point_counter)):
point_counter[i] += 1 point_counter[i] += 1
# Regular reward # Reward for reaching a checkpoint
elif last_state.next_point != state.next_point:
print(f"Got point {state.next_point}")
reward = state.next_point - last_state.next_point
# Add to point counter
for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1
# No reward otherwise
else: else:
pt_next_state = torch.tensor( reward = 0
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32, # Strawberry reward
device = compute_device # (Will probably break current version of model)
).unsqueeze(0) #if state.berries[state.stage] and not state.berries[state.stage]:
# print(f"Got stage {state.stage} bonus")
# reward += 1
assert reward is not None
return reward * 10
def on_state_after(celeste, before_out):
global n_episodes
global n_steps
if state.next_point == next_state.next_point: last_state, action = before_out
reward = 0 next_state = celeste.state
else: dead = next_state.deaths != 0
print(f"Got point {state.next_point}") done = next_state.stage >= 1
# Reward for reaching a point
reward = next_state.next_point - state.next_point
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
point_counter[i] += 1
# Strawberry reward
if next_state.berries[state.stage] and not state.berries[state.stage]:
print(f"Got stage {state.stage} bonus")
reward += 1
reward = compute_reward(last_state, next_state)
reward = reward * 10 if dead:
pt_reward = torch.tensor([reward], device = compute_device) next_state = None
elif done:
# We don't set the next state to None because
# the optimization routine forces zero reward
# for terminal states.
# Copy last state instead. It's a hack, but it
# should work.
next_state = last_state
# Add this state transition to memory. # Add this state transition to memory.
memory.append( memory.append(
Transition( Transition(
pt_state, # last state
pt_action, torch.tensor(
pt_next_state, [getattr(last_state, x) for x in Celeste.state_number_map],
pt_reward dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
) )
) )
@ -406,11 +428,10 @@ def on_state_after(celeste, before_out):
print("") print("")
# Perform a training step
loss = None loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE: if len(memory) >= BATCH_SIZE:
n_steps += 1
loss = optimize_model() loss = optimize_model()
# Soft update target_net weights # Soft update target_net weights
@ -423,65 +444,43 @@ def on_state_after(celeste, before_out):
) )
target_net.load_state_dict(target_net_state) target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage): # Move on to the next episode and run
# housekeeping tasks.
if (dead or done):
s = celeste.state s = celeste.state
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
with model_train_log.open("a") as f: with model_train_log.open("a") as f:
f.write(json.dumps({ f.write(json.dumps({
"n_episodes": n_episodes,
"n_steps": n_steps,
"checkpoints": s.next_point, "checkpoints": s.next_point,
"state_count": s.state_count, "loss": None if loss is None else loss.item(),
"loss": None if loss is None else loss.item() "done": done
}) + "\n") }) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"point_counter": point_counter,
"episode_number": episode_number,
"steps_done": steps_done,
# Hyperparameters
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, model_save_path)
# Clean up screenshots
shots = screenshot_source.glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a snapshot # Save a snapshot
if episode_number % archive_interval == 0: if n_episodes % model_save_interval == 0:
torch.save({ save_model(model_archive_dir / f"{n_episodes}.torch")
"policy_state_dict": policy_net.state_dict(), shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_archive_dir / f"{episode_number}.torch")
print("Game over. Resetting.") print("Game over. Resetting.")
episode_number += 1
celeste.reset() celeste.reset()
if __name__ == "__main__": if __name__ == "__main__":
c = Celeste( c = Celeste(
"resources/pico-8/linux/pico8" "resources/pico-8/linux/pico8"

View File

@ -0,0 +1,69 @@
from pathlib import Path
import shutil
class ScreenshotManager:
def __init__(
self,
# Where PICO-8 saves screenshots
source: Path,
# How PICO-8 names screenshots.
# Example: "celeste_*.png"
pattern: str,
# Where we want to move screenshots.
target: Path
):
self.source = source
self.pattern = pattern
self.target = target
self.target.mkdir(
parents = True,
exist_ok = True
)
def clean(self):
shots = self.source.glob(self.pattern)
for s in shots:
s.unlink()
return self
def move(self, number: int | None = None, overwrite = False):
shots = self.source.glob(self.pattern)
if number == None:
# Auto-select new directory number.
# Chooses next highest int directory name
number = 0
for f in self.target.iterdir():
try:
number = max(
int(f.name),
number
)
except ValueError:
continue
number += 1
else:
target = self.target / str(number)
if target.exists():
if not overwrite:
raise Exception(f"Target \"{target}\" exists!")
else:
print(f"Target \"{target}\" exists, removing.")
shutil.rmtree(target)
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
return self