Mark
/
celeste-ai
Archived
1
0
Fork 0

Removed "can_dash" input value

master
Mark 2023-02-26 12:09:05 -08:00
parent f40b58508e
commit 0b61702677
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
4 changed files with 34 additions and 42 deletions

View File

@ -70,7 +70,7 @@ class Celeste:
#"ypos",
"xpos_scaled",
"ypos_scaled",
"can_dash_int"
#"can_dash_int"
#"next_point_x",
#"next_point_y"
]

View File

@ -1,6 +1,7 @@
import torch
import numpy as np
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# All of the following are required to load
@ -34,7 +35,7 @@ def best_action(
# Compute preditions
p = np.zeros((128, 128, 2), dtype=np.float32)
p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad():
for r in range(len(p)):
for c in range(len(p[r])):
@ -43,26 +44,31 @@ def best_action(
k = np.asarray(policy_net(
torch.tensor(
[x, y, 0],
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][0] = np.argmax(k)
p[r][c] = np.argmax(k)
k = np.asarray(policy_net(
torch.tensor(
[x, y, 1],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][1] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions
fig, axs = plt.subplots(1, 2, figsize = (10, 10))
ax = axs[0]
fig, axs = plt.subplots(1, 1, figsize = (20, 20))
ax = axs
ax.set(
adjustable = "box",
aspect = "equal",
@ -70,30 +76,16 @@ def best_action(
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
p,
cmap = cmap,
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
ax = axs[1]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename)
plt.close()

View File

@ -43,7 +43,7 @@ def predicted_reward(
k = np.asarray(policy_net(
torch.tensor(
[x, y, 0],
[x, y],
dtype = torch.float32,
device = device
).unsqueeze(0)

View File

@ -47,14 +47,6 @@ plots = {
if __name__ == "__main__":
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["best"]:
print("Making best-action plots...")
with Pool(5) as p:
@ -63,6 +55,14 @@ if __name__ == "__main__":
list((m / "model_archive").iterdir())
)
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["actual"]:
print("Making actual plots...")
with Pool(5) as p: