Mark
/
celeste-ai
Archived
1
0
Fork 0

Tweaked plotter for new model

master
Mark 2023-02-24 14:23:25 -08:00
parent e9c0521ff5
commit bae70e0cfa
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 10 additions and 3 deletions

View File

@ -34,13 +34,16 @@ def predicted_reward(
# Compute preditions # Compute preditions
p = np.zeros((128, 128, 8), dtype=np.float32) p = np.zeros((128, 128, 9), dtype=np.float32)
with torch.no_grad(): with torch.no_grad():
for r in range(len(p)): for r in range(len(p)):
for c in range(len(p[r])): for c in range(len(p[r])):
x = c / 128.0
y = r / 128.0
k = np.asarray(policy_net( k = np.asarray(policy_net(
torch.tensor( torch.tensor(
[c, r, 60, 80], [x, y, 0],
dtype = torch.float32, dtype = torch.float32,
device = device device = device
).unsqueeze(0) ).unsqueeze(0)
@ -49,8 +52,11 @@ def predicted_reward(
# Plot predictions # Plot predictions
fig, axs = plt.subplots(2, 4, figsize = (20, 10)) fig, axs = plt.subplots(2, 5, figsize = (20, 10))
for a in range(len(axs.ravel())): for a in range(len(axs.ravel())):
if a >= len(Celeste.action_space):
continue
ax = axs.ravel()[a] ax = axs.ravel()[a]
ax.set( ax.set(
adjustable = "box", adjustable = "box",
@ -62,6 +68,7 @@ def predicted_reward(
p[:,:,a], p[:,:,a],
cmap = "Greens", cmap = "Greens",
vmin = 0, vmin = 0,
#vmax = 5
) )
ax.invert_yaxis() ax.invert_yaxis()