Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-18 19:50:43 -08:00
parent 36c5fcac7c
commit 0e874bf810
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 124 additions and 168 deletions

View File

@ -9,58 +9,58 @@ import torch
from celeste import Celeste
# Where to read/write model data.
model_data_root = Path("model_data")
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data")
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device(
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Epsilon-greedy parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
# Epsilon-greedy parameters
#
# Original docs:
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
BATCH_SIZE = 1_000
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.005
BATCH_SIZE = 1_000
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.005
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99
# Outline our network
@ -92,40 +92,60 @@ class DQN(torch.nn.Module):
def forward(self, x):
return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
steps_done = 0
if __name__ == "__main__":
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
num_episodes = 100
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
policy_net = DQN(
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
).to(compute_device)
target_net = DQN(
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
).to(compute_device)
target_net.load_state_dict(policy_net.state_dict())
target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW(
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate
amsgrad = True
)
)
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def select_action(state, steps_done):
"""
@ -160,24 +180,6 @@ def select_action(state, steps_done):
return random.randint( 0, n_actions-1 )
last_state = None
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def optimize_model():
if len(memory) < BATCH_SIZE:
@ -313,19 +315,6 @@ def optimize_model():
optimizer.step()
episode_number = 0
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def on_state_before(celeste):
global steps_done
@ -363,9 +352,6 @@ def on_state_before(celeste):
image_interval = 10
def on_state_after(celeste, before_out):
global episode_number
global image_count
@ -474,7 +460,7 @@ def on_state_after(celeste, before_out):
s.rename(target / s.name)
# Save a prediction graph
if episode_number % image_interval == 0:
if episode_number % archive_interval == 0:
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
@ -490,10 +476,10 @@ def on_state_after(celeste, before_out):
celeste.reset()
if __name__ == "__main__":
c = Celeste()
c = Celeste()
c.update_loop(
c.update_loop(
on_state_before,
on_state_after
)
)

View File

@ -1,14 +1,15 @@
from pathlib import Path
import torch
from celeste import Celeste
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from collections import namedtuple
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Use cpu, the script is faster in parallel.
compute_device = torch.device("cpu")
# Celeste env properties
@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/model_archive")
policy_net = DQN(
n_observations,
@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW(
amsgrad = True
)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
@ -93,10 +57,9 @@ def makeplt(i, net):
return p
for i in Path("out/model_images").iterdir():
checkpoint = torch.load(i)
def plot(src):
checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir():
ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor(
makeplt(a, policy_net),
cmap = "Greens_r",
cmap = "Greens",
vmin = 0,
vmax = 20
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
fig.colorbar(plot)
print(i)
fig.savefig(f"out/{i.stem}.png")
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))