Mark
/
celeste-ai
Archived
1
0
Fork 0
master
Mark 2023-02-18 21:10:13 -08:00
parent 6668614bbd
commit 6fe0d6e1cd
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 19 additions and 32 deletions

View File

@ -8,43 +8,21 @@ from celeste import Celeste
from main import DQN from main import DQN
from main import Transition from main import Transition
# Use cpu, the script is faster in parallel. # Use cpu, this script is faster in parallel.
compute_device = torch.device("cpu") compute_device = torch.device("cpu")
# Celeste env properties
n_observations = len(Celeste.state_number_map)
n_actions = len(Celeste.action_space)
out_dir = Path("out/plots") out_dir = Path("out/plots")
out_dir.mkdir(parents = True, exist_ok = True) out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/model_archive") src_dir = Path("model_data/current/model_archive")
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
optimizer = torch.optim.AdamW(
policy_net.parameters(),
lr = 0.01, # Hyperparameter: learning rate
amsgrad = True
)
def makeplt(i, net): def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32) p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)): for r in range(len(p)):
for c in range(len(p[r])): for c in range(len(p[r])):
with torch.no_grad():
k = net( k = net(
torch.tensor( torch.tensor(
[c, r, 60, 80], [c, r, 60, 80],
@ -52,29 +30,38 @@ def makeplt(i, net):
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0)
)[0][i].item() )[0][i].item()
p[r][c] = k p[r][c] = k
return p return p
def plot(src): def plot(src):
policy_net = DQN(
len(Celeste.state_number_map),
len(Celeste.action_space)
).to(compute_device)
checkpoint = torch.load(src) checkpoint = torch.load(src)
policy_net.load_state_dict(checkpoint["policy_state_dict"]) policy_net.load_state_dict(checkpoint["policy_state_dict"])
fig, axs = plt.subplots(2, 4, figsize = (15, 10)) fig, axs = plt.subplots(2, 4, figsize = (20, 10))
for a in range(len(axs.ravel())): for a in range(len(axs.ravel())):
ax = axs.ravel()[a] ax = axs.ravel()[a]
ax.set(adjustable="box", aspect="equal") ax.set(
adjustable = "box",
aspect = "equal"
)
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
plot = ax.pcolor( plot = ax.pcolor(
makeplt(a, policy_net), makeplt(a, policy_net),
cmap = "Greens", cmap = "Greens",
vmin = 0, vmin = 0,
) )
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
fig.colorbar(plot) fig.colorbar(plot)
print(src) print(src)
fig.savefig(out_dir / f"{src.stem}.png") fig.savefig(out_dir / f"{src.stem}.png")