diff --git a/celeste/celeste_ai/train.py b/celeste/celeste_ai/train.py index 5d9fa0e..1d0d38e 100644 --- a/celeste/celeste_ai/train.py +++ b/celeste/celeste_ai/train.py @@ -5,33 +5,31 @@ import random import math import json import torch +import shutil + from celeste_ai import Celeste from celeste_ai import DQN from celeste_ai import Transition - +from celeste_ai.util.screenshots import ScreenshotManager if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") - # Where PICO-8 saves screenshots. - # Probably your desktop. - screenshot_source = Path("/home/mark/Desktop") + sm = ScreenshotManager( + # Where PICO-8 saves screenshots. + # Probably your desktop. + source = Path("/home/mark/Desktop"), + pattern = "hackcel_*.png", + target = model_data_root / "screenshots" + ).clean() # Remove old screenshots model_save_path = model_data_root / "model.torch" model_archive_dir = model_data_root / "model_archive" model_train_log = model_data_root / "train_log" - screenshot_dir = model_data_root / "screenshots" model_data_root.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True) - screenshot_dir.mkdir(parents = True, exist_ok = True) - - - # Remove old screenshots - shots = screenshot_source.glob("hackcel_*.png") - for s in shots: - s.unlink() compute_device = torch.device( @@ -45,66 +43,51 @@ if __name__ == "__main__": # Epsilon-greedy parameters - # - # Original docs: - # EPS_START is the starting value of epsilon - # EPS_END is the final value of epsilon - # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay + # Probability of choosing a random action starts at + # EPS_START and decays to EPS_END. + # EPS_DECAY controls the rate of decay. EPS_START = 0.9 EPS_END = 0.02 EPS_DECAY = 100 - # How many times we've reached each point. - # Used to compute epsilon-greedy probability with - # the parameters above. - point_counter = [0] * len(Celeste.target_checkpoints[0]) - - BATCH_SIZE = 100 - # Learning rate of target_net. - # Controls how soft our soft update is. - # - # Should be between 0 and 1. - # Large values - # Small values do the opposite. - # - # A value of one makes target_net - # change at the same rate as policy_net. - # - # A value of zero makes target_net - # not change at all. - TAU = 0.05 - - - # GAMMA is the discount factor as mentioned in the previous section + # Bellman equation time-discount factor GAMMA = 0.9 - steps_done = 0 - num_episodes = 100 - episode_number = 0 - archive_interval = 10 + # Train on this many transitions from + # replay memory each round + BATCH_SIZE = 100 + + # Controls target_net soft update. + # Should be between 0 and 1. + TAU = 0.05 + + # Optimizer learning rate + learning_rate = 0.001 + + # Save a snapshot of the model every n + # episodes. + model_save_interval = 10 + + + + # How many times we've reached each point. + # This is used to compute epsilon-greedy probability. + point_counter = [0] * len(Celeste.target_checkpoints[0]) + + n_episodes = 0 # Number of episodes we've trained on + n_steps = 0 # Number of training steps we've completed # Create replay memory. # - # Transition: a container for naming data (defined in util.py) - # Memory: a deque that holds recent states as Transitions - # Has a fixed length, drops oldest - # element if maxlen is exceeded. + # Holds objects, defined in + # network.py memory = deque([], maxlen=50_000) - policy_net = DQN( - n_observations, - n_actions - ).to(compute_device) - target_net = DQN( - n_observations, - n_actions - ).to(compute_device) + policy_net = DQN(n_observations, n_actions).to(compute_device) + target_net = DQN(n_observations, n_actions).to(compute_device) target_net.load_state_dict(policy_net.state_dict()) - - - learning_rate = 0.001 optimizer = torch.optim.AdamW( policy_net.parameters(), lr = learning_rate, @@ -122,11 +105,43 @@ if __name__ == "__main__": target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) memory = checkpoint["memory"] - episode_number = checkpoint["episode_number"] + 1 - steps_done = checkpoint["steps_done"] + + n_episodes = checkpoint["n_episodes"] + n_steps = checkpoint["n_steps"] point_counter = checkpoint["point_counter"] -def select_action(state, steps_done): + + +def save_model(path): + torch.save({ + # Newtorks + "policy_state_dict": policy_net.state_dict(), + "target_state_dict": target_net.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + + # Training data + "memory": memory, + "point_counter": point_counter, + "n_episodes": n_episodes, + "n_steps": n_steps, + + # Hyperparameters, + # for reference + "eps_start": EPS_START, + "eps_end": EPS_END, + "eps_decay": EPS_DECAY, + "batch_size": BATCH_SIZE, + "tau": TAU, + "learning_rate": learning_rate, + "gamma": GAMMA + }, path + ) + + + + + +def select_action(state, x) -> int: """ Select an action using an epsilon-greedy policy. @@ -136,19 +151,13 @@ def select_action(state, steps_done): Decay rate is controlled by EPS_DECAY. """ - # Random number 0 <= x < 1 - sample = random.random() - # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * - math.exp( - -1.0 * steps_done / - EPS_DECAY - ) + math.exp(-1.0 * x / EPS_DECAY) ) - if sample > eps_threshold: + if random.random() > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was @@ -175,7 +184,7 @@ def optimize_model(): # Conversion. # Combine states, actions, and rewards into their own tensors. - state_batch = torch.cat(batch.state) + last_state_batch = torch.cat(batch.last_state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) @@ -209,7 +218,7 @@ def optimize_model(): # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. - state_action_values = policy_net(state_batch).gather(1, action_batch) + state_action_values = policy_net(last_state_batch).gather(1, action_batch) @@ -282,36 +291,21 @@ def optimize_model(): def on_state_before(celeste): - global steps_done - state = celeste.state - pt_state = torch.tensor( - [getattr(state, x) for x in Celeste.state_number_map], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) - action = select_action( - pt_state, + # Put state in a tensor + torch.tensor( + [getattr(state, x) for x in Celeste.state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0), + + # Random action probability is determined by + # the number of times we've reached the next point. point_counter[state.next_point] ) - str_action = Celeste.action_space[action] - - - """ - action = None - while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): - action = select_action( - pt_state, - steps_done - ) - str_action = Celeste.action_space[action] - """ - - steps_done += 1 - # For manual testing #str_action = "" @@ -319,86 +313,114 @@ def on_state_before(celeste): # str_action = input("action> ") #action = Celeste.action_space.index(str_action) - print(str_action) - celeste.act(str_action) + print(Celeste.action_space[action]) + celeste.act(action) - return state, action - - -def on_state_after(celeste, before_out): - global episode_number - - state, action = before_out - next_state = celeste.state - - pt_state = torch.tensor( - [getattr(state, x) for x in Celeste.state_number_map], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) - - pt_action = torch.tensor( - [[ action ]], - device = compute_device, - dtype = torch.long + return ( + state, # CelesteState + action # Integer ) - finished_stage = False + +def compute_reward(last_state, state): + global point_counter + + reward = None + # No reward if dead - if next_state.deaths != 0: - pt_next_state = None + if state.deaths != 0: reward = 0 # Reward for finishing a stage - elif next_state.stage >= 1: - finished_stage = True - reward = next_state.next_point - state.next_point + elif state.stage >= 1: + print("FINISHED STAGE!!") + + # We don't set a fixed reward here because the agent may + # complete the stage before getting all points. + # The below line provides extra reward for taking shortcuts. + reward = state.next_point - last_state.next_point reward += 1 # Add to point counter - for i in range(state.next_point, state.next_point + reward): + for i in range(last_state.next_point, len(point_counter)): point_counter[i] += 1 - # Regular reward + # Reward for reaching a checkpoint + elif last_state.next_point != state.next_point: + print(f"Got point {state.next_point}") + + reward = state.next_point - last_state.next_point + + # Add to point counter + for i in range(last_state.next_point, last_state.next_point + reward): + point_counter[i] += 1 + + # No reward otherwise else: - pt_next_state = torch.tensor( - [getattr(next_state, x) for x in Celeste.state_number_map], - dtype = torch.float32, - device = compute_device - ).unsqueeze(0) + reward = 0 + + # Strawberry reward + # (Will probably break current version of model) + #if state.berries[state.stage] and not state.berries[state.stage]: + # print(f"Got stage {state.stage} bonus") + # reward += 1 + + assert reward is not None + return reward * 10 +def on_state_after(celeste, before_out): + global n_episodes + global n_steps - if state.next_point == next_state.next_point: - reward = 0 - else: - print(f"Got point {state.next_point}") - # Reward for reaching a point - reward = next_state.next_point - state.next_point - - # Add to point counter - for i in range(state.next_point, state.next_point + reward): - point_counter[i] += 1 - - # Strawberry reward - if next_state.berries[state.stage] and not state.berries[state.stage]: - print(f"Got stage {state.stage} bonus") - reward += 1 + last_state, action = before_out + next_state = celeste.state + dead = next_state.deaths != 0 + done = next_state.stage >= 1 - - - reward = reward * 10 - pt_reward = torch.tensor([reward], device = compute_device) - + reward = compute_reward(last_state, next_state) + + if dead: + next_state = None + elif done: + # We don't set the next state to None because + # the optimization routine forces zero reward + # for terminal states. + # Copy last state instead. It's a hack, but it + # should work. + next_state = last_state # Add this state transition to memory. memory.append( Transition( - pt_state, - pt_action, - pt_next_state, - pt_reward + # last state + torch.tensor( + [getattr(last_state, x) for x in Celeste.state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0), + + # action + torch.tensor( + [[ action ]], + device = compute_device, + dtype = torch.long + ), + + # next state + # None if dead or done. + torch.tensor( + [getattr(next_state, x) for x in Celeste.state_number_map], + dtype = torch.float32, + device = compute_device + ).unsqueeze(0) if next_state is not None else None, + + # reward + torch.tensor( + [reward], + device = compute_device + ) ) ) @@ -406,11 +428,10 @@ def on_state_after(celeste, before_out): print("") + # Perform a training step loss = None - - # Only train the network if we have enough - # transitions in memory to do so. if len(memory) >= BATCH_SIZE: + n_steps += 1 loss = optimize_model() # Soft update target_net weights @@ -423,65 +444,43 @@ def on_state_after(celeste, before_out): ) target_net.load_state_dict(target_net_state) - # Move on to the next episode once we reach - # a terminal state. - if (next_state.deaths != 0 or finished_stage): + + + # Move on to the next episode and run + # housekeeping tasks. + if (dead or done): s = celeste.state + n_episodes += 1 + + # Move screenshots + sm.move( + number = n_episodes, + overwrite = True + ) + + + # Log this episode with model_train_log.open("a") as f: f.write(json.dumps({ + "n_episodes": n_episodes, + "n_steps": n_steps, "checkpoints": s.next_point, - "state_count": s.state_count, - "loss": None if loss is None else loss.item() + "loss": None if loss is None else loss.item(), + "done": done }) + "\n") - # Save model - torch.save({ - "policy_state_dict": policy_net.state_dict(), - "target_state_dict": target_net.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "memory": memory, - "point_counter": point_counter, - "episode_number": episode_number, - "steps_done": steps_done, - - # Hyperparameters - "eps_start": EPS_START, - "eps_end": EPS_END, - "eps_decay": EPS_DECAY, - "batch_size": BATCH_SIZE, - "tau": TAU, - "learning_rate": learning_rate, - "gamma": GAMMA - }, model_save_path) - - - # Clean up screenshots - shots = screenshot_source.glob("hackcel_*.png") - - target = screenshot_dir / Path(f"{episode_number}") - target.mkdir(parents = True) - - for s in shots: - s.rename(target / s.name) - # Save a snapshot - if episode_number % archive_interval == 0: - torch.save({ - "policy_state_dict": policy_net.state_dict(), - "target_state_dict": target_net.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "memory": memory, - "episode_number": episode_number, - "steps_done": steps_done - }, model_archive_dir / f"{episode_number}.torch") + if n_episodes % model_save_interval == 0: + save_model(model_archive_dir / f"{n_episodes}.torch") + shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path) print("Game over. Resetting.") - episode_number += 1 celeste.reset() + if __name__ == "__main__": c = Celeste( "resources/pico-8/linux/pico8" diff --git a/celeste/celeste_ai/util/screenshots.py b/celeste/celeste_ai/util/screenshots.py new file mode 100644 index 0000000..17bbcf4 --- /dev/null +++ b/celeste/celeste_ai/util/screenshots.py @@ -0,0 +1,69 @@ +from pathlib import Path +import shutil + + +class ScreenshotManager: + def __init__( + self, + + # Where PICO-8 saves screenshots + source: Path, + + # How PICO-8 names screenshots. + # Example: "celeste_*.png" + pattern: str, + + # Where we want to move screenshots. + target: Path + ): + self.source = source + self.pattern = pattern + self.target = target + self.target.mkdir( + parents = True, + exist_ok = True + ) + + + + def clean(self): + shots = self.source.glob(self.pattern) + for s in shots: + s.unlink() + return self + + + + def move(self, number: int | None = None, overwrite = False): + shots = self.source.glob(self.pattern) + + if number == None: + + # Auto-select new directory number. + # Chooses next highest int directory name + number = 0 + for f in self.target.iterdir(): + try: + number = max( + int(f.name), + number + ) + except ValueError: + continue + number += 1 + + else: + target = self.target / str(number) + + if target.exists(): + if not overwrite: + raise Exception(f"Target \"{target}\" exists!") + else: + print(f"Target \"{target}\" exists, removing.") + shutil.rmtree(target) + + target.mkdir(parents = True) + + for s in shots: + s.rename(target / s.name) + return self \ No newline at end of file