Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-18 19:50:43 -08:00
parent 36c5fcac7c
commit 0e874bf810
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 124 additions and 168 deletions

View File

@ -9,58 +9,58 @@ import torch
from celeste import Celeste from celeste import Celeste
# Where to read/write model data. if __name__ == "__main__":
model_data_root = Path("model_data") # Where to read/write model data.
model_data_root = Path("model_data")
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log" model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots" screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True) model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True) screenshot_dir.mkdir(parents = True, exist_ok = True)
compute_device = torch.device( compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu" "cuda" if torch.cuda.is_available() else "cpu"
) )
# Celeste env properties # Celeste env properties
n_observations = len(Celeste.state_number_map) n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space) n_actions = len(Celeste.action_space)
# Epsilon-greedy parameters # Epsilon-greedy parameters
# #
# Original docs: # Original docs:
# EPS_START is the starting value of epsilon # EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon # EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.05 EPS_END = 0.05
EPS_DECAY = 1000 EPS_DECAY = 1000
BATCH_SIZE = 1_000 BATCH_SIZE = 1_000
# Learning rate of target_net. # Learning rate of target_net.
# Controls how soft our soft update is. # Controls how soft our soft update is.
# #
# Should be between 0 and 1. # Should be between 0 and 1.
# Large values # Large values
# Small values do the opposite. # Small values do the opposite.
# #
# A value of one makes target_net # A value of one makes target_net
# change at the same rate as policy_net. # change at the same rate as policy_net.
# #
# A value of zero makes target_net # A value of zero makes target_net
# not change at all. # not change at all.
TAU = 0.005 TAU = 0.005
# GAMMA is the discount factor as mentioned in the previous section # GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.99 GAMMA = 0.99
# Outline our network # Outline our network
@ -92,40 +92,60 @@ class DQN(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
steps_done = 0 if __name__ == "__main__":
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
num_episodes = 100 # Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
policy_net = DQN(
# Create replay memory.
#
# Transition: a container for naming data (defined in util.py)
# Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
policy_net = DQN(
n_observations, n_observations,
n_actions n_actions
).to(compute_device) ).to(compute_device)
target_net = DQN( target_net = DQN(
n_observations, n_observations,
n_actions n_actions
).to(compute_device) ).to(compute_device)
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate lr = 0.01, # Hyperparameter: learning rate
amsgrad = True amsgrad = True
) )
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def select_action(state, steps_done): def select_action(state, steps_done):
""" """
@ -160,24 +180,6 @@ def select_action(state, steps_done):
return random.randint( 0, n_actions-1 ) return random.randint( 0, n_actions-1 )
last_state = None
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def optimize_model(): def optimize_model():
if len(memory) < BATCH_SIZE: if len(memory) < BATCH_SIZE:
@ -313,19 +315,6 @@ def optimize_model():
optimizer.step() optimizer.step()
episode_number = 0
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def on_state_before(celeste): def on_state_before(celeste):
global steps_done global steps_done
@ -363,9 +352,6 @@ def on_state_before(celeste):
image_interval = 10
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global episode_number global episode_number
global image_count global image_count
@ -474,7 +460,7 @@ def on_state_after(celeste, before_out):
s.rename(target / s.name) s.rename(target / s.name)
# Save a prediction graph # Save a prediction graph
if episode_number % image_interval == 0: if episode_number % archive_interval == 0:
torch.save({ torch.save({
"policy_state_dict": policy_net.state_dict(), "policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(), "target_state_dict": target_net.state_dict(),
@ -490,10 +476,10 @@ def on_state_after(celeste, before_out):
celeste.reset() celeste.reset()
if __name__ == "__main__":
c = Celeste()
c = Celeste() c.update_loop(
c.update_loop(
on_state_before, on_state_before,
on_state_after on_state_after
) )

View File

@ -1,14 +1,15 @@
from pathlib import Path
import torch import torch
from celeste import Celeste
import numpy as np import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from collections import namedtuple from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
compute_device = torch.device( # Use cpu, the script is faster in parallel.
"cuda" if torch.cuda.is_available() else "cpu" compute_device = torch.device("cpu")
)
# Celeste env properties # Celeste env properties
@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space) n_actions = len(Celeste.action_space)
# Outline our network out_dir = Path("out/plots")
class DQN(torch.nn.Module): out_dir.mkdir(parents = True, exist_ok = True)
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
src_dir = Path("model_data/model_archive")
policy_net = DQN( policy_net = DQN(
n_observations, n_observations,
@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW(
amsgrad = True amsgrad = True
) )
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def makeplt(i, net): def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32) p = np.zeros((128, 128), dtype=np.float32)
@ -93,10 +57,9 @@ def makeplt(i, net):
return p return p
for i in Path("out/model_images").iterdir():
def plot(src):
checkpoint = torch.load(i) checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir():
ax.set(adjustable="box", aspect="equal") ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor( plot = ax.pcolor(
makeplt(a, policy_net), makeplt(a, policy_net),
cmap = "Greens_r", cmap = "Greens",
vmin = 0, vmin = 0,
vmax = 20
) )
ax.set_title(Celeste.action_space[a]) ax.set_title(Celeste.action_space[a])
ax.invert_yaxis() ax.invert_yaxis()
fig.colorbar(plot) fig.colorbar(plot)
print(i) print(src)
fig.savefig(f"out/{i.stem}.png") fig.savefig(out_dir / f"{src.stem}.png")
plt.close() plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))