from collections import namedtuple from collections import deque from pathlib import Path import random import math import json import torch from celeste_ai import Celeste from celeste_ai import DQN from celeste_ai import Transition if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") model_save_path = model_data_root / "model.torch" model_archive_dir = model_data_root / "model_archive" model_train_log = model_data_root / "train_log" screenshot_dir = model_data_root / "screenshots" model_data_root.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True) screenshot_dir.mkdir(parents = True, exist_ok = True) # Remove old screenshots shots = Path("/home/mark/Desktop").glob("hackcel_*.png") for s in shots: s.unlink() compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) # Epsilon-greedy parameters # # Original docs: # EPS_START is the starting value of epsilon # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9 EPS_END = 0.02 EPS_DECAY = 100 # How many times we've reached each point. # Used to compute epsilon-greedy probability with # the parameters above. point_counter = [0] * len(Celeste.target_checkpoints[0]) BATCH_SIZE = 100 # Learning rate of target_net. # Controls how soft our soft update is. # # Should be between 0 and 1. # Large values # Small values do the opposite. # # A value of one makes target_net # change at the same rate as policy_net. # # A value of zero makes target_net # not change at all. TAU = 0.05 # GAMMA is the discount factor as mentioned in the previous section GAMMA = 0.9 steps_done = 0 num_episodes = 100 episode_number = 0 archive_interval = 10 # Create replay memory. # # Transition: a container for naming data (defined in util.py) # Memory: a deque that holds recent states as Transitions # Has a fixed length, drops oldest # element if maxlen is exceeded. memory = deque([], maxlen=50_000) policy_net = DQN( n_observations, n_actions ).to(compute_device) target_net = DQN( n_observations, n_actions ).to(compute_device) target_net.load_state_dict(policy_net.state_dict()) learning_rate = 0.001 optimizer = torch.optim.AdamW( policy_net.parameters(), lr = learning_rate, amsgrad = True ) if model_save_path.is_file(): # Load model if one exists checkpoint = torch.load( model_save_path, map_location = compute_device ) policy_net.load_state_dict(checkpoint["policy_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) memory = checkpoint["memory"] episode_number = checkpoint["episode_number"] + 1 steps_done = checkpoint["steps_done"] point_counter = checkpoint["point_counter"] def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Random number 0 <= x < 1 sample = random.random() # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * math.exp( -1.0 * steps_done / EPS_DECAY ) ) if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1).item() else: return random.randint( 0, n_actions-1 ) def optimize_model(): if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(memory, BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it isn't. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # Compute Q(s_t, a). # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = torch.nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. optimizer.step() return loss def on_state_before(celeste): global steps_done # Conversion to pytorch state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = select_action( pt_state, point_counter[state.next_point] ) str_action = Celeste.action_space[action] """ action = None while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): action = select_action( pt_state, steps_done ) str_action = Celeste.action_space[action] """ steps_done += 1 # For manual testing #str_action = "" #while str_action not in Celeste.action_space: # str_action = input("action> ") #action = Celeste.action_space.index(str_action) print(str_action) celeste.act(str_action) return state, action def on_state_after(celeste, before_out): global episode_number state, action = before_out next_state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) pt_action = torch.tensor( [[ action ]], device = compute_device, dtype = torch.long ) finished_stage = False # No reward if dead if next_state.deaths != 0: pt_next_state = None reward = 0 # Reward for finishing stage elif next_state.stage >= 1: finished_stage = True reward = next_state.next_point - state.next_point reward += 1 # Add to point counter for i in range(state.next_point, state.next_point + reward): point_counter[i] += 1 # Regular reward else: pt_next_state = torch.tensor( [getattr(next_state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) if state.next_point == next_state.next_point: reward = 0 else: # Reward for reaching a point reward = next_state.next_point - state.next_point # Add to point counter for i in range(state.next_point, state.next_point + reward): point_counter[i] += 1 reward = reward * 10 pt_reward = torch.tensor([reward], device = compute_device) # Add this state transition to memory. memory.append( Transition( pt_state, pt_action, pt_next_state, pt_reward ) ) print("==> ", reward) print("") loss = None # Only train the network if we have enough # transitions in memory to do so. if len(memory) >= BATCH_SIZE: loss = optimize_model() # Soft update target_net weights target_net_state = target_net.state_dict() policy_net_state = policy_net.state_dict() for key in policy_net_state: target_net_state[key] = ( policy_net_state[key] * TAU + target_net_state[key] * (1-TAU) ) target_net.load_state_dict(target_net_state) # Move on to the next episode once we reach # a terminal state. if (next_state.deaths != 0 or finished_stage): s = celeste.state with model_train_log.open("a") as f: f.write(json.dumps({ "checkpoints": s.next_point, "state_count": s.state_count, "loss": None if loss is None else loss.item() }) + "\n") # Save model torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "memory": memory, "point_counter": point_counter, "episode_number": episode_number, "steps_done": steps_done, # Hyperparameters "eps_start": EPS_START, "eps_end": EPS_END, "eps_decay": EPS_DECAY, "batch_size": BATCH_SIZE, "tau": TAU, "learning_rate": learning_rate, "gamma": GAMMA }, model_save_path) # Clean up screenshots shots = Path("/home/mark/Desktop").glob("hackcel_*.png") target = screenshot_dir / Path(f"{episode_number}") target.mkdir(parents = True) for s in shots: s.rename(target / s.name) # Save a snapshot if episode_number % archive_interval == 0: torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "memory": memory, "episode_number": episode_number, "steps_done": steps_done }, model_archive_dir / f"{episode_number}.torch") print("Game over. Resetting.") episode_number += 1 celeste.reset() if __name__ == "__main__": c = Celeste( "resources/pico-8/linux/pico8" ) c.update_loop( on_state_before, on_state_after )