Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

..

No commits in common. "24dd65ace88ca89f2a6c051043d2f9dc5c439f49" and "f40b58508e910cbf99ebf8163c91220ee9cba1aa" have entirely different histories.

9 changed files with 297 additions and 399 deletions

View File

@ -70,24 +70,21 @@ class Celeste:
#"ypos", #"ypos",
"xpos_scaled", "xpos_scaled",
"ypos_scaled", "ypos_scaled",
#"can_dash_int" "can_dash_int"
#"next_point_x", #"next_point_x",
#"next_point_y" #"next_point_y"
] ]
# Targets the agent tries to reach. # Targets the agent tries to reach.
# The last target MUST be outside the frame. # The last target MUST be outside the frame.
# Format is X, Y, range, force_y
# force_y is optional. If true, y_value MUST match perfectly.
target_checkpoints = [ target_checkpoints = [
[ # Stage 1 [ # Stage 1
#(28, 88, 8), # Start pillar #(28, 88), # Start pillar
(60, 80, 8), # Middle pillar (60, 80), # Middle pillar
(105, 64, 8), # Right ledge (105, 64), # Right ledge
(25, 40, 8), # Left ledge (25, 40), # Left ledge
(97, 24, 5, True), # Small end ledge (110, 16), # End ledge
(110, 16, 8), # End ledge (110, -2), # Next stage
(110, -20, 8), # Next stage
] ]
] ]
@ -102,7 +99,7 @@ class Celeste:
self, self,
pico_path, pico_path,
*, *,
state_timeout = 20, state_timeout = 30,
cart_name = "hackcel.p8", cart_name = "hackcel.p8",
): ):
@ -147,7 +144,7 @@ class Celeste:
self._resetting = False # True between a call to .reset() and the first state message from pico. self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool self._keys = {} # Dictionary of "key": bool
def act(self, action: str | int): def act(self, action: str):
""" """
Specify what keys should be down. This does NOT send key events. Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time. Celeste._apply_keys() does that at the right time.
@ -156,9 +153,6 @@ class Celeste:
action (str): key name, as in Celeste.action_space action (str): key name, as in Celeste.action_space
""" """
if isinstance(action, int):
action = Celeste.action_space[action]
self._keys = {} self._keys = {}
if action is None: if action is None:
return return
@ -214,9 +208,9 @@ class Celeste:
[int(self._internal_state["rx"])] [int(self._internal_state["rx"])]
) )
if len(Celeste.target_checkpoints) <= stage: if len(Celeste.target_checkpoints) < stage:
next_point_x = 0 next_point_x = None
next_point_y = 0 next_point_y = None
else: else:
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0] next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1] next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
@ -335,65 +329,46 @@ class Celeste:
if self.state.stage <= 0:
# Calculate distance to each point
x = self.state.xpos
y = self.state.ypos
dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
if i < self._next_checkpoint_idx:
dist[i] = 1000
continue
# Update checkpoints # Calculate distance to each point
tx, ty = c[:2] x = self.state.xpos
dist[i] = (math.sqrt( y = self.state.ypos
(x-tx)*(x-tx) + dist = np.zeros(len(Celeste.target_checkpoints[self.state.stage]), dtype=np.float16)
((y-ty)*(y-ty))/2 for i, c in enumerate(Celeste.target_checkpoints[self.state.stage]):
# Possible modification: if i < self._next_checkpoint_idx:
# make x-distance twice as valuable as y-distance dist[i] = 1000
)) continue
min_idx = int(dist.argmin())
dist = int(dist[min_idx])
# Update checkpoints
t = Celeste.target_checkpoints[self.state.stage][min_idx] tx, ty = c
range = t[2] dist[i] = (math.sqrt(
if len(t) == 3: (x-tx)*(x-tx) +
force_y = False ((y-ty)*(y-ty))/2
else: # Possible modification:
force_y = t[3] # make x-distance twice as valuable as y-distance
))
min_idx = int(dist.argmin())
dist = int(dist[min_idx])
if force_y:
got_point = (
dist <= range and
y == t[1]
)
else:
got_point = dist <= range
if got_point: if dist <= 8:
self._next_checkpoint_idx = min_idx + 1 print(f"Got point {min_idx}")
self._last_checkpoint_state = self._state_counter self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point # Recalculate distance to new point
tx, ty = ( tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx]
Celeste.target_checkpoints dist = math.sqrt(
[self.state.stage] (x-tx)*(x-tx) +
[self._next_checkpoint_idx] ((y-ty)*(y-ty))/2
[:2] )
)
dist = math.sqrt(
(x-tx)*(x-tx) +
((y-ty)*(y-ty))/2
)
# Timeout if we spend too long between points # Timeout if we spend too long between points
elif self._state_counter - self._last_checkpoint_state > self.state_timeout: elif self._state_counter - self._last_checkpoint_state > self.state_timeout:
self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1) self._internal_state["dc"] = str(int(self._internal_state["dc"]) + 1)
self._dist = dist self._dist = dist
# Call step callbacks # Call step callbacks
# These should call celeste.act() to set next input # These should call celeste.act() to set next input

View File

@ -5,7 +5,7 @@ from collections import namedtuple
Transition = namedtuple( Transition = namedtuple(
"Transition", "Transition",
( (
"last_state", "state",
"action", "action",
"next_state", "next_state",
"reward" "reward"

View File

@ -1,7 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# All of the following are required to load # All of the following are required to load
@ -35,7 +34,7 @@ def best_action(
# Compute preditions # Compute preditions
p = np.zeros((128, 128), dtype=np.float32) p = np.zeros((128, 128, 2), dtype=np.float32)
with torch.no_grad(): with torch.no_grad():
for r in range(len(p)): for r in range(len(p)):
for c in range(len(p[r])): for c in range(len(p[r])):
@ -44,31 +43,26 @@ def best_action(
k = np.asarray(policy_net( k = np.asarray(policy_net(
torch.tensor( torch.tensor(
[x, y], [x, y, 0],
dtype = torch.float32, dtype = torch.float32,
device = device device = device
).unsqueeze(0) ).unsqueeze(0)
)[0]) )[0])
p[r][c] = np.argmax(k) p[r][c][0] = np.argmax(k)
k = np.asarray(policy_net(
torch.tensor(
[x, y, 1],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][1] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions # Plot predictions
fig, axs = plt.subplots(1, 1, figsize = (20, 20)) fig, axs = plt.subplots(1, 2, figsize = (10, 10))
ax = axs ax = axs[0]
ax.set( ax.set(
adjustable = "box", adjustable = "box",
aspect = "equal", aspect = "equal",
@ -76,16 +70,30 @@ def best_action(
) )
plot = ax.pcolor( plot = ax.pcolor(
p, p[:,:,0],
cmap = cmap, cmap = "Set1",
vmin = 0, vmin = 0,
vmax = 8 vmax = 8
) )
ax.invert_yaxis() ax.invert_yaxis()
cbar = fig.colorbar(plot, ticks = list(range(0, 9))) fig.colorbar(plot)
cbar.ax.set_yticklabels(Celeste.action_space)
ax = axs[1]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename) fig.savefig(out_filename)
plt.close() plt.close()

View File

@ -43,7 +43,7 @@ def predicted_reward(
k = np.asarray(policy_net( k = np.asarray(policy_net(
torch.tensor( torch.tensor(
[x, y], [x, y, 0],
dtype = torch.float32, dtype = torch.float32,
device = device device = device
).unsqueeze(0) ).unsqueeze(0)

View File

@ -5,31 +5,33 @@ import random
import math import math
import json import json
import torch import torch
import shutil
from celeste_ai import Celeste from celeste_ai import Celeste
from celeste_ai import DQN from celeste_ai import DQN
from celeste_ai import Transition from celeste_ai import Transition
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__": if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data/current") model_data_root = Path("model_data/current")
sm = ScreenshotManager( # Where PICO-8 saves screenshots.
# Where PICO-8 saves screenshots. # Probably your desktop.
# Probably your desktop. screenshot_source = Path("/home/mark/Desktop")
source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log" model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True) model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
# Remove old screenshots
shots = screenshot_source.glob("hackcel_*.png")
for s in shots:
s.unlink()
compute_device = torch.device( compute_device = torch.device(
@ -43,51 +45,66 @@ if __name__ == "__main__":
# Epsilon-greedy parameters # Epsilon-greedy parameters
# Probability of choosing a random action starts at #
# EPS_START and decays to EPS_END. # Original docs:
# EPS_DECAY controls the rate of decay. # EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.02 EPS_END = 0.02
EPS_DECAY = 100 EPS_DECAY = 100
# Bellman equation time-discount factor
GAMMA = 0.9
# Train on this many transitions from
# replay memory each round
BATCH_SIZE = 100
# Controls target_net soft update.
# Should be between 0 and 1.
TAU = 0.05
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
# How many times we've reached each point. # How many times we've reached each point.
# This is used to compute epsilon-greedy probability. # Used to compute epsilon-greedy probability with
# the parameters above.
point_counter = [0] * len(Celeste.target_checkpoints[0]) point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on BATCH_SIZE = 100
n_steps = 0 # Number of training steps we've completed # Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.05
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9
steps_done = 0
num_episodes = 100
episode_number = 0
archive_interval = 10
# Create replay memory. # Create replay memory.
# #
# Holds <Transition> objects, defined in # Transition: a container for naming data (defined in util.py)
# network.py # Memory: a deque that holds recent states as Transitions
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=50_000) memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
learning_rate = 0.001
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = learning_rate, lr = learning_rate,
@ -105,43 +122,11 @@ if __name__ == "__main__":
target_net.load_state_dict(checkpoint["target_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"] memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
n_episodes = checkpoint["n_episodes"] steps_done = checkpoint["steps_done"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"] point_counter = checkpoint["point_counter"]
def select_action(state, steps_done):
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
""" """
Select an action using an epsilon-greedy policy. Select an action using an epsilon-greedy policy.
@ -151,13 +136,19 @@ def select_action(state, x) -> int:
Decay rate is controlled by EPS_DECAY. Decay rate is controlled by EPS_DECAY.
""" """
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold # Calculate random step threshhold
eps_threshold = ( eps_threshold = (
EPS_END + (EPS_START - EPS_END) * EPS_END + (EPS_START - EPS_END) *
math.exp(-1.0 * x / EPS_DECAY) math.exp(
-1.0 * steps_done /
EPS_DECAY
)
) )
if random.random() > eps_threshold: if sample > eps_threshold:
with torch.no_grad(): with torch.no_grad():
# t.max(1) will return the largest column value of each row. # t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was # second column on max result is index of where max element was
@ -184,7 +175,7 @@ def optimize_model():
# Conversion. # Conversion.
# Combine states, actions, and rewards into their own tensors. # Combine states, actions, and rewards into their own tensors.
last_state_batch = torch.cat(batch.last_state) state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action) action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward) reward_batch = torch.cat(batch.reward)
@ -218,7 +209,7 @@ def optimize_model():
# This gives us a tensor that contains the return we expect to get # This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice. # at that state if we follow the model's advice.
state_action_values = policy_net(last_state_batch).gather(1, action_batch) state_action_values = policy_net(state_batch).gather(1, action_batch)
@ -291,21 +282,36 @@ def optimize_model():
def on_state_before(celeste): def on_state_before(celeste):
global steps_done
state = celeste.state state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
action = select_action( action = select_action(
# Put state in a tensor pt_state,
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# Random action probability is determined by
# the number of times we've reached the next point.
point_counter[state.next_point] point_counter[state.next_point]
) )
str_action = Celeste.action_space[action]
"""
action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action(
pt_state,
steps_done
)
str_action = Celeste.action_space[action]
"""
steps_done += 1
# For manual testing # For manual testing
#str_action = "" #str_action = ""
@ -313,114 +319,86 @@ def on_state_before(celeste):
# str_action = input("action> ") # str_action = input("action> ")
#action = Celeste.action_space.index(str_action) #action = Celeste.action_space.index(str_action)
print(Celeste.action_space[action]) print(str_action)
celeste.act(action) celeste.act(str_action)
return ( return state, action
state, # CelesteState
action # Integer
)
def compute_reward(last_state, state):
global point_counter
reward = None
# No reward if dead
if state.deaths != 0:
reward = 0
# Reward for finishing a stage
elif state.stage >= 1:
print("FINISHED STAGE!!")
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
reward += 1
# Add to point counter
for i in range(last_state.next_point, len(point_counter)):
point_counter[i] += 1
# Reward for reaching a checkpoint
elif last_state.next_point != state.next_point:
print(f"Got point {state.next_point}")
reward = state.next_point - last_state.next_point
# Add to point counter
for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1
# No reward otherwise
else:
reward = 0
# Strawberry reward
# (Will probably break current version of model)
#if state.berries[state.stage] and not state.berries[state.stage]:
# print(f"Got stage {state.stage} bonus")
# reward += 1
assert reward is not None
return reward * 10
def on_state_after(celeste, before_out): def on_state_after(celeste, before_out):
global n_episodes global episode_number
global n_steps
last_state, action = before_out state, action = before_out
next_state = celeste.state next_state = celeste.state
dead = next_state.deaths != 0
done = next_state.stage >= 1 pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
)
finished_stage = False
# No reward if dead
if next_state.deaths != 0:
pt_next_state = None
reward = 0
# Reward for finishing a stage
elif next_state.stage >= 1:
finished_stage = True
reward = next_state.next_point - state.next_point
reward += 1
# Add to point counter
for i in range(state.next_point, state.next_point + reward):
point_counter[i] += 1
# Regular reward
else:
pt_next_state = torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
reward = compute_reward(last_state, next_state)
if state.next_point == next_state.next_point:
if dead: reward = 0
next_state = None else:
elif done: print(f"Got point {state.next_point}")
# We don't set the next state to None because # Reward for reaching a point
# the optimization routine forces zero reward reward = next_state.next_point - state.next_point
# for terminal states.
# Copy last state instead. It's a hack, but it # Add to point counter
# should work. for i in range(state.next_point, state.next_point + reward):
next_state = last_state point_counter[i] += 1
# Strawberry reward
if next_state.berries[state.stage] and not state.berries[state.stage]:
print(f"Got stage {state.stage} bonus")
reward += 1
reward = reward * 10
pt_reward = torch.tensor([reward], device = compute_device)
# Add this state transition to memory. # Add this state transition to memory.
memory.append( memory.append(
Transition( Transition(
# last state pt_state,
torch.tensor( pt_action,
[getattr(last_state, x) for x in Celeste.state_number_map], pt_next_state,
dtype = torch.float32, pt_reward
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
) )
) )
@ -428,10 +406,11 @@ def on_state_after(celeste, before_out):
print("") print("")
# Perform a training step
loss = None loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE: if len(memory) >= BATCH_SIZE:
n_steps += 1
loss = optimize_model() loss = optimize_model()
# Soft update target_net weights # Soft update target_net weights
@ -444,43 +423,65 @@ def on_state_after(celeste, before_out):
) )
target_net.load_state_dict(target_net_state) target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
# Move on to the next episode and run if (next_state.deaths != 0 or finished_stage):
# housekeeping tasks.
if (dead or done):
s = celeste.state s = celeste.state
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
with model_train_log.open("a") as f: with model_train_log.open("a") as f:
f.write(json.dumps({ f.write(json.dumps({
"n_episodes": n_episodes,
"n_steps": n_steps,
"checkpoints": s.next_point, "checkpoints": s.next_point,
"loss": None if loss is None else loss.item(), "state_count": s.state_count,
"done": done "loss": None if loss is None else loss.item()
}) + "\n") }) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"point_counter": point_counter,
"episode_number": episode_number,
"steps_done": steps_done,
# Hyperparameters
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, model_save_path)
# Clean up screenshots
shots = screenshot_source.glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a snapshot # Save a snapshot
if n_episodes % model_save_interval == 0: if episode_number % archive_interval == 0:
save_model(model_archive_dir / f"{n_episodes}.torch") torch.save({
shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path) "policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_archive_dir / f"{episode_number}.torch")
print("Game over. Resetting.") print("Game over. Resetting.")
episode_number += 1
celeste.reset() celeste.reset()
if __name__ == "__main__": if __name__ == "__main__":
c = Celeste( c = Celeste(
"resources/pico-8/linux/pico8" "resources/pico-8/linux/pico8"

View File

@ -1,69 +0,0 @@
from pathlib import Path
import shutil
class ScreenshotManager:
def __init__(
self,
# Where PICO-8 saves screenshots
source: Path,
# How PICO-8 names screenshots.
# Example: "celeste_*.png"
pattern: str,
# Where we want to move screenshots.
target: Path
):
self.source = source
self.pattern = pattern
self.target = target
self.target.mkdir(
parents = True,
exist_ok = True
)
def clean(self):
shots = self.source.glob(self.pattern)
for s in shots:
s.unlink()
return self
def move(self, number: int | None = None, overwrite = False):
shots = self.source.glob(self.pattern)
if number == None:
# Auto-select new directory number.
# Chooses next highest int directory name
number = 0
for f in self.target.iterdir():
try:
number = max(
int(f.name),
number
)
except ValueError:
continue
number += 1
else:
target = self.target / str(number)
if target.exists():
if not overwrite:
raise Exception(f"Target \"{target}\" exists!")
else:
print(f"Target \"{target}\" exists, removing.")
shutil.rmtree(target)
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
return self

View File

@ -47,14 +47,6 @@ plots = {
if __name__ == "__main__": if __name__ == "__main__":
if plots["best"]:
print("Making best-action plots...")
with Pool(5) as p:
p.map(
plot_best,
list((m / "model_archive").iterdir())
)
if plots["prediction"]: if plots["prediction"]:
print("Making prediction plots...") print("Making prediction plots...")
with Pool(5) as p: with Pool(5) as p:
@ -63,6 +55,14 @@ if __name__ == "__main__":
list((m / "model_archive").iterdir()) list((m / "model_archive").iterdir())
) )
if plots["best"]:
print("Making best-action plots...")
with Pool(5) as p:
p.map(
plot_best,
list((m / "model_archive").iterdir())
)
if plots["actual"]: if plots["actual"]:
print("Making actual plots...") print("Making actual plots...")
with Pool(5) as p: with Pool(5) as p:

View File

@ -30,16 +30,6 @@ k_jump=4
k_dash=5 k_dash=5
-- Set to false while training or running the model.
-- Set to true to play the game manually with debug print.
-- (good for finding coordinates of checkpoints)
--
-- If true, disables most hack features:
-- - screenshots at every frame
-- - frame skipping
-- - waiting for input
hack_human_mode = false
-- If true, disable screensake -- If true, disable screensake
hack_no_shake = true hack_no_shake = true
@ -1219,10 +1209,6 @@ end
-- _update60 does 60 fps -- _update60 does 60 fps
-- default for celeste is 30. -- default for celeste is 30.
function _update() function _update()
if hack_human_mode then
old_update()
return
end
-- Run at full speed until ready -- Run at full speed until ready
if not hack_ready then if not hack_ready then
@ -1318,10 +1304,7 @@ end
-- Called at the same rate as _update, -- Called at the same rate as _update,
-- but not necessarily at the same time. -- but not necessarily at the same time.
function _draw() function _draw()
if hack_human_mode then --old_draw()
old_draw()
return
end
end end
function old_update() function old_update()