cleanup
parent
6668614bbd
commit
6fe0d6e1cd
|
@ -8,43 +8,21 @@ from celeste import Celeste
|
||||||
from main import DQN
|
from main import DQN
|
||||||
from main import Transition
|
from main import Transition
|
||||||
|
|
||||||
# Use cpu, the script is faster in parallel.
|
# Use cpu, this script is faster in parallel.
|
||||||
compute_device = torch.device("cpu")
|
compute_device = torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
# Celeste env properties
|
|
||||||
n_observations = len(Celeste.state_number_map)
|
|
||||||
n_actions = len(Celeste.action_space)
|
|
||||||
|
|
||||||
|
|
||||||
out_dir = Path("out/plots")
|
out_dir = Path("out/plots")
|
||||||
out_dir.mkdir(parents = True, exist_ok = True)
|
out_dir.mkdir(parents = True, exist_ok = True)
|
||||||
|
|
||||||
src_dir = Path("model_data/model_archive")
|
src_dir = Path("model_data/current/model_archive")
|
||||||
|
|
||||||
policy_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
target_net = DQN(
|
|
||||||
n_observations,
|
|
||||||
n_actions
|
|
||||||
).to(compute_device)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(
|
|
||||||
policy_net.parameters(),
|
|
||||||
lr = 0.01, # Hyperparameter: learning rate
|
|
||||||
amsgrad = True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def makeplt(i, net):
|
def makeplt(i, net):
|
||||||
p = np.zeros((128, 128), dtype=np.float32)
|
p = np.zeros((128, 128), dtype=np.float32)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
for r in range(len(p)):
|
for r in range(len(p)):
|
||||||
for c in range(len(p[r])):
|
for c in range(len(p[r])):
|
||||||
with torch.no_grad():
|
|
||||||
k = net(
|
k = net(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[c, r, 60, 80],
|
[c, r, 60, 80],
|
||||||
|
@ -52,29 +30,38 @@ def makeplt(i, net):
|
||||||
device = compute_device
|
device = compute_device
|
||||||
).unsqueeze(0)
|
).unsqueeze(0)
|
||||||
)[0][i].item()
|
)[0][i].item()
|
||||||
|
|
||||||
p[r][c] = k
|
p[r][c] = k
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot(src):
|
def plot(src):
|
||||||
|
policy_net = DQN(
|
||||||
|
len(Celeste.state_number_map),
|
||||||
|
len(Celeste.action_space)
|
||||||
|
).to(compute_device)
|
||||||
|
|
||||||
checkpoint = torch.load(src)
|
checkpoint = torch.load(src)
|
||||||
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (15, 10))
|
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
||||||
|
|
||||||
for a in range(len(axs.ravel())):
|
for a in range(len(axs.ravel())):
|
||||||
ax = axs.ravel()[a]
|
ax = axs.ravel()[a]
|
||||||
ax.set(adjustable="box", aspect="equal")
|
ax.set(
|
||||||
|
adjustable = "box",
|
||||||
|
aspect = "equal"
|
||||||
|
)
|
||||||
|
ax.set_title(Celeste.action_space[a])
|
||||||
|
ax.invert_yaxis()
|
||||||
|
|
||||||
plot = ax.pcolor(
|
plot = ax.pcolor(
|
||||||
makeplt(a, policy_net),
|
makeplt(a, policy_net),
|
||||||
cmap = "Greens",
|
cmap = "Greens",
|
||||||
vmin = 0,
|
vmin = 0,
|
||||||
)
|
)
|
||||||
ax.set_title(Celeste.action_space[a])
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
fig.colorbar(plot)
|
||||||
print(src)
|
print(src)
|
||||||
fig.savefig(out_dir / f"{src.stem}.png")
|
fig.savefig(out_dir / f"{src.stem}.png")
|
||||||
|
|
Reference in New Issue