Mark
/
celeste-ai
Archived
1
0
Fork 0

Made plotter faster

master
Mark 2023-02-19 08:49:33 -08:00
parent 6fe0d6e1cd
commit 4fbf1ea3f5
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 21 additions and 22 deletions

View File

@ -17,24 +17,6 @@ out_dir.mkdir(parents = True, exist_ok = True)
src_dir = Path("model_data/current/model_archive") src_dir = Path("model_data/current/model_archive")
def makeplt(i, net):
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0][i].item()
p[r][c] = k
return p
def plot(src): def plot(src):
policy_net = DQN( policy_net = DQN(
len(Celeste.state_number_map), len(Celeste.state_number_map),
@ -47,21 +29,38 @@ def plot(src):
fig, axs = plt.subplots(2, 4, figsize = (20, 10)) fig, axs = plt.subplots(2, 4, figsize = (20, 10))
# Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
k = np.asarray(policy_net(
torch.tensor(
[c, r, 60, 80],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
)[0])
p[r][c] = k
# Plot predictions
for a in range(len(axs.ravel())): for a in range(len(axs.ravel())):
ax = axs.ravel()[a] ax = axs.ravel()[a]
ax.set( ax.set(
adjustable = "box", adjustable = "box",
aspect = "equal" aspect = "equal",
title = Celeste.action_space[a]
) )
ax.set_title(Celeste.action_space[a])
ax.invert_yaxis()
plot = ax.pcolor( plot = ax.pcolor(
makeplt(a, policy_net), p[:,:,a],
cmap = "Greens", cmap = "Greens",
vmin = 0, vmin = 0,
) )
ax.invert_yaxis()
fig.colorbar(plot) fig.colorbar(plot)
print(src) print(src)
fig.savefig(out_dir / f"{src.stem}.png") fig.savefig(out_dir / f"{src.stem}.png")