from collections import namedtuple from collections import deque from pathlib import Path import random import math import json import torch import shutil from celeste_ai import Celeste from celeste_ai import DQN from celeste_ai import Transition from celeste_ai.util.screenshots import ScreenshotManager if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") sm = ScreenshotManager( # Where PICO-8 saves screenshots. # Probably your desktop. source = Path("/home/mark/Desktop"), pattern = "hackcel_*.png", target = model_data_root / "screenshots" ).clean() # Remove old screenshots model_save_path = model_data_root / "model.torch" model_archive_dir = model_data_root / "model_archive" model_train_log = model_data_root / "train_log" model_data_root.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True) compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) # Epsilon-greedy parameters # Probability of choosing a random action starts at # EPS_START and decays to EPS_END. # EPS_DECAY controls the rate of decay. EPS_START = 0.9 EPS_END = 0.02 EPS_DECAY = 100 # Bellman equation time-discount factor GAMMA = 0.9 # Train on this many transitions from # replay memory each round BATCH_SIZE = 100 # Controls target_net soft update. # Should be between 0 and 1. TAU = 0.05 # Optimizer learning rate learning_rate = 0.001 # Save a snapshot of the model every n # episodes. model_save_interval = 10 # How many times we've reached each point. # This is used to compute epsilon-greedy probability. point_counter = [0] * len(Celeste.target_checkpoints[0]) n_episodes = 0 # Number of episodes we've trained on n_steps = 0 # Number of training steps we've completed # Create replay memory. # # Holds objects, defined in # network.py memory = deque([], maxlen=50_000) policy_net = DQN(n_observations, n_actions).to(compute_device) target_net = DQN(n_observations, n_actions).to(compute_device) target_net.load_state_dict(policy_net.state_dict()) optimizer = torch.optim.AdamW( policy_net.parameters(), lr = learning_rate, amsgrad = True ) if model_save_path.is_file(): # Load model if one exists checkpoint = torch.load( model_save_path, map_location = compute_device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) memory = checkpoint["memory"] n_episodes = checkpoint["n_episodes"] n_steps = checkpoint["n_steps"] point_counter = checkpoint["point_counter"] def save_model(path): torch.save({ # Newtorks "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), # Training data "memory": memory, "point_counter": point_counter, "n_episodes": n_episodes, "n_steps": n_steps, # Hyperparameters, # for reference "eps_start": EPS_START, "eps_end": EPS_END, "eps_decay": EPS_DECAY, "batch_size": BATCH_SIZE, "tau": TAU, "learning_rate": learning_rate, "gamma": GAMMA }, path ) def select_action(state, x) -> int: """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * math.exp(-1.0 * x / EPS_DECAY) ) if random.random() > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1).item() else: return random.randint( 0, n_actions-1 ) def optimize_model(): if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(memory, BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. last_state_batch = torch.cat(batch.last_state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it isn't. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # Compute Q(s_t, a). # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = policy_net(last_state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = torch.nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. optimizer.step() return loss def on_state_before(celeste): state = celeste.state action = select_action( # Put state in a tensor torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0), # Random action probability is determined by # the number of times we've reached the next point. point_counter[state.next_point] ) # For manual testing #str_action = "" #while str_action not in Celeste.action_space: # str_action = input("action> ") #action = Celeste.action_space.index(str_action) print(Celeste.action_space[action]) celeste.act(action) return ( state, # CelesteState action # Integer ) def compute_reward(last_state, state): global point_counter reward = None # No reward if dead if state.deaths != 0: reward = 0 # Reward for finishing a stage elif state.stage >= 1: print("FINISHED STAGE!!") # We don't set a fixed reward here because the agent may # complete the stage before getting all points. # The below line provides extra reward for taking shortcuts. reward = state.next_point - last_state.next_point reward += 1 # Add to point counter for i in range(last_state.next_point, len(point_counter)): point_counter[i] += 1 # Reward for reaching a checkpoint elif last_state.next_point != state.next_point: print(f"Got point {state.next_point}") reward = state.next_point - last_state.next_point # Add to point counter for i in range(last_state.next_point, last_state.next_point + reward): point_counter[i] += 1 # No reward otherwise else: reward = 0 # Strawberry reward # (Will probably break current version of model) #if state.berries[state.stage] and not state.berries[state.stage]: # print(f"Got stage {state.stage} bonus") # reward += 1 assert reward is not None return reward * 10 def on_state_after(celeste, before_out): global n_episodes global n_steps last_state, action = before_out next_state = celeste.state dead = next_state.deaths != 0 done = next_state.stage >= 1 reward = compute_reward(last_state, next_state) if dead: next_state = None elif done: # We don't set the next state to None because # the optimization routine forces zero reward # for terminal states. # Copy last state instead. It's a hack, but it # should work. next_state = last_state # Add this state transition to memory. memory.append( Transition( # last state torch.tensor( [getattr(last_state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0), # action torch.tensor( [[ action ]], device = compute_device, dtype = torch.long ), # next state # None if dead or done. torch.tensor( [getattr(next_state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) if next_state is not None else None, # reward torch.tensor( [reward], device = compute_device ) ) ) print("==> ", reward) print("") # Perform a training step loss = None if len(memory) >= BATCH_SIZE: n_steps += 1 loss = optimize_model() # Soft update target_net weights target_net_state = target_net.state_dict() policy_net_state = policy_net.state_dict() for key in policy_net_state: target_net_state[key] = ( policy_net_state[key] * TAU + target_net_state[key] * (1-TAU) ) target_net.load_state_dict(target_net_state) # Move on to the next episode and run # housekeeping tasks. if (dead or done): s = celeste.state n_episodes += 1 # Move screenshots sm.move( number = n_episodes, overwrite = True ) # Log this episode with model_train_log.open("a") as f: f.write(json.dumps({ "n_episodes": n_episodes, "n_steps": n_steps, "checkpoints": s.next_point, "loss": None if loss is None else loss.item(), "done": done }) + "\n") # Save a snapshot if n_episodes % model_save_interval == 0: save_model(model_archive_dir / f"{n_episodes}.torch") shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path) print("Game over. Resetting.") celeste.reset() if __name__ == "__main__": c = Celeste( "resources/pico-8/linux/pico8" ) c.update_loop( on_state_before, on_state_after )