from collections import namedtuple from collections import deque from pathlib import Path import random import math import json import torch from celeste import Celeste if __name__ == "__main__": # Where to read/write model data. model_data_root = Path("model_data/current") model_save_path = model_data_root / "model.torch" model_archive_dir = model_data_root / "model_archive" model_train_log = model_data_root / "train_log" screenshot_dir = model_data_root / "screenshots" model_data_root.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True) screenshot_dir.mkdir(parents = True, exist_ok = True) compute_device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) # Celeste env properties n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) # Epsilon-greedy parameters # # Original docs: # EPS_START is the starting value of epsilon # EPS_END is the final value of epsilon # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 4000 BATCH_SIZE = 1_000 # Learning rate of target_net. # Controls how soft our soft update is. # # Should be between 0 and 1. # Large values # Small values do the opposite. # # A value of one makes target_net # change at the same rate as policy_net. # # A value of zero makes target_net # not change at all. TAU = 0.005 # GAMMA is the discount factor as mentioned in the previous section GAMMA = 0.9 # Outline our network class DQN(torch.nn.Module): def __init__(self, n_observations: int, n_actions: int): super(DQN, self).__init__() self.layers = torch.nn.Sequential( torch.nn.Linear(n_observations, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.nn.Linear(128, 128), torch.nn.ReLU(), torch.torch.nn.Linear(128, n_actions) ) # Can be called with one input, or with a batch. # # Returns tensor( # [ Q(s, left), Q(s, right) ], ... # ) # # Recall that Q(s, a) is the (expected) return of taking # action `a` at state `s` def forward(self, x): return self.layers(x) Transition = namedtuple( "Transition", ( "state", "action", "next_state", "reward" ) ) if __name__ == "__main__": steps_done = 0 num_episodes = 100 episode_number = 0 archive_interval = 10 # Create replay memory. # # Transition: a container for naming data (defined in util.py) # Memory: a deque that holds recent states as Transitions # Has a fixed length, drops oldest # element if maxlen is exceeded. memory = deque([], maxlen=50_000) policy_net = DQN( n_observations, n_actions ).to(compute_device) target_net = DQN( n_observations, n_actions ).to(compute_device) target_net.load_state_dict(policy_net.state_dict()) optimizer = torch.optim.AdamW( policy_net.parameters(), lr = 0.01, # Hyperparameter: learning rate amsgrad = True ) if model_save_path.is_file(): # Load model if one exists checkpoint = torch.load(model_save_path) policy_net.load_state_dict(checkpoint["policy_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) memory = checkpoint["memory"] episode_number = checkpoint["episode_number"] + 1 steps_done = checkpoint["steps_done"] def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. Sometimes use our model, sometimes sample one uniformly. P(random action) starts at EPS_START and decays to EPS_END. Decay rate is controlled by EPS_DECAY. """ # Random number 0 <= x < 1 sample = random.random() # Calculate random step threshhold eps_threshold = ( EPS_END + (EPS_START - EPS_END) * math.exp( -1.0 * steps_done / EPS_DECAY ) ) if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1)[1].view(1, 1).item() else: return random.randint( 0, n_actions-1 ) def optimize_model(): if len(memory) < BATCH_SIZE: raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}") # Get a random sample of transitions batch = random.sample(memory, BATCH_SIZE) # Conversion. # Transposes batch, turning an array of Transitions # into a Transition of arrays. batch = Transition(*zip(*batch)) # Conversion. # Combine states, actions, and rewards into their own tensors. state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute a mask of non_final_states. # Each element of this tensor corresponds to an element in the batch. # True if this is a final state, False if it isn't. # # We use this to select non-final states later. non_final_mask = torch.tensor( tuple(map( lambda s: s is not None, batch.next_state )) ) non_final_next_states = torch.cat( [s for s in batch.next_state if s is not None] ) # How .gather works: # if out = a.gather(1, b), # out[i, j] = a[ i ][ b[i,j] ] # # a is "input," b is "index" # If this doesn't make sense, RTFD. # Compute Q(s_t, a). # - Use policy_net to compute Q(s_t) for each state in the batch. # This gives a tensor of [ Q(state, left), Q(state, right) ] # # - Action batch is a tensor that looks like [ [0], [1], [1], ... ] # listing the action that was taken in each transition. # 0 => we went left, 1 => we went right. # # This aligns nicely with the output of policy_net. We use # action_batch to index the output of policy_net's prediction. # # This gives us a tensor that contains the return we expect to get # at that state if we follow the model's advice. state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_t+1) for all next states. # V(s_t+1) = max_a ( Q(s_t+1, a) ) # = the maximum reward over all possible actions at state s_t+1. next_state_values = torch.zeros(BATCH_SIZE, device = compute_device) # Don't compute gradient for operations in this block. # If you don't understand what this means, RTFD. with torch.no_grad(): # Note the use of non_final_mask here. # States that are final do not have their reward set by the line # below, so their reward stays at zero. # # States that are not final get their predicted value # set to the best value the model predicts. # # # Expected values of action are selected with the "older" target net, # and their best reward (over possible actions) is selected with max(1)[0]. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0] # TODO: What does this mean? # "Compute expected Q values" expected_state_action_values = reward_batch + (next_state_values * GAMMA) # Compute Huber loss between predicted reward and expected reward. # Pytorch is will account for this when we compute the gradient of loss. # # loss is a single-element tensor (i.e, a scalar). criterion = torch.nn.SmoothL1Loss() loss = criterion( state_action_values, expected_state_action_values.unsqueeze(1) ) # We can now run a step of backpropagation on our model. # TODO: what does this do? # # Calling .backward() multiple times will accumulate parameter gradients. # Thus, we reset the gradient before each step. optimizer.zero_grad() # Compute the gradient of loss wrt... something? # TODO: what does this do, we never use loss again?! loss.backward() # Prevent vanishing and exploding gradients. # Forces gradients to be in [-clip_value, +clip_value] torch.nn.utils.clip_grad_value_( # type: ignore policy_net.parameters(), clip_value = 100 ) # Perform a single optimizer step. # # Uses the current gradient, which is stored # in the .grad attribute of the parameter. optimizer.step() return loss def on_state_before(celeste): global steps_done # Conversion to pytorch state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) action = None while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])): action = select_action( pt_state, steps_done ) str_action = Celeste.action_space[action] steps_done += 1 # For manual testing #str_action = "" #while str_action not in Celeste.action_space: # str_action = input("action> ") #action = Celeste.action_space.index(str_action) print(str_action) celeste.act(str_action) return state, action def on_state_after(celeste, before_out): global episode_number state, action = before_out next_state = celeste.state pt_state = torch.tensor( [getattr(state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) pt_action = torch.tensor( [[ action ]], device = compute_device, dtype = torch.long ) if next_state.deaths != 0: pt_next_state = None reward = 0 else: pt_next_state = torch.tensor( [getattr(next_state, x) for x in Celeste.state_number_map], dtype = torch.float32, device = compute_device ).unsqueeze(0) if state.next_point == next_state.next_point: reward = state.dist - next_state.dist # Clip rewards that are too large if reward > 1: reward = 1 else: reward = 0 else: # Reward for reaching a point reward = 1 pt_reward = torch.tensor([reward], device = compute_device) # Add this state transition to memory. memory.append( Transition( pt_state, # last state pt_action, pt_next_state, # next state pt_reward ) ) print("==> ", int(reward)) print("") loss = None # Only train the network if we have enough # transitions in memory to do so. if len(memory) >= BATCH_SIZE: loss = optimize_model() # Soft update target_net weights target_net_state = target_net.state_dict() policy_net_state = policy_net.state_dict() for key in policy_net_state: target_net_state[key] = ( policy_net_state[key] * TAU + target_net_state[key] * (1-TAU) ) target_net.load_state_dict(target_net_state) # Move on to the next episode once we reach # a terminal state. if (next_state.deaths != 0): s = celeste.state with model_train_log.open("a") as f: f.write(json.dumps({ "checkpoints": s.next_point, "state_count": s.state_count, "loss": None if loss is None else loss.item() }) + "\n") # Save model torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "memory": memory, "episode_number": episode_number, "steps_done": steps_done }, model_save_path) # Clean up screenshots shots = Path("/home/mark/Desktop").glob("hackcel_*.png") target = screenshot_dir / Path(f"{episode_number}") target.mkdir(parents = True) for s in shots: s.rename(target / s.name) # Save a prediction graph if episode_number % archive_interval == 0: torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "memory": memory, "episode_number": episode_number, "steps_done": steps_done }, model_archive_dir / f"{episode_number}.torch") print("Game over. Resetting.") episode_number += 1 celeste.reset() if __name__ == "__main__": c = Celeste() c.update_loop( on_state_before, on_state_after )