From 19923e672c6752e6a8da40bdf7904c919fbb1b17 Mon Sep 17 00:00:00 2001 From: Mark Date: Fri, 24 Feb 2023 14:23:48 -0800 Subject: [PATCH] Updated main plot script --- celeste/plot.py | 55 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/celeste/plot.py b/celeste/plot.py index cd8699f..84dd7b3 100644 --- a/celeste/plot.py +++ b/celeste/plot.py @@ -8,6 +8,7 @@ from multiprocessing import Pool m = Path("model_data/current") +scaled = True # Make "predicted reward" plots def plot_pred(src_model): @@ -15,9 +16,19 @@ def plot_pred(src_model): src_model, m / f"plots/predicted/{src_model.stem}.png", - device = torch.device("cpu") + device = torch.device("cpu"), + scaled = scaled ) +# Make "best action" plots +def plot_best(src_model): + plotting.best_action( + src_model, + m / f"plots/best_action/{src_model.stem}.png", + + device = torch.device("cpu"), + scaled = scaled + ) # Make "actual reward" plots def plot_act(src_model): @@ -30,18 +41,36 @@ def plot_act(src_model): ) +# Which plots should we make? +plots = { + "prediction": True, + "actual": False, + "best": True +} + if __name__ == "__main__": - print("Making prediction plots...") - with Pool(5) as p: - p.map( - plot_pred, - list((m / "model_archive").iterdir()) - ) - print("Making actual plots...") - with Pool(5) as p: - p.map( - plot_act, - list((m / "model_archive").iterdir()) - ) \ No newline at end of file + if plots["prediction"]: + print("Making prediction plots...") + with Pool(5) as p: + p.map( + plot_pred, + list((m / "model_archive").iterdir()) + ) + + if plots["best"]: + print("Making best-action plots...") + with Pool(5) as p: + p.map( + plot_best, + list((m / "model_archive").iterdir()) + ) + + if plots["actual"]: + print("Making actual plots...") + with Pool(5) as p: + p.map( + plot_act, + list((m / "model_archive").iterdir()) + ) \ No newline at end of file