Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-18 19:50:43 -08:00
parent 36c5fcac7c
commit 0e874bf810
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 124 additions and 168 deletions

View File

@ -9,6 +9,7 @@ import torch
from celeste import Celeste from celeste import Celeste
if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data") model_data_root = Path("model_data")
@ -62,7 +63,6 @@ TAU = 0.005
GAMMA = 0.99 GAMMA = 0.99
# Outline our network # Outline our network
class DQN(torch.nn.Module): class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int): def __init__(self, n_observations: int, n_actions: int):
@ -92,12 +92,22 @@ class DQN(torch.nn.Module):
def forward(self, x): def forward(self, x):
return self.layers(x) return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
if __name__ == "__main__":
steps_done = 0 steps_done = 0
num_episodes = 100 num_episodes = 100
episode_number = 0
archive_interval = 10
# Create replay memory. # Create replay memory.
# #
@ -107,7 +117,6 @@ num_episodes = 100
# element if maxlen is exceeded. # element if maxlen is exceeded.
memory = deque([], maxlen=100_000) memory = deque([], maxlen=100_000)
policy_net = DQN( policy_net = DQN(
n_observations, n_observations,
n_actions n_actions
@ -127,6 +136,17 @@ optimizer = torch.optim.AdamW(
amsgrad = True amsgrad = True
) )
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def select_action(state, steps_done): def select_action(state, steps_done):
""" """
Select an action using an epsilon-greedy policy. Select an action using an epsilon-greedy policy.
@ -160,24 +180,6 @@ def select_action(state, steps_done):
return random.randint( 0, n_actions-1 ) return random.randint( 0, n_actions-1 )
last_state = None
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def optimize_model(): def optimize_model():
if len(memory) < BATCH_SIZE: if len(memory) < BATCH_SIZE:
@ -313,19 +315,6 @@ def optimize_model():
optimizer.step() optimizer.step()
episode_number = 0
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def on_state_before(celeste): def on_state_before(celeste):
global steps_done global steps_done
@ -363,9 +352,6 @@ def on_state_before(celeste):
image_interval = 10
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global episode_number global episode_number
global image_count global image_count
@ -474,7 +460,7 @@ def on_state_after(celeste, before_out):
s.rename(target / s.name) s.rename(target / s.name)
# Save a prediction graph # Save a prediction graph
if episode_number % image_interval == 0: if episode_number % archive_interval == 0:
torch.save({ torch.save({
"policy_state_dict": policy_net.state_dict(), "policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(), "target_state_dict": target_net.state_dict(),
@ -490,7 +476,7 @@ def on_state_after(celeste, before_out):
celeste.reset() celeste.reset()
if __name__ == "__main__":
c = Celeste() c = Celeste()
c.update_loop( c.update_loop(

View File

@ -1,14 +1,15 @@
from pathlib import Path
import torch import torch
from celeste import Celeste
import numpy as np import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from collections import namedtuple from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
compute_device = torch.device( # Use cpu, the script is faster in parallel.
"cuda" if torch.cuda.is_available() else "cpu" compute_device = torch.device("cpu")
)
# Celeste env properties # Celeste env properties
@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space) n_actions = len(Celeste.action_space)
# Outline our network out_dir = Path("out/plots")
class DQN(torch.nn.Module): out_dir.mkdir(parents = True, exist_ok = True)
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
src_dir = Path("model_data/model_archive")
policy_net = DQN( policy_net = DQN(
n_observations, n_observations,
@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW(
amsgrad = True amsgrad = True
) )
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def makeplt(i, net): def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32) p = np.zeros((128, 128), dtype=np.float32)
@ -93,10 +57,9 @@ def makeplt(i, net):
return p return p
for i in Path("out/model_images").iterdir():
def plot(src):
checkpoint = torch.load(i) checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir():
ax.set(adjustable="box", aspect="equal") ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor( plot = ax.pcolor(
makeplt(a, policy_net), makeplt(a, policy_net),
cmap = "Greens_r", cmap = "Greens",
vmin = 0, vmin = 0,
vmax = 20
) )
ax.set_title(Celeste.action_space[a]) ax.set_title(Celeste.action_space[a])
ax.invert_yaxis() ax.invert_yaxis()
fig.colorbar(plot) fig.colorbar(plot)
print(i) print(src)
fig.savefig(f"out/{i.stem}.png") fig.savefig(out_dir / f"{src.stem}.png")
plt.close() plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))