Added new plotting script
parent
55ac62dc47
commit
ab355475d5
|
@ -0,0 +1,47 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import celeste_ai.plotting as plotting
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
|
||||
m = Path("model_data/current")
|
||||
|
||||
|
||||
# Make "predicted reward" plots
|
||||
def plot_pred(src_model):
|
||||
plotting.predicted_reward(
|
||||
src_model,
|
||||
m / f"plots/predicted/{src_model.stem}.png",
|
||||
|
||||
device = torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
# Make "actual reward" plots
|
||||
def plot_act(src_model):
|
||||
plotting.actual_reward(
|
||||
src_model,
|
||||
(60, 80),
|
||||
m / f"plots/actual/{src_model.stem}.png",
|
||||
|
||||
device = torch.device("cpu")
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Making prediction plots...")
|
||||
with Pool(5) as p:
|
||||
p.map(
|
||||
plot_pred,
|
||||
list((m / "model_archive").iterdir())
|
||||
)
|
||||
|
||||
print("Making actual plots...")
|
||||
with Pool(5) as p:
|
||||
p.map(
|
||||
plot_act,
|
||||
list((m / "model_archive").iterdir())
|
||||
)
|
Reference in New Issue