Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-18 19:50:43 -08:00
parent 36c5fcac7c
commit 0e874bf810
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
2 changed files with 124 additions and 168 deletions

View File

@ -9,6 +9,7 @@ import torch
from celeste import Celeste
if __name__ == "__main__":
# Where to read/write model data.
model_data_root = Path("model_data")
@ -62,7 +63,6 @@ TAU = 0.005
GAMMA = 0.99
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
@ -92,12 +92,22 @@ class DQN(torch.nn.Module):
def forward(self, x):
return self.layers(x)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
if __name__ == "__main__":
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
# Create replay memory.
#
@ -107,7 +117,6 @@ num_episodes = 100
# element if maxlen is exceeded.
memory = deque([], maxlen=100_000)
policy_net = DQN(
n_observations,
n_actions
@ -127,6 +136,17 @@ optimizer = torch.optim.AdamW(
amsgrad = True
)
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def select_action(state, steps_done):
"""
Select an action using an epsilon-greedy policy.
@ -160,24 +180,6 @@ def select_action(state, steps_done):
return random.randint( 0, n_actions-1 )
last_state = None
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def optimize_model():
if len(memory) < BATCH_SIZE:
@ -313,19 +315,6 @@ def optimize_model():
optimizer.step()
episode_number = 0
if model_save_path.is_file():
# Load model if one exists
checkpoint = torch.load(model_save_path)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"]
def on_state_before(celeste):
global steps_done
@ -363,9 +352,6 @@ def on_state_before(celeste):
image_interval = 10
def on_state_after(celeste, before_out):
global episode_number
global image_count
@ -474,7 +460,7 @@ def on_state_after(celeste, before_out):
s.rename(target / s.name)
# Save a prediction graph
if episode_number % image_interval == 0:
if episode_number % archive_interval == 0:
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
@ -490,7 +476,7 @@ def on_state_after(celeste, before_out):
celeste.reset()
if __name__ == "__main__":
c = Celeste()
c.update_loop(

View File

@ -1,14 +1,15 @@
from pathlib import Path
import torch
from celeste import Celeste
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from collections import namedtuple
from multiprocessing import Pool
from celeste import Celeste
from main import DQN
from main import Transition
compute_device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
# Use cpu, the script is faster in parallel.
compute_device = torch.device("cpu")
# Celeste env properties
@ -16,35 +17,10 @@ n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
# Outline our network
class DQN(torch.nn.Module):
def __init__(self, n_observations: int, n_actions: int):
super(DQN, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(n_observations, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.torch.nn.Linear(128, n_actions)
)
# Can be called with one input, or with a batch.
#
# Returns tensor(
# [ Q(s, left), Q(s, right) ], ...
# )
#
# Recall that Q(s, a) is the (expected) return of taking
# action `a` at state `s`
def forward(self, x):
return self.layers(x)
out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/model_archive")
policy_net = DQN(
n_observations,
@ -62,18 +38,6 @@ optimizer = torch.optim.AdamW(
amsgrad = True
)
Transition = namedtuple(
"Transition",
(
"state",
"action",
"next_state",
"reward"
)
)
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
@ -93,10 +57,9 @@ def makeplt(i, net):
return p
for i in Path("out/model_images").iterdir():
checkpoint = torch.load(i)
def plot(src):
checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"])
@ -107,13 +70,20 @@ for i in Path("out/model_images").iterdir():
ax.set(adjustable="box", aspect="equal")
plot = ax.pcolor(
makeplt(a, policy_net),
cmap = "Greens_r",
cmap = "Greens",
vmin = 0,
vmax = 20
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
fig.colorbar(plot)
print(i)
fig.savefig(f"out/{i.stem}.png")
print(src)
fig.savefig(out_dir / f"{src.stem}.png")
plt.close()
if __name__ == "__main__":
with Pool(5) as p:
p.map(plot, list(src_dir.iterdir()))