From 0e874bf810778225ca9a0d650cf4c47d0e19948c Mon Sep 17 00:00:00 2001 From: Mark Date: Sat, 18 Feb 2023 19:50:43 -0800 Subject: [PATCH] Cleanup --- celeste/main.py | 216 ++++++++++++++++++++++------------------------- celeste/plots.py | 76 +++++------------ 2 files changed, 124 insertions(+), 168 deletions(-) diff --git a/celeste/main.py b/celeste/main.py index 4f6e5f8..441b4e4 100644 --- a/celeste/main.py +++ b/celeste/main.py @@ -9,58 +9,58 @@ import torch from celeste import Celeste -# Where to read/write model data. -model_data_root = Path("model_data") +if __name__ == "__main__": + # Where to read/write model data. + model_data_root = Path("model_data") -model_save_path = model_data_root / "model.torch" -model_archive_dir = model_data_root / "model_archive" -model_train_log = model_data_root / "train_log" -screenshot_dir = model_data_root / "screenshots" -model_data_root.mkdir(parents = True, exist_ok = True) -model_archive_dir.mkdir(parents = True, exist_ok = True) -screenshot_dir.mkdir(parents = True, exist_ok = True) + model_save_path = model_data_root / "model.torch" + model_archive_dir = model_data_root / "model_archive" + model_train_log = model_data_root / "train_log" + screenshot_dir = model_data_root / "screenshots" + model_data_root.mkdir(parents = True, exist_ok = True) + model_archive_dir.mkdir(parents = True, exist_ok = True) + screenshot_dir.mkdir(parents = True, exist_ok = True) -compute_device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" -) + compute_device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) -# Celeste env properties -n_observations = len(Celeste.state_number_map) -n_actions = len(Celeste.action_space) + # Celeste env properties + n_observations = len(Celeste.state_number_map) + n_actions = len(Celeste.action_space) -# Epsilon-greedy parameters -# -# Original docs: -# EPS_START is the starting value of epsilon -# EPS_END is the final value of epsilon -# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay -EPS_START = 0.9 -EPS_END = 0.05 -EPS_DECAY = 1000 + # Epsilon-greedy parameters + # + # Original docs: + # EPS_START is the starting value of epsilon + # EPS_END is the final value of epsilon + # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay + EPS_START = 0.9 + EPS_END = 0.05 + EPS_DECAY = 1000 -BATCH_SIZE = 1_000 -# Learning rate of target_net. -# Controls how soft our soft update is. -# -# Should be between 0 and 1. -# Large values -# Small values do the opposite. -# -# A value of one makes target_net -# change at the same rate as policy_net. -# -# A value of zero makes target_net -# not change at all. -TAU = 0.005 + BATCH_SIZE = 1_000 + # Learning rate of target_net. + # Controls how soft our soft update is. + # + # Should be between 0 and 1. + # Large values + # Small values do the opposite. + # + # A value of one makes target_net + # change at the same rate as policy_net. + # + # A value of zero makes target_net + # not change at all. + TAU = 0.005 -# GAMMA is the discount factor as mentioned in the previous section -GAMMA = 0.99 - + # GAMMA is the discount factor as mentioned in the previous section + GAMMA = 0.99 # Outline our network @@ -92,41 +92,61 @@ class DQN(torch.nn.Module): def forward(self, x): return self.layers(x) - - -steps_done = 0 - -num_episodes = 100 - - -# Create replay memory. -# -# Transition: a container for naming data (defined in util.py) -# Memory: a deque that holds recent states as Transitions -# Has a fixed length, drops oldest -# element if maxlen is exceeded. -memory = deque([], maxlen=100_000) - - -policy_net = DQN( - n_observations, - n_actions -).to(compute_device) - -target_net = DQN( - n_observations, - n_actions -).to(compute_device) - -target_net.load_state_dict(policy_net.state_dict()) - - -optimizer = torch.optim.AdamW( - policy_net.parameters(), - lr = 0.01, # Hyperparameter: learning rate - amsgrad = True +Transition = namedtuple( + "Transition", + ( + "state", + "action", + "next_state", + "reward" + ) ) + +if __name__ == "__main__": + steps_done = 0 + num_episodes = 100 + episode_number = 0 + archive_interval = 10 + + # Create replay memory. + # + # Transition: a container for naming data (defined in util.py) + # Memory: a deque that holds recent states as Transitions + # Has a fixed length, drops oldest + # element if maxlen is exceeded. + memory = deque([], maxlen=100_000) + + policy_net = DQN( + n_observations, + n_actions + ).to(compute_device) + + target_net = DQN( + n_observations, + n_actions + ).to(compute_device) + + target_net.load_state_dict(policy_net.state_dict()) + + + optimizer = torch.optim.AdamW( + policy_net.parameters(), + lr = 0.01, # Hyperparameter: learning rate + amsgrad = True + ) + + + if model_save_path.is_file(): + # Load model if one exists + checkpoint = torch.load(model_save_path) + policy_net.load_state_dict(checkpoint["policy_state_dict"]) + target_net.load_state_dict(checkpoint["target_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + memory = checkpoint["memory"] + episode_number = checkpoint["episode_number"] + 1 + steps_done = checkpoint["steps_done"] + def select_action(state, steps_done): """ Select an action using an epsilon-greedy policy. @@ -160,24 +180,6 @@ def select_action(state, steps_done): return random.randint( 0, n_actions-1 ) -last_state = None - - -Transition = namedtuple( - "Transition", - ( - "state", - "action", - "next_state", - "reward" - ) -) - - - - - - def optimize_model(): if len(memory) < BATCH_SIZE: @@ -313,19 +315,6 @@ def optimize_model(): optimizer.step() -episode_number = 0 - - -if model_save_path.is_file(): - # Load model if one exists - checkpoint = torch.load(model_save_path) - policy_net.load_state_dict(checkpoint["policy_state_dict"]) - target_net.load_state_dict(checkpoint["target_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - memory = checkpoint["memory"] - episode_number = checkpoint["episode_number"] + 1 - steps_done = checkpoint["steps_done"] - def on_state_before(celeste): global steps_done @@ -363,9 +352,6 @@ def on_state_before(celeste): -image_interval = 10 - - def on_state_after(celeste, before_out): global episode_number global image_count @@ -474,7 +460,7 @@ def on_state_after(celeste, before_out): s.rename(target / s.name) # Save a prediction graph - if episode_number % image_interval == 0: + if episode_number % archive_interval == 0: torch.save({ "policy_state_dict": policy_net.state_dict(), "target_state_dict": target_net.state_dict(), @@ -490,10 +476,10 @@ def on_state_after(celeste, before_out): celeste.reset() +if __name__ == "__main__": + c = Celeste() -c = Celeste() - -c.update_loop( - on_state_before, - on_state_after -) + c.update_loop( + on_state_before, + on_state_after + ) diff --git a/celeste/plots.py b/celeste/plots.py index 62c6ce3..d2e338d 100644 --- a/celeste/plots.py +++ b/celeste/plots.py @@ -1,14 +1,15 @@ -from pathlib import Path import torch -from celeste import Celeste import numpy as np +from pathlib import Path import matplotlib.pyplot as plt -from collections import namedtuple +from multiprocessing import Pool +from celeste import Celeste +from main import DQN +from main import Transition -compute_device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu" -) +# Use cpu, the script is faster in parallel. +compute_device = torch.device("cpu") # Celeste env properties @@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map) n_actions = len(Celeste.action_space) -# Outline our network -class DQN(torch.nn.Module): - def __init__(self, n_observations: int, n_actions: int): - super(DQN, self).__init__() - - self.layers = torch.nn.Sequential( - torch.nn.Linear(n_observations, 128), - torch.nn.ReLU(), - - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - - torch.nn.Linear(128, 128), - torch.nn.ReLU(), - - torch.torch.nn.Linear(128, n_actions) - ) - - # Can be called with one input, or with a batch. - # - # Returns tensor( - # [ Q(s, left), Q(s, right) ], ... - # ) - # - # Recall that Q(s, a) is the (expected) return of taking - # action `a` at state `s` - def forward(self, x): - return self.layers(x) +out_dir = Path("out/plots") +out_dir.mkdir(parents = True, exist_ok = True) +src_dir = Path("model_data/model_archive") policy_net = DQN( n_observations, @@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW( amsgrad = True ) -Transition = namedtuple( - "Transition", - ( - "state", - "action", - "next_state", - "reward" - ) -) - - - def makeplt(i, net): p = np.zeros((128, 128), dtype=np.float32) @@ -93,10 +57,9 @@ def makeplt(i, net): return p -for i in Path("out/model_images").iterdir(): - - checkpoint = torch.load(i) +def plot(src): + checkpoint = torch.load(src) policy_net.load_state_dict(checkpoint["policy_state_dict"]) @@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir(): ax.set(adjustable="box", aspect="equal") plot = ax.pcolor( makeplt(a, policy_net), - cmap = "Greens_r", + cmap = "Greens", vmin = 0, - vmax = 20 ) ax.set_title(Celeste.action_space[a]) ax.invert_yaxis() fig.colorbar(plot) - print(i) - fig.savefig(f"out/{i.stem}.png") + print(src) + fig.savefig(out_dir / f"{src.stem}.png") plt.close() + + + +if __name__ == "__main__": + with Pool(5) as p: + p.map(plot, list(src_dir.iterdir())) + +