Mark
/
celeste-ai
Archived
1
0
Fork 0

Updated main plot script

master
Mark 2023-02-24 14:23:48 -08:00
parent c9e04dcd41
commit 19923e672c
Signed by: Mark
GPG Key ID: AD62BB059C2AAEE4
1 changed files with 42 additions and 13 deletions

View File

@ -8,6 +8,7 @@ from multiprocessing import Pool
m = Path("model_data/current") m = Path("model_data/current")
scaled = True
# Make "predicted reward" plots # Make "predicted reward" plots
def plot_pred(src_model): def plot_pred(src_model):
@ -15,9 +16,19 @@ def plot_pred(src_model):
src_model, src_model,
m / f"plots/predicted/{src_model.stem}.png", m / f"plots/predicted/{src_model.stem}.png",
device = torch.device("cpu") device = torch.device("cpu"),
scaled = scaled
) )
# Make "best action" plots
def plot_best(src_model):
plotting.best_action(
src_model,
m / f"plots/best_action/{src_model.stem}.png",
device = torch.device("cpu"),
scaled = scaled
)
# Make "actual reward" plots # Make "actual reward" plots
def plot_act(src_model): def plot_act(src_model):
@ -30,8 +41,17 @@ def plot_act(src_model):
) )
# Which plots should we make?
plots = {
"prediction": True,
"actual": False,
"best": True
}
if __name__ == "__main__": if __name__ == "__main__":
if plots["prediction"]:
print("Making prediction plots...") print("Making prediction plots...")
with Pool(5) as p: with Pool(5) as p:
p.map( p.map(
@ -39,6 +59,15 @@ if __name__ == "__main__":
list((m / "model_archive").iterdir()) list((m / "model_archive").iterdir())
) )
if plots["best"]:
print("Making best-action plots...")
with Pool(5) as p:
p.map(
plot_best,
list((m / "model_archive").iterdir())
)
if plots["actual"]:
print("Making actual plots...") print("Making actual plots...")
with Pool(5) as p: with Pool(5) as p:
p.map( p.map(