Made plotter faster
parent
6fe0d6e1cd
commit
4fbf1ea3f5
|
@ -17,24 +17,6 @@ out_dir.mkdir(parents = True, exist_ok = True)
|
|||
src_dir = Path("model_data/current/model_archive")
|
||||
|
||||
|
||||
def makeplt(i, net):
|
||||
p = np.zeros((128, 128), dtype=np.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
for r in range(len(p)):
|
||||
for c in range(len(p[r])):
|
||||
k = net(
|
||||
torch.tensor(
|
||||
[c, r, 60, 80],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
)[0][i].item()
|
||||
p[r][c] = k
|
||||
return p
|
||||
|
||||
|
||||
|
||||
def plot(src):
|
||||
policy_net = DQN(
|
||||
len(Celeste.state_number_map),
|
||||
|
@ -47,21 +29,38 @@ def plot(src):
|
|||
|
||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||
|
||||
|
||||
# Compute preditions
|
||||
p = np.zeros((128, 128, 8), dtype=np.float32)
|
||||
with torch.no_grad():
|
||||
for r in range(len(p)):
|
||||
for c in range(len(p[r])):
|
||||
k = np.asarray(policy_net(
|
||||
torch.tensor(
|
||||
[c, r, 60, 80],
|
||||
dtype = torch.float32,
|
||||
device = compute_device
|
||||
).unsqueeze(0)
|
||||
)[0])
|
||||
p[r][c] = k
|
||||
|
||||
|
||||
# Plot predictions
|
||||
for a in range(len(axs.ravel())):
|
||||
ax = axs.ravel()[a]
|
||||
ax.set(
|
||||
adjustable = "box",
|
||||
aspect = "equal"
|
||||
aspect = "equal",
|
||||
title = Celeste.action_space[a]
|
||||
)
|
||||
ax.set_title(Celeste.action_space[a])
|
||||
ax.invert_yaxis()
|
||||
|
||||
plot = ax.pcolor(
|
||||
makeplt(a, policy_net),
|
||||
p[:,:,a],
|
||||
cmap = "Greens",
|
||||
vmin = 0,
|
||||
)
|
||||
|
||||
ax.invert_yaxis()
|
||||
fig.colorbar(plot)
|
||||
print(src)
|
||||
fig.savefig(out_dir / f"{src.stem}.png")
|
||||
|
|
Reference in New Issue