Compare commits
No commits in common. "c372ef8cc7d25a8c532a98ac8b6ffb414808be62" and "4ff32b91ea4a895c12fc24a70664e2592b753a9a" have entirely different histories.
c372ef8cc7
...
4ff32b91ea
|
@ -50,6 +50,7 @@ class Celeste:
|
||||||
action_space = [
|
action_space = [
|
||||||
"left", # move left 0
|
"left", # move left 0
|
||||||
"right", # move right 1
|
"right", # move right 1
|
||||||
|
#"jump", # jump
|
||||||
"jump-l", # jump left 2
|
"jump-l", # jump left 2
|
||||||
"jump-r", # jump right 3
|
"jump-r", # jump right 3
|
||||||
|
|
||||||
|
@ -85,13 +86,6 @@ class Celeste:
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# Maps room_x, room_y coordinates to stage number.
|
|
||||||
stage_map = [
|
|
||||||
[0, 1, 2, 3, 4]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pico_path,
|
pico_path,
|
||||||
|
@ -200,7 +194,9 @@ class Celeste:
|
||||||
def state(self):
|
def state(self):
|
||||||
try:
|
try:
|
||||||
stage = (
|
stage = (
|
||||||
Celeste.stage_map
|
[
|
||||||
|
[0, 1, 2, 3, 4]
|
||||||
|
]
|
||||||
[int(self._internal_state["ry"])]
|
[int(self._internal_state["ry"])]
|
||||||
[int(self._internal_state["rx"])]
|
[int(self._internal_state["rx"])]
|
||||||
)
|
)
|
||||||
|
|
|
@ -341,23 +341,10 @@ def on_state_after(celeste, before_out):
|
||||||
dtype = torch.long
|
dtype = torch.long
|
||||||
)
|
)
|
||||||
|
|
||||||
finished_stage = False
|
|
||||||
# No reward if dead
|
|
||||||
if next_state.deaths != 0:
|
if next_state.deaths != 0:
|
||||||
pt_next_state = None
|
pt_next_state = None
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
# Reward for finishing stage
|
|
||||||
elif next_state.stage >= 1:
|
|
||||||
finished_stage = True
|
|
||||||
reward = next_state.next_point - state.next_point
|
|
||||||
reward += 1
|
|
||||||
|
|
||||||
# Add to point counter
|
|
||||||
for i in range(state.next_point, state.next_point + reward):
|
|
||||||
point_counter[i] += 1
|
|
||||||
|
|
||||||
# Regular reward
|
|
||||||
else:
|
else:
|
||||||
pt_next_state = torch.tensor(
|
pt_next_state = torch.tensor(
|
||||||
[getattr(next_state, x) for x in Celeste.state_number_map],
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
||||||
|
@ -414,7 +401,7 @@ def on_state_after(celeste, before_out):
|
||||||
|
|
||||||
# Move on to the next episode once we reach
|
# Move on to the next episode once we reach
|
||||||
# a terminal state.
|
# a terminal state.
|
||||||
if (next_state.deaths != 0 or finished_stage):
|
if (next_state.deaths != 0):
|
||||||
s = celeste.state
|
s = celeste.state
|
||||||
with model_train_log.open("a") as f:
|
with model_train_log.open("a") as f:
|
||||||
f.write(json.dumps({
|
f.write(json.dumps({
|
||||||
|
|
|
@ -22,6 +22,8 @@ render_dir () {
|
||||||
$OUTPUT_DIR/${1##*/}.mp4
|
$OUTPUT_DIR/${1##*/}.mp4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Todo: error out if exists
|
# Todo: error out if exists
|
||||||
mkdir -p $OUTPUT_DIR
|
mkdir -p $OUTPUT_DIR
|
||||||
|
|
||||||
|
@ -48,18 +50,17 @@ ffmpeg \
|
||||||
-safe 0 \
|
-safe 0 \
|
||||||
-i video_merge_list \
|
-i video_merge_list \
|
||||||
-vf "scale=1024x1024:flags=neighbor" \
|
-vf "scale=1024x1024:flags=neighbor" \
|
||||||
$SC_ROOT/1x.mp4
|
$OUTPUT_DIR/00-all.mp4
|
||||||
|
|
||||||
rm video_merge_list
|
rm video_merge_list
|
||||||
|
|
||||||
# Make accelerated video
|
# Make accelerated video
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
-loglevel error -stats -y \
|
-loglevel error -stats -y \
|
||||||
-i $SC_ROOT/1x.mp4 \
|
-i $OUTPUT_DIR/00-all.mp4 \
|
||||||
-framerate 60 \
|
-framerate 60 \
|
||||||
-filter:v "setpts=0.125*PTS" \
|
-filter:v "setpts=0.125*PTS" \
|
||||||
$SC_ROOT/8x.mp4
|
$SC_ROOT/8x.mp4
|
||||||
|
|
||||||
echo "Cleaning up..."
|
echo "Cleaning up..."
|
||||||
|
|
||||||
rm -dr $OUTPUT_DIR
|
rm -dr $OUTPUT_DIR
|
||||||
|
|
|
@ -200,12 +200,13 @@ for ep in range(num_episodes):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
|
|
||||||
# Only train the network if we have enough
|
# Only train the network if we have enough
|
||||||
# transitions in memory to do so.
|
# transitions in memory to do so.
|
||||||
if len(memory) >= BATCH_SIZE:
|
if len(memory) >= BATCH_SIZE:
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
|
||||||
# Run optimizer
|
# Run optimizer
|
||||||
optimize.optimize_model(
|
optimize.optimize_model(
|
||||||
memory,
|
memory,
|
||||||
|
|
Reference in New Issue