Mark
/
celeste-ai
Archived
1
0
Fork 0

Compare commits

...

7 Commits

9 changed files with 390 additions and 288 deletions

View File

@ -70,21 +70,24 @@ class Celeste:
#"ypos", #"ypos",
"xpos_scaled", "xpos_scaled",
"ypos_scaled", "ypos_scaled",
"can_dash_int" #"can_dash_int"
#"next_point_x", #"next_point_x",
#"next_point_y" #"next_point_y"
] ]
# Targets the agent tries to reach. # Targets the agent tries to reach.
# The last target MUST be outside the frame. # The last target MUST be outside the frame.
# Format is X, Y, range, force_y
# force_y is optional. If true, y_value MUST match perfectly.
target_checkpoints = [ target_checkpoints = [
[ # Stage 1 [ # Stage 1
#(28, 88), # Start pillar #(28, 88, 8), # Start pillar
(60, 80), # Middle pillar (60, 80, 8), # Middle pillar
(105, 64), # Right ledge (105, 64, 8), # Right ledge
(25, 40), # Left ledge (25, 40, 8), # Left ledge
(110, 16), # End ledge (97, 24, 5, True), # Small end ledge
(110, -2), # Next stage (110, 16, 8), # End ledge
(110, -20, 8), # Next stage
] ]
] ]
@ -99,7 +102,7 @@ class Celeste:
self, self,
pico_path, pico_path,
*, *,
state_timeout = 30, state_timeout = 20,
cart_name = "hackcel.p8", cart_name = "hackcel.p8",
): ):
@ -144,7 +147,7 @@ class Celeste:
self._resetting = False # True between a call to .reset() and the first state message from pico. self._resetting = False # True between a call to .reset() and the first state message from pico.
self._keys = {} # Dictionary of "key": bool self._keys = {} # Dictionary of "key": bool
def act(self, action: str): def act(self, action: str | int):
""" """
Specify what keys should be down. This does NOT send key events. Specify what keys should be down. This does NOT send key events.
Celeste._apply_keys() does that at the right time. Celeste._apply_keys() does that at the right time.
@ -153,6 +156,9 @@ class Celeste:
action (str): key name, as in Celeste.action_space action (str): key name, as in Celeste.action_space
""" """
if isinstance(action, int):
action = Celeste.action_space[action]
self._keys = {} self._keys = {}
if action is None: if action is None:
return return
@ -208,9 +214,9 @@ class Celeste:
[int(self._internal_state["rx"])] [int(self._internal_state["rx"])]
) )
if len(Celeste.target_checkpoints) < stage: if len(Celeste.target_checkpoints) <= stage:
next_point_x = None next_point_x = 0
next_point_y = None next_point_y = 0
else: else:
next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0] next_point_x = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][0]
next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1] next_point_y = Celeste.target_checkpoints[stage][self._next_checkpoint_idx][1]
@ -329,7 +335,7 @@ class Celeste:
if self.state.stage <= 0:
# Calculate distance to each point # Calculate distance to each point
x = self.state.xpos x = self.state.xpos
y = self.state.ypos y = self.state.ypos
@ -340,7 +346,7 @@ class Celeste:
continue continue
# Update checkpoints # Update checkpoints
tx, ty = c tx, ty = c[:2]
dist[i] = (math.sqrt( dist[i] = (math.sqrt(
(x-tx)*(x-tx) + (x-tx)*(x-tx) +
((y-ty)*(y-ty))/2 ((y-ty)*(y-ty))/2
@ -351,13 +357,32 @@ class Celeste:
dist = int(dist[min_idx]) dist = int(dist[min_idx])
if dist <= 8: t = Celeste.target_checkpoints[self.state.stage][min_idx]
print(f"Got point {min_idx}") range = t[2]
if len(t) == 3:
force_y = False
else:
force_y = t[3]
if force_y:
got_point = (
dist <= range and
y == t[1]
)
else:
got_point = dist <= range
if got_point:
self._next_checkpoint_idx = min_idx + 1 self._next_checkpoint_idx = min_idx + 1
self._last_checkpoint_state = self._state_counter self._last_checkpoint_state = self._state_counter
# Recalculate distance to new point # Recalculate distance to new point
tx, ty = Celeste.target_checkpoints[self.state.stage][self._next_checkpoint_idx] tx, ty = (
Celeste.target_checkpoints
[self.state.stage]
[self._next_checkpoint_idx]
[:2]
)
dist = math.sqrt( dist = math.sqrt(
(x-tx)*(x-tx) + (x-tx)*(x-tx) +
((y-ty)*(y-ty))/2 ((y-ty)*(y-ty))/2

View File

@ -5,7 +5,7 @@ from collections import namedtuple
Transition = namedtuple( Transition = namedtuple(
"Transition", "Transition",
( (
"state", "last_state",
"action", "action",
"next_state", "next_state",
"reward" "reward"

View File

@ -1,6 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# All of the following are required to load # All of the following are required to load
@ -34,7 +35,7 @@ def best_action(
# Compute preditions # Compute preditions
p = np.zeros((128, 128, 2), dtype=np.float32) p = np.zeros((128, 128), dtype=np.float32)
with torch.no_grad(): with torch.no_grad():
for r in range(len(p)): for r in range(len(p)):
for c in range(len(p[r])): for c in range(len(p[r])):
@ -43,26 +44,31 @@ def best_action(
k = np.asarray(policy_net( k = np.asarray(policy_net(
torch.tensor( torch.tensor(
[x, y, 0], [x, y],
dtype = torch.float32, dtype = torch.float32,
device = device device = device
).unsqueeze(0) ).unsqueeze(0)
)[0]) )[0])
p[r][c][0] = np.argmax(k) p[r][c] = np.argmax(k)
k = np.asarray(policy_net(
torch.tensor(
[x, y, 1],
dtype = torch.float32,
device = device
).unsqueeze(0)
)[0])
p[r][c][1] = np.argmax(k)
cmap = mpl.colors.ListedColormap(
[
"forestgreen",
"firebrick",
"lightgreen",
"salmon",
"darkturquoise",
"sandybrown",
"olive",
"darkorchid",
"mediumvioletred"
]
)
# Plot predictions # Plot predictions
fig, axs = plt.subplots(1, 2, figsize = (10, 10)) fig, axs = plt.subplots(1, 1, figsize = (20, 20))
ax = axs[0] ax = axs
ax.set( ax.set(
adjustable = "box", adjustable = "box",
aspect = "equal", aspect = "equal",
@ -70,30 +76,16 @@ def best_action(
) )
plot = ax.pcolor( plot = ax.pcolor(
p[:,:,0], p,
cmap = "Set1", cmap = cmap,
vmin = 0, vmin = 0,
vmax = 8 vmax = 8
) )
ax.invert_yaxis() ax.invert_yaxis()
fig.colorbar(plot) cbar = fig.colorbar(plot, ticks = list(range(0, 9)))
cbar.ax.set_yticklabels(Celeste.action_space)
ax = axs[1]
ax.set(
adjustable = "box",
aspect = "equal",
title = "Best Action"
)
plot = ax.pcolor(
p[:,:,0],
cmap = "Set1",
vmin = 0,
vmax = 8
)
ax.invert_yaxis()
fig.colorbar(plot)
fig.savefig(out_filename) fig.savefig(out_filename)
plt.close() plt.close()

View File

@ -43,7 +43,7 @@ def predicted_reward(
k = np.asarray(policy_net( k = np.asarray(policy_net(
torch.tensor( torch.tensor(
[x, y, 0], [x, y],
dtype = torch.float32, dtype = torch.float32,
device = device device = device
).unsqueeze(0) ).unsqueeze(0)

View File

@ -5,33 +5,31 @@ import random
import math import math
import json import json
import torch import torch
import shutil
from celeste_ai import Celeste from celeste_ai import Celeste
from celeste_ai import DQN from celeste_ai import DQN
from celeste_ai import Transition from celeste_ai import Transition
from celeste_ai.util.screenshots import ScreenshotManager
if __name__ == "__main__": if __name__ == "__main__":
# Where to read/write model data. # Where to read/write model data.
model_data_root = Path("model_data/current") model_data_root = Path("model_data/current")
sm = ScreenshotManager(
# Where PICO-8 saves screenshots. # Where PICO-8 saves screenshots.
# Probably your desktop. # Probably your desktop.
screenshot_source = Path("/home/mark/Desktop") source = Path("/home/mark/Desktop"),
pattern = "hackcel_*.png",
target = model_data_root / "screenshots"
).clean() # Remove old screenshots
model_save_path = model_data_root / "model.torch" model_save_path = model_data_root / "model.torch"
model_archive_dir = model_data_root / "model_archive" model_archive_dir = model_data_root / "model_archive"
model_train_log = model_data_root / "train_log" model_train_log = model_data_root / "train_log"
screenshot_dir = model_data_root / "screenshots"
model_data_root.mkdir(parents = True, exist_ok = True) model_data_root.mkdir(parents = True, exist_ok = True)
model_archive_dir.mkdir(parents = True, exist_ok = True) model_archive_dir.mkdir(parents = True, exist_ok = True)
screenshot_dir.mkdir(parents = True, exist_ok = True)
# Remove old screenshots
shots = screenshot_source.glob("hackcel_*.png")
for s in shots:
s.unlink()
compute_device = torch.device( compute_device = torch.device(
@ -45,66 +43,51 @@ if __name__ == "__main__":
# Epsilon-greedy parameters # Epsilon-greedy parameters
# # Probability of choosing a random action starts at
# Original docs: # EPS_START and decays to EPS_END.
# EPS_START is the starting value of epsilon # EPS_DECAY controls the rate of decay.
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
EPS_START = 0.9 EPS_START = 0.9
EPS_END = 0.02 EPS_END = 0.02
EPS_DECAY = 100 EPS_DECAY = 100
# How many times we've reached each point. # Bellman equation time-discount factor
# Used to compute epsilon-greedy probability with
# the parameters above.
point_counter = [0] * len(Celeste.target_checkpoints[0])
BATCH_SIZE = 100
# Learning rate of target_net.
# Controls how soft our soft update is.
#
# Should be between 0 and 1.
# Large values
# Small values do the opposite.
#
# A value of one makes target_net
# change at the same rate as policy_net.
#
# A value of zero makes target_net
# not change at all.
TAU = 0.05
# GAMMA is the discount factor as mentioned in the previous section
GAMMA = 0.9 GAMMA = 0.9
steps_done = 0 # Train on this many transitions from
num_episodes = 100 # replay memory each round
episode_number = 0 BATCH_SIZE = 100
archive_interval = 10
# Controls target_net soft update.
# Should be between 0 and 1.
TAU = 0.05
# Optimizer learning rate
learning_rate = 0.001
# Save a snapshot of the model every n
# episodes.
model_save_interval = 10
# How many times we've reached each point.
# This is used to compute epsilon-greedy probability.
point_counter = [0] * len(Celeste.target_checkpoints[0])
n_episodes = 0 # Number of episodes we've trained on
n_steps = 0 # Number of training steps we've completed
# Create replay memory. # Create replay memory.
# #
# Transition: a container for naming data (defined in util.py) # Holds <Transition> objects, defined in
# Memory: a deque that holds recent states as Transitions # network.py
# Has a fixed length, drops oldest
# element if maxlen is exceeded.
memory = deque([], maxlen=50_000) memory = deque([], maxlen=50_000)
policy_net = DQN(
n_observations,
n_actions
).to(compute_device)
target_net = DQN(
n_observations,
n_actions
).to(compute_device)
policy_net = DQN(n_observations, n_actions).to(compute_device)
target_net = DQN(n_observations, n_actions).to(compute_device)
target_net.load_state_dict(policy_net.state_dict()) target_net.load_state_dict(policy_net.state_dict())
learning_rate = 0.001
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
policy_net.parameters(), policy_net.parameters(),
lr = learning_rate, lr = learning_rate,
@ -122,11 +105,43 @@ if __name__ == "__main__":
target_net.load_state_dict(checkpoint["target_state_dict"]) target_net.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
memory = checkpoint["memory"] memory = checkpoint["memory"]
episode_number = checkpoint["episode_number"] + 1
steps_done = checkpoint["steps_done"] n_episodes = checkpoint["n_episodes"]
n_steps = checkpoint["n_steps"]
point_counter = checkpoint["point_counter"] point_counter = checkpoint["point_counter"]
def select_action(state, steps_done):
def save_model(path):
torch.save({
# Newtorks
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
# Training data
"memory": memory,
"point_counter": point_counter,
"n_episodes": n_episodes,
"n_steps": n_steps,
# Hyperparameters,
# for reference
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, path
)
def select_action(state, x) -> int:
""" """
Select an action using an epsilon-greedy policy. Select an action using an epsilon-greedy policy.
@ -136,19 +151,13 @@ def select_action(state, steps_done):
Decay rate is controlled by EPS_DECAY. Decay rate is controlled by EPS_DECAY.
""" """
# Random number 0 <= x < 1
sample = random.random()
# Calculate random step threshhold # Calculate random step threshhold
eps_threshold = ( eps_threshold = (
EPS_END + (EPS_START - EPS_END) * EPS_END + (EPS_START - EPS_END) *
math.exp( math.exp(-1.0 * x / EPS_DECAY)
-1.0 * steps_done /
EPS_DECAY
)
) )
if sample > eps_threshold: if random.random() > eps_threshold:
with torch.no_grad(): with torch.no_grad():
# t.max(1) will return the largest column value of each row. # t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was # second column on max result is index of where max element was
@ -175,7 +184,7 @@ def optimize_model():
# Conversion. # Conversion.
# Combine states, actions, and rewards into their own tensors. # Combine states, actions, and rewards into their own tensors.
state_batch = torch.cat(batch.state) last_state_batch = torch.cat(batch.last_state)
action_batch = torch.cat(batch.action) action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward) reward_batch = torch.cat(batch.reward)
@ -209,7 +218,7 @@ def optimize_model():
# This gives us a tensor that contains the return we expect to get # This gives us a tensor that contains the return we expect to get
# at that state if we follow the model's advice. # at that state if we follow the model's advice.
state_action_values = policy_net(state_batch).gather(1, action_batch) state_action_values = policy_net(last_state_batch).gather(1, action_batch)
@ -282,36 +291,21 @@ def optimize_model():
def on_state_before(celeste): def on_state_before(celeste):
global steps_done
state = celeste.state state = celeste.state
pt_state = torch.tensor(
action = select_action(
# Put state in a tensor
torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map], [getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32, dtype = torch.float32,
device = compute_device device = compute_device
).unsqueeze(0) ).unsqueeze(0),
# Random action probability is determined by
action = select_action( # the number of times we've reached the next point.
pt_state,
point_counter[state.next_point] point_counter[state.next_point]
) )
str_action = Celeste.action_space[action]
"""
action = None
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
action = select_action(
pt_state,
steps_done
)
str_action = Celeste.action_space[action]
"""
steps_done += 1
# For manual testing # For manual testing
#str_action = "" #str_action = ""
@ -319,86 +313,114 @@ def on_state_before(celeste):
# str_action = input("action> ") # str_action = input("action> ")
#action = Celeste.action_space.index(str_action) #action = Celeste.action_space.index(str_action)
print(str_action) print(Celeste.action_space[action])
celeste.act(str_action) celeste.act(action)
return state, action return (
state, # CelesteState
action # Integer
def on_state_after(celeste, before_out):
global episode_number
state, action = before_out
next_state = celeste.state
pt_state = torch.tensor(
[getattr(state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
pt_action = torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
) )
finished_stage = False
def compute_reward(last_state, state):
global point_counter
reward = None
# No reward if dead # No reward if dead
if next_state.deaths != 0: if state.deaths != 0:
pt_next_state = None
reward = 0 reward = 0
# Reward for finishing a stage # Reward for finishing a stage
elif next_state.stage >= 1: elif state.stage >= 1:
finished_stage = True print("FINISHED STAGE!!")
reward = next_state.next_point - state.next_point
# We don't set a fixed reward here because the agent may
# complete the stage before getting all points.
# The below line provides extra reward for taking shortcuts.
reward = state.next_point - last_state.next_point
reward += 1 reward += 1
# Add to point counter # Add to point counter
for i in range(state.next_point, state.next_point + reward): for i in range(last_state.next_point, len(point_counter)):
point_counter[i] += 1 point_counter[i] += 1
# Regular reward # Reward for reaching a checkpoint
else: elif last_state.next_point != state.next_point:
pt_next_state = torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0)
if state.next_point == next_state.next_point:
reward = 0
else:
print(f"Got point {state.next_point}") print(f"Got point {state.next_point}")
# Reward for reaching a point
reward = next_state.next_point - state.next_point reward = state.next_point - last_state.next_point
# Add to point counter # Add to point counter
for i in range(state.next_point, state.next_point + reward): for i in range(last_state.next_point, last_state.next_point + reward):
point_counter[i] += 1 point_counter[i] += 1
# No reward otherwise
else:
reward = 0
# Strawberry reward # Strawberry reward
if next_state.berries[state.stage] and not state.berries[state.stage]: # (Will probably break current version of model)
print(f"Got stage {state.stage} bonus") #if state.berries[state.stage] and not state.berries[state.stage]:
reward += 1 # print(f"Got stage {state.stage} bonus")
# reward += 1
assert reward is not None
return reward * 10
def on_state_after(celeste, before_out):
global n_episodes
global n_steps
last_state, action = before_out
next_state = celeste.state
dead = next_state.deaths != 0
done = next_state.stage >= 1
reward = reward * 10 reward = compute_reward(last_state, next_state)
pt_reward = torch.tensor([reward], device = compute_device)
if dead:
next_state = None
elif done:
# We don't set the next state to None because
# the optimization routine forces zero reward
# for terminal states.
# Copy last state instead. It's a hack, but it
# should work.
next_state = last_state
# Add this state transition to memory. # Add this state transition to memory.
memory.append( memory.append(
Transition( Transition(
pt_state, # last state
pt_action, torch.tensor(
pt_next_state, [getattr(last_state, x) for x in Celeste.state_number_map],
pt_reward dtype = torch.float32,
device = compute_device
).unsqueeze(0),
# action
torch.tensor(
[[ action ]],
device = compute_device,
dtype = torch.long
),
# next state
# None if dead or done.
torch.tensor(
[getattr(next_state, x) for x in Celeste.state_number_map],
dtype = torch.float32,
device = compute_device
).unsqueeze(0) if next_state is not None else None,
# reward
torch.tensor(
[reward],
device = compute_device
)
) )
) )
@ -406,11 +428,10 @@ def on_state_after(celeste, before_out):
print("") print("")
# Perform a training step
loss = None loss = None
# Only train the network if we have enough
# transitions in memory to do so.
if len(memory) >= BATCH_SIZE: if len(memory) >= BATCH_SIZE:
n_steps += 1
loss = optimize_model() loss = optimize_model()
# Soft update target_net weights # Soft update target_net weights
@ -423,65 +444,43 @@ def on_state_after(celeste, before_out):
) )
target_net.load_state_dict(target_net_state) target_net.load_state_dict(target_net_state)
# Move on to the next episode once we reach
# a terminal state.
if (next_state.deaths != 0 or finished_stage): # Move on to the next episode and run
# housekeeping tasks.
if (dead or done):
s = celeste.state s = celeste.state
n_episodes += 1
# Move screenshots
sm.move(
number = n_episodes,
overwrite = True
)
# Log this episode
with model_train_log.open("a") as f: with model_train_log.open("a") as f:
f.write(json.dumps({ f.write(json.dumps({
"n_episodes": n_episodes,
"n_steps": n_steps,
"checkpoints": s.next_point, "checkpoints": s.next_point,
"state_count": s.state_count, "loss": None if loss is None else loss.item(),
"loss": None if loss is None else loss.item() "done": done
}) + "\n") }) + "\n")
# Save model
torch.save({
"policy_state_dict": policy_net.state_dict(),
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"point_counter": point_counter,
"episode_number": episode_number,
"steps_done": steps_done,
# Hyperparameters
"eps_start": EPS_START,
"eps_end": EPS_END,
"eps_decay": EPS_DECAY,
"batch_size": BATCH_SIZE,
"tau": TAU,
"learning_rate": learning_rate,
"gamma": GAMMA
}, model_save_path)
# Clean up screenshots
shots = screenshot_source.glob("hackcel_*.png")
target = screenshot_dir / Path(f"{episode_number}")
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
# Save a snapshot # Save a snapshot
if episode_number % archive_interval == 0: if n_episodes % model_save_interval == 0:
torch.save({ save_model(model_archive_dir / f"{n_episodes}.torch")
"policy_state_dict": policy_net.state_dict(), shutil.copy(model_archive_dir / f"{n_episodes}.torch", model_save_path)
"target_state_dict": target_net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"memory": memory,
"episode_number": episode_number,
"steps_done": steps_done
}, model_archive_dir / f"{episode_number}.torch")
print("Game over. Resetting.") print("Game over. Resetting.")
episode_number += 1
celeste.reset() celeste.reset()
if __name__ == "__main__": if __name__ == "__main__":
c = Celeste( c = Celeste(
"resources/pico-8/linux/pico8" "resources/pico-8/linux/pico8"

View File

View File

@ -0,0 +1,69 @@
from pathlib import Path
import shutil
class ScreenshotManager:
def __init__(
self,
# Where PICO-8 saves screenshots
source: Path,
# How PICO-8 names screenshots.
# Example: "celeste_*.png"
pattern: str,
# Where we want to move screenshots.
target: Path
):
self.source = source
self.pattern = pattern
self.target = target
self.target.mkdir(
parents = True,
exist_ok = True
)
def clean(self):
shots = self.source.glob(self.pattern)
for s in shots:
s.unlink()
return self
def move(self, number: int | None = None, overwrite = False):
shots = self.source.glob(self.pattern)
if number == None:
# Auto-select new directory number.
# Chooses next highest int directory name
number = 0
for f in self.target.iterdir():
try:
number = max(
int(f.name),
number
)
except ValueError:
continue
number += 1
else:
target = self.target / str(number)
if target.exists():
if not overwrite:
raise Exception(f"Target \"{target}\" exists!")
else:
print(f"Target \"{target}\" exists, removing.")
shutil.rmtree(target)
target.mkdir(parents = True)
for s in shots:
s.rename(target / s.name)
return self

View File

@ -47,14 +47,6 @@ plots = {
if __name__ == "__main__": if __name__ == "__main__":
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["best"]: if plots["best"]:
print("Making best-action plots...") print("Making best-action plots...")
with Pool(5) as p: with Pool(5) as p:
@ -63,6 +55,14 @@ if __name__ == "__main__":
list((m / "model_archive").iterdir()) list((m / "model_archive").iterdir())
) )
if plots["prediction"]:
print("Making prediction plots...")
with Pool(5) as p:
p.map(
plot_pred,
list((m / "model_archive").iterdir())
)
if plots["actual"]: if plots["actual"]:
print("Making actual plots...") print("Making actual plots...")
with Pool(5) as p: with Pool(5) as p:

View File

@ -30,6 +30,16 @@ k_jump=4
k_dash=5 k_dash=5
-- Set to false while training or running the model.
-- Set to true to play the game manually with debug print.
-- (good for finding coordinates of checkpoints)
--
-- If true, disables most hack features:
-- - screenshots at every frame
-- - frame skipping
-- - waiting for input
hack_human_mode = false
-- If true, disable screensake -- If true, disable screensake
hack_no_shake = true hack_no_shake = true
@ -1209,6 +1219,10 @@ end
-- _update60 does 60 fps -- _update60 does 60 fps
-- default for celeste is 30. -- default for celeste is 30.
function _update() function _update()
if hack_human_mode then
old_update()
return
end
-- Run at full speed until ready -- Run at full speed until ready
if not hack_ready then if not hack_ready then
@ -1304,7 +1318,10 @@ end
-- Called at the same rate as _update, -- Called at the same rate as _update,
-- but not necessarily at the same time. -- but not necessarily at the same time.
function _draw() function _draw()
--old_draw() if hack_human_mode then
old_draw()
return
end
end end
function old_update() function old_update()