Tweaked plotter for new model
parent
e9c0521ff5
commit
bae70e0cfa
|
@ -34,13 +34,16 @@ def predicted_reward(
|
||||||
|
|
||||||
|
|
||||||
# Compute preditions
|
# Compute preditions
|
||||||
p = np.zeros((128, 128, 8), dtype=np.float32)
|
p = np.zeros((128, 128, 9), dtype=np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for r in range(len(p)):
|
for r in range(len(p)):
|
||||||
for c in range(len(p[r])):
|
for c in range(len(p[r])):
|
||||||
|
x = c / 128.0
|
||||||
|
y = r / 128.0
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[c, r, 60, 80],
|
[x, y, 0],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
@ -49,8 +52,11 @@ def predicted_reward(
|
||||||
|
|
||||||
|
|
||||||
# Plot predictions
|
# Plot predictions
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
fig, axs = plt.subplots(2, 5, figsize = (20, 10))
|
||||||
for a in range(len(axs.ravel())):
|
for a in range(len(axs.ravel())):
|
||||||
|
if a >= len(Celeste.action_space):
|
||||||
|
continue
|
||||||
|
|
||||||
ax = axs.ravel()[a]
|
ax = axs.ravel()[a]
|
||||||
ax.set(
|
ax.set(
|
||||||
adjustable = "box",
|
adjustable = "box",
|
||||||
|
@ -62,6 +68,7 @@ def predicted_reward(
|
||||||
p[:,:,a],
|
p[:,:,a],
|
||||||
cmap = "Greens",
|
cmap = "Greens",
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
|
#vmax = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
|
|
Reference in New Issue