Compare commits
2 Commits
072867a7a3
...
3d25d63efe
Author | SHA1 | Date | |
---|---|---|---|
3d25d63efe | |||
05d745cc07 |
@ -1,4 +1,3 @@
|
|||||||
from .plot_actual_reward import actual_reward
|
|
||||||
from .plot_predicted_reward import predicted_reward
|
from .plot_predicted_reward import predicted_reward
|
||||||
from .plot_best_action import best_action
|
from .plot_best_action import best_action
|
||||||
|
|
||||||
|
@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
# All of the following are required to load
|
|
||||||
# a pickled model.
|
|
||||||
from celeste_ai.celeste import Celeste
|
|
||||||
from celeste_ai.network import DQN
|
|
||||||
from celeste_ai.network import Transition
|
|
||||||
|
|
||||||
def actual_reward(
|
|
||||||
model_file: Path,
|
|
||||||
target_point: tuple[int, int],
|
|
||||||
out_filename: Path,
|
|
||||||
*,
|
|
||||||
device = torch.device("cpu")
|
|
||||||
):
|
|
||||||
if not model_file.is_file():
|
|
||||||
raise Exception(f"Bad model file {model_file}")
|
|
||||||
out_filename.parent.mkdir(exist_ok = True, parents = True)
|
|
||||||
|
|
||||||
|
|
||||||
checkpoint = torch.load(
|
|
||||||
model_file,
|
|
||||||
map_location = device
|
|
||||||
)
|
|
||||||
memory = checkpoint["memory"]
|
|
||||||
|
|
||||||
|
|
||||||
r = np.zeros((128, 128, 8), dtype=np.float32)
|
|
||||||
for m in memory:
|
|
||||||
x, y, x_target, y_target = list(m.state[0])
|
|
||||||
|
|
||||||
action = m.action[0].item()
|
|
||||||
x = int(x.item())
|
|
||||||
y = int(y.item())
|
|
||||||
x_target = int(x_target.item())
|
|
||||||
y_target = int(y_target.item())
|
|
||||||
|
|
||||||
# Only plot memory related to this point
|
|
||||||
if (x_target, y_target) != target_point:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.reward[0].item() == 1:
|
|
||||||
r[y][x][action] += 1
|
|
||||||
else:
|
|
||||||
r[y][x][action] -= 1
|
|
||||||
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 4, figsize = (20, 10))
|
|
||||||
|
|
||||||
|
|
||||||
for a in range(len(axs.ravel())):
|
|
||||||
ax = axs.ravel()[a]
|
|
||||||
ax.set(
|
|
||||||
adjustable = "box",
|
|
||||||
aspect = "equal",
|
|
||||||
title = Celeste.action_space[a]
|
|
||||||
)
|
|
||||||
|
|
||||||
plot = ax.pcolor(
|
|
||||||
r[:,:,a],
|
|
||||||
cmap = "seismic_r",
|
|
||||||
vmin = -10,
|
|
||||||
vmax = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Draw target point on plot
|
|
||||||
ax.plot(
|
|
||||||
target_point[0],
|
|
||||||
target_point[1],
|
|
||||||
"k."
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.invert_yaxis()
|
|
||||||
fig.colorbar(plot)
|
|
||||||
|
|
||||||
fig.savefig(out_filename)
|
|
||||||
plt.close()
|
|
@ -1,65 +1,90 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
|
||||||
# Where screenshots are saved
|
# Where screenshots are saved.
|
||||||
|
# SC_ROOT/SC_DIR should contain episode screenshot directories
|
||||||
SC_ROOT="model_data/current"
|
SC_ROOT="model_data/current"
|
||||||
|
SC_DIR="screenshots"
|
||||||
# WILL BE DELETED
|
|
||||||
OUTPUT_DIR="model_data/video_output"
|
|
||||||
|
|
||||||
|
|
||||||
# To make with fade in and out:
|
# Select a temporary working directory
|
||||||
# ffmpeg -framerate 30 -i %03d.png -vf "scale=1024x1024:flags=neighbor,fade=in:0:45,fade=out:1040:45" out.webm
|
# if false, uses ramdisk.
|
||||||
|
# set to true if ramdisk overflows.
|
||||||
|
if false; then
|
||||||
|
OUTPUT_DIR="model_data/video_output"
|
||||||
|
|
||||||
render_dir () {
|
# This directory will be deleted.
|
||||||
|
# Make sure it doesn't already exist.
|
||||||
|
if [ -e "$OUTPUT_DIR" ]; then
|
||||||
|
echo "$OUTPUT_DIR exists, exiting. Please delete it manually."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
else
|
||||||
|
OUTPUT_DIR=$(mktemp -d)
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# Usage:
|
||||||
|
# render_episode <src_dir> <output_name>
|
||||||
|
#
|
||||||
|
# Turns a directory of frame screenshots into a video.
|
||||||
|
# Applies upscaling. We do it early, because upscaling
|
||||||
|
# after encoding will exaggerate artifacts.
|
||||||
|
render_episode () {
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
-y \
|
-y \
|
||||||
-loglevel quiet \
|
-loglevel quiet \
|
||||||
-framerate 30 \
|
-framerate 30 \
|
||||||
-i $1/hackcel_%003d.png \
|
-i "$1/hackcel_%003d.png" \
|
||||||
-c:v libx264 \
|
-c:v libx264 \
|
||||||
-pix_fmt yuv420p \
|
-crf 20 \
|
||||||
$OUTPUT_DIR/${1##*/}.mp4
|
-preset slow \
|
||||||
|
-tune animation \
|
||||||
|
-vf "scale=1024x1024:flags=neighbor" \
|
||||||
|
"$2.mp4"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Todo: error out if exists
|
|
||||||
mkdir -p $OUTPUT_DIR
|
|
||||||
|
|
||||||
|
|
||||||
echo "Making episode files..."
|
echo "Making episode files..."
|
||||||
for D in $SC_ROOT/screenshots/*; do
|
for D in "$SC_ROOT/$SC_DIR"/*; do
|
||||||
if [ -d "${D}" ]; then
|
if [ -d "${D}" ]; then
|
||||||
render_dir $D
|
render_episode "$D" "$OUTPUT_DIR/${D##*/}"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
echo "Done."
|
|
||||||
|
|
||||||
|
|
||||||
# Generate video for each run
|
echo "Merging..."
|
||||||
for f in $OUTPUT_DIR/*.mp4; do
|
for f in "$OUTPUT_DIR"/*.mp4; do
|
||||||
echo file \'$f\' >> video_merge_list
|
echo file \'$f\' >> "$OUTPUT_DIR/video_merge_list"
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Merge videos
|
# Merge videos
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
-loglevel error -stats -y \
|
-loglevel error -stats -y \
|
||||||
-f concat \
|
-f concat \
|
||||||
-safe 0 \
|
-safe 0 \
|
||||||
-i video_merge_list \
|
-i "$OUTPUT_DIR/video_merge_list" \
|
||||||
-vf "scale=1024x1024:flags=neighbor" \
|
"$SC_ROOT/1x.mp4"
|
||||||
$SC_ROOT/1x.mp4
|
echo ""
|
||||||
|
echo "Making accelerated video..."
|
||||||
|
|
||||||
|
|
||||||
rm video_merge_list
|
|
||||||
|
|
||||||
# Make accelerated video
|
# Make accelerated video
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
-loglevel error -stats -y \
|
-loglevel error -stats -y \
|
||||||
-i $SC_ROOT/1x.mp4 \
|
-i "$SC_ROOT/1x.mp4" \
|
||||||
-framerate 60 \
|
-framerate 60 \
|
||||||
-filter:v "setpts=0.125*PTS" \
|
-filter:v "setpts=0.125*PTS" \
|
||||||
$SC_ROOT/8x.mp4
|
"$SC_ROOT/8x.mp4"
|
||||||
|
|
||||||
echo "Cleaning up..."
|
|
||||||
|
|
||||||
|
|
||||||
|
echo "Cleaning up...."
|
||||||
rm -dr $OUTPUT_DIR
|
rm -dr $OUTPUT_DIR
|
||||||
|
@ -5,10 +5,9 @@ import celeste_ai.plotting as plotting
|
|||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
m = Path("model_data/current")
|
m = Path("model_data/current")
|
||||||
|
|
||||||
# Make "predicted reward" plots
|
|
||||||
def plot_pred(src_model):
|
def plot_pred(src_model):
|
||||||
plotting.predicted_reward(
|
plotting.predicted_reward(
|
||||||
src_model,
|
src_model,
|
||||||
@ -17,7 +16,6 @@ def plot_pred(src_model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make "best action" plots
|
|
||||||
def plot_best(src_model):
|
def plot_best(src_model):
|
||||||
plotting.best_action(
|
plotting.best_action(
|
||||||
src_model,
|
src_model,
|
||||||
@ -26,47 +24,14 @@ def plot_best(src_model):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make "actual reward" plots
|
|
||||||
def plot_act(src_model):
|
|
||||||
plotting.actual_reward(
|
|
||||||
src_model,
|
|
||||||
(60, 80),
|
|
||||||
m / f"plots/actual/{src_model.stem}.png",
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
for k, v in {
|
||||||
)
|
#"prediction": plot_pred,
|
||||||
|
"best_action": plot_best,
|
||||||
|
}.items():
|
||||||
# Which plots should we make?
|
print(f"Making {k} plots...")
|
||||||
plots = {
|
with Pool(5) as p:
|
||||||
"prediction": True,
|
p.map(
|
||||||
"actual": False,
|
v,
|
||||||
"best": True
|
list((m / "model_archive").iterdir())
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
if plots["best"]:
|
|
||||||
print("Making best-action plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_best,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["prediction"]:
|
|
||||||
print("Making prediction plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_pred,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
||||||
|
|
||||||
if plots["actual"]:
|
|
||||||
print("Making actual plots...")
|
|
||||||
with Pool(5) as p:
|
|
||||||
p.map(
|
|
||||||
plot_act,
|
|
||||||
list((m / "model_archive").iterdir())
|
|
||||||
)
|
|
Reference in New Issue
Block a user