Removed "can_dash" input value
parent
f40b58508e
commit
0b61702677
|
@ -70,7 +70,7 @@ class Celeste:
|
||||||
#"ypos",
|
#"ypos",
|
||||||
"xpos_scaled",
|
"xpos_scaled",
|
||||||
"ypos_scaled",
|
"ypos_scaled",
|
||||||
"can_dash_int"
|
#"can_dash_int"
|
||||||
#"next_point_x",
|
#"next_point_x",
|
||||||
#"next_point_y"
|
#"next_point_y"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import matplotlib as mpl
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
# All of the following are required to load
|
# All of the following are required to load
|
||||||
|
@ -34,7 +35,7 @@ def best_action(
|
||||||
|
|
||||||
|
|
||||||
# Compute preditions
|
# Compute preditions
|
||||||
p = np.zeros((128, 128, 2), dtype=np.float32)
|
p = np.zeros((128, 128), dtype=np.float32)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for r in range(len(p)):
|
for r in range(len(p)):
|
||||||
for c in range(len(p[r])):
|
for c in range(len(p[r])):
|
||||||
|
@ -43,26 +44,31 @@ def best_action(
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y, 0],
|
[x, y],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
)[0])
|
)[0])
|
||||||
p[r][c][0] = np.argmax(k)
|
p[r][c] = np.argmax(k)
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
|
||||||
torch.tensor(
|
|
||||||
[x, y, 1],
|
|
||||||
dtype = torch.float32,
|
|
||||||
device = device
|
|
||||||
).unsqueeze(0)
|
|
||||||
)[0])
|
|
||||||
p[r][c][1] = np.argmax(k)
|
|
||||||
|
|
||||||
|
cmap = mpl.colors.ListedColormap(
|
||||||
|
[
|
||||||
|
"forestgreen",
|
||||||
|
"firebrick",
|
||||||
|
"lightgreen",
|
||||||
|
"salmon",
|
||||||
|
"darkturquoise",
|
||||||
|
"sandybrown",
|
||||||
|
"olive",
|
||||||
|
"darkorchid",
|
||||||
|
"mediumvioletred"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Plot predictions
|
# Plot predictions
|
||||||
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
|
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
|
||||||
ax = axs[0]
|
ax = axs
|
||||||
ax.set(
|
ax.set(
|
||||||
adjustable = "box",
|
adjustable = "box",
|
||||||
aspect = "equal",
|
aspect = "equal",
|
||||||
|
@ -70,30 +76,16 @@ def best_action(
|
||||||
)
|
)
|
||||||
|
|
||||||
plot = ax.pcolor(
|
plot = ax.pcolor(
|
||||||
p[:,:,0],
|
p,
|
||||||
cmap = "Set1",
|
cmap = cmap,
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
vmax = 8
|
vmax = 8
|
||||||
)
|
)
|
||||||
ax.invert_yaxis()
|
ax.invert_yaxis()
|
||||||
fig.colorbar(plot)
|
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
|
||||||
|
cbar.ax.set_yticklabels(Celeste.action_space)
|
||||||
|
|
||||||
ax = axs[1]
|
|
||||||
ax.set(
|
|
||||||
adjustable = "box",
|
|
||||||
aspect = "equal",
|
|
||||||
title = "Best Action"
|
|
||||||
)
|
|
||||||
|
|
||||||
plot = ax.pcolor(
|
|
||||||
p[:,:,0],
|
|
||||||
cmap = "Set1",
|
|
||||||
vmin = 0,
|
|
||||||
vmax = 8
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
fig.savefig(out_filename)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
|
@ -43,7 +43,7 @@ def predicted_reward(
|
||||||
|
|
||||||
k = np.asarray(policy_net(
|
k = np.asarray(policy_net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[x, y, 0],
|
[x, y],
|
||||||
dtype = torch.float32,
|
dtype = torch.float32,
|
||||||
device = device
|
device = device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
|
|
|
@ -47,14 +47,6 @@ plots = {
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
if plots["prediction"]:
|
|
||||||
print("Making prediction plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_pred,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["best"]:
|
if plots["best"]:
|
||||||
print("Making best-action plots...")
|
print("Making best-action plots...")
|
||||||
with Pool(5) as p:
|
with Pool(5) as p:
|
||||||
|
@ -63,6 +55,14 @@ if __name__ == "__main__":
|
||||||
list((m / "model_archive").iterdir())
|
list((m / "model_archive").iterdir())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if plots["prediction"]:
|
||||||
|
print("Making prediction plots...")
|
||||||
|
with Pool(5) as p:
|
||||||
|
p.map(
|
||||||
|
plot_pred,
|
||||||
|
list((m / "model_archive").iterdir())
|
||||||
|
)
|
||||||
|
|
||||||
if plots["actual"]:
|
if plots["actual"]:
|
||||||
print("Making actual plots...")
|
print("Making actual plots...")
|
||||||
with Pool(5) as p:
|
with Pool(5) as p:
|
||||||
|
|
Reference in New Issue