Cleanup
parent
36c5fcac7c
commit
0e874bf810
182
celeste/main.py
182
celeste/main.py
|
@ -9,58 +9,58 @@ import torch
|
|||
from celeste import Celeste
|
||||
|
||||
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data")
|
||||
if __name__ == "__main__":
|
||||
# Where to read/write model data.
|
||||
model_data_root = Path("model_data")
|
||||
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
model_train_log = model_data_root / "train_log"
|
||||
screenshot_dir = model_data_root / "screenshots"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||
model_save_path = model_data_root / "model.torch"
|
||||
model_archive_dir = model_data_root / "model_archive"
|
||||
model_train_log = model_data_root / "train_log"
|
||||
screenshot_dir = model_data_root / "screenshots"
|
||||
model_data_root.mkdir(parents = True, exist_ok = True)
|
||||
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
||||
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
|
||||
compute_device = torch.device(
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
# Celeste env properties
|
||||
n_observations = len(Celeste.state_number_map)
|
||||
n_actions = len(Celeste.action_space)
|
||||
|
||||
|
||||
# Epsilon-greedy parameters
|
||||
#
|
||||
# Original docs:
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
EPS_START = 0.9
|
||||
EPS_END = 0.05
|
||||
EPS_DECAY = 1000
|
||||
# Epsilon-greedy parameters
|
||||
#
|
||||
# Original docs:
|
||||
# EPS_START is the starting value of epsilon
|
||||
# EPS_END is the final value of epsilon
|
||||
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
||||
EPS_START = 0.9
|
||||
EPS_END = 0.05
|
||||
EPS_DECAY = 1000
|
||||
|
||||
|
||||
BATCH_SIZE = 1_000
|
||||
# Learning rate of target_net.
|
||||
# Controls how soft our soft update is.
|
||||
#
|
||||
# Should be between 0 and 1.
|
||||
# Large values
|
||||
# Small values do the opposite.
|
||||
#
|
||||
# A value of one makes target_net
|
||||
# change at the same rate as policy_net.
|
||||
#
|
||||
# A value of zero makes target_net
|
||||
# not change at all.
|
||||
TAU = 0.005
|
||||
BATCH_SIZE = 1_000
|
||||
# Learning rate of target_net.
|
||||
# Controls how soft our soft update is.
|
||||
#
|
||||
# Should be between 0 and 1.
|
||||
# Large values
|
||||
# Small values do the opposite.
|
||||
#
|
||||
# A value of one makes target_net
|
||||
# change at the same rate as policy_net.
|
||||
#
|
||||
# A value of zero makes target_net
|
||||
# not change at all.
|
||||
TAU = 0.005
|
||||
|
||||
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
GAMMA = 0.99
|
||||
|
||||
# GAMMA is the discount factor as mentioned in the previous section
|
||||
GAMMA = 0.99
|
||||
|
||||
|
||||
# Outline our network
|
||||
|
@ -92,40 +92,60 @@ class DQN(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
steps_done = 0
|
||||
if __name__ == "__main__":
|
||||
steps_done = 0
|
||||
num_episodes = 100
|
||||
episode_number = 0
|
||||
archive_interval = 10
|
||||
|
||||
num_episodes = 100
|
||||
# Create replay memory.
|
||||
#
|
||||
# Transition: a container for naming data (defined in util.py)
|
||||
# Memory: a deque that holds recent states as Transitions
|
||||
# Has a fixed length, drops oldest
|
||||
# element if maxlen is exceeded.
|
||||
memory = deque([], maxlen=100_000)
|
||||
|
||||
|
||||
# Create replay memory.
|
||||
#
|
||||
# Transition: a container for naming data (defined in util.py)
|
||||
# Memory: a deque that holds recent states as Transitions
|
||||
# Has a fixed length, drops oldest
|
||||
# element if maxlen is exceeded.
|
||||
memory = deque([], maxlen=100_000)
|
||||
|
||||
|
||||
policy_net = DQN(
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
).to(compute_device)
|
||||
|
||||
target_net = DQN(
|
||||
target_net = DQN(
|
||||
n_observations,
|
||||
n_actions
|
||||
).to(compute_device)
|
||||
).to(compute_device)
|
||||
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
target_net.load_state_dict(policy_net.state_dict())
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer = torch.optim.AdamW(
|
||||
policy_net.parameters(),
|
||||
lr = 0.01, # Hyperparameter: learning rate
|
||||
amsgrad = True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if model_save_path.is_file():
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(model_save_path)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
memory = checkpoint["memory"]
|
||||
episode_number = checkpoint["episode_number"] + 1
|
||||
steps_done = checkpoint["steps_done"]
|
||||
|
||||
def select_action(state, steps_done):
|
||||
"""
|
||||
|
@ -160,24 +180,6 @@ def select_action(state, steps_done):
|
|||
return random.randint( 0, n_actions-1 )
|
||||
|
||||
|
||||
last_state = None
|
||||
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def optimize_model():
|
||||
|
||||
if len(memory) < BATCH_SIZE:
|
||||
|
@ -313,19 +315,6 @@ def optimize_model():
|
|||
optimizer.step()
|
||||
|
||||
|
||||
episode_number = 0
|
||||
|
||||
|
||||
if model_save_path.is_file():
|
||||
# Load model if one exists
|
||||
checkpoint = torch.load(model_save_path)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
target_net.load_state_dict(checkpoint["target_state_dict"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
memory = checkpoint["memory"]
|
||||
episode_number = checkpoint["episode_number"] + 1
|
||||
steps_done = checkpoint["steps_done"]
|
||||
|
||||
|
||||
def on_state_before(celeste):
|
||||
global steps_done
|
||||
|
@ -363,9 +352,6 @@ def on_state_before(celeste):
|
|||
|
||||
|
||||
|
||||
image_interval = 10
|
||||
|
||||
|
||||
def on_state_after(celeste, before_out):
|
||||
global episode_number
|
||||
global image_count
|
||||
|
@ -474,7 +460,7 @@ def on_state_after(celeste, before_out):
|
|||
s.rename(target / s.name)
|
||||
|
||||
# Save a prediction graph
|
||||
if episode_number % image_interval == 0:
|
||||
if episode_number % archive_interval == 0:
|
||||
torch.save({
|
||||
"policy_state_dict": policy_net.state_dict(),
|
||||
"target_state_dict": target_net.state_dict(),
|
||||
|
@ -490,10 +476,10 @@ def on_state_after(celeste, before_out):
|
|||
celeste.reset()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
c = Celeste()
|
||||
|
||||
c = Celeste()
|
||||
|
||||
c.update_loop(
|
||||
c.update_loop(
|
||||
on_state_before,
|
||||
on_state_after
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from pathlib import Path
|
||||
import torch
|
||||
from celeste import Celeste
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import namedtuple
|
||||
from multiprocessing import Pool
|
||||
|
||||
from celeste import Celeste
|
||||
from main import DQN
|
||||
from main import Transition
|
||||
|
||||
compute_device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
# Use cpu, the script is faster in parallel.
|
||||
compute_device = torch.device("cpu")
|
||||
|
||||
|
||||
# Celeste env properties
|
||||
|
@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map)
|
|||
n_actions = len(Celeste.action_space)
|
||||
|
||||
|
||||
# Outline our network
|
||||
class DQN(torch.nn.Module):
|
||||
def __init__(self, n_observations: int, n_actions: int):
|
||||
super(DQN, self).__init__()
|
||||
|
||||
self.layers = torch.nn.Sequential(
|
||||
torch.nn.Linear(n_observations, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.nn.Linear(128, 128),
|
||||
torch.nn.ReLU(),
|
||||
|
||||
torch.torch.nn.Linear(128, n_actions)
|
||||
)
|
||||
|
||||
# Can be called with one input, or with a batch.
|
||||
#
|
||||
# Returns tensor(
|
||||
# [ Q(s, left), Q(s, right) ], ...
|
||||
# )
|
||||
#
|
||||
# Recall that Q(s, a) is the (expected) return of taking
|
||||
# action `a` at state `s`
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
out_dir = Path("out/plots")
|
||||
out_dir.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
src_dir = Path("model_data/model_archive")
|
||||
|
||||
policy_net = DQN(
|
||||
n_observations,
|
||||
|
@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW(
|
|||
amsgrad = True
|
||||
)
|
||||
|
||||
Transition = namedtuple(
|
||||
"Transition",
|
||||
(
|
||||
"state",
|
||||
"action",
|
||||
"next_state",
|
||||
"reward"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def makeplt(i, net):
|
||||
p = np.zeros((128, 128), dtype=np.float32)
|
||||
|
@ -93,10 +57,9 @@ def makeplt(i, net):
|
|||
return p
|
||||
|
||||
|
||||
for i in Path("out/model_images").iterdir():
|
||||
|
||||
|
||||
checkpoint = torch.load(i)
|
||||
def plot(src):
|
||||
checkpoint = torch.load(src)
|
||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||
|
||||
|
||||
|
@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir():
|
|||
ax.set(adjustable="box", aspect="equal")
|
||||
plot = ax.pcolor(
|
||||
makeplt(a, policy_net),
|
||||
cmap = "Greens_r",
|
||||
cmap = "Greens",
|
||||
vmin = 0,
|
||||
vmax = 20
|
||||
)
|
||||
ax.set_title(Celeste.action_space[a])
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
print(i)
|
||||
fig.savefig(f"out/{i.stem}.png")
|
||||
print(src)
|
||||
fig.savefig(out_dir / f"{src.stem}.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with Pool(5) as p:
|
||||
p.map(plot, list(src_dir.iterdir()))
|
||||
|
||||
|
||||
|
|
Reference in New Issue