2023-02-15 22:24:40 -08:00
|
|
|
from collections import namedtuple
|
|
|
|
from collections import deque
|
2023-02-18 19:28:02 -08:00
|
|
|
from pathlib import Path
|
2023-02-15 22:24:40 -08:00
|
|
|
import random
|
2023-02-15 19:24:19 -08:00
|
|
|
import math
|
2023-02-18 19:28:02 -08:00
|
|
|
import json
|
2023-02-15 22:24:40 -08:00
|
|
|
import torch
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
from celeste import Celeste
|
|
|
|
|
|
|
|
|
2023-02-18 19:35:46 -08:00
|
|
|
# Where to read/write model data.
|
|
|
|
model_data_root = Path("model_data")
|
|
|
|
|
|
|
|
model_save_path = model_data_root / "model.torch"
|
|
|
|
model_archive_dir = model_data_root / "model_archive"
|
|
|
|
model_train_log = model_data_root / "train_log"
|
|
|
|
screenshot_dir = model_data_root / "screenshots"
|
|
|
|
model_data_root.mkdir(parents = True, exist_ok = True)
|
|
|
|
model_archive_dir.mkdir(parents = True, exist_ok = True)
|
|
|
|
screenshot_dir.mkdir(parents = True, exist_ok = True)
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
compute_device = torch.device(
|
|
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Celeste env properties
|
2023-02-18 19:28:02 -08:00
|
|
|
n_observations = len(Celeste.state_number_map)
|
2023-02-15 23:38:27 -08:00
|
|
|
n_actions = len(Celeste.action_space)
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
# Epsilon-greedy parameters
|
|
|
|
#
|
|
|
|
# Original docs:
|
|
|
|
# EPS_START is the starting value of epsilon
|
|
|
|
# EPS_END is the final value of epsilon
|
|
|
|
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
|
|
|
EPS_START = 0.9
|
|
|
|
EPS_END = 0.05
|
|
|
|
EPS_DECAY = 1000
|
|
|
|
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
BATCH_SIZE = 1_000
|
2023-02-15 23:38:27 -08:00
|
|
|
# Learning rate of target_net.
|
|
|
|
# Controls how soft our soft update is.
|
2023-02-16 13:52:59 -08:00
|
|
|
#
|
2023-02-15 23:38:27 -08:00
|
|
|
# Should be between 0 and 1.
|
2023-02-16 13:52:59 -08:00
|
|
|
# Large values
|
2023-02-15 23:38:27 -08:00
|
|
|
# Small values do the opposite.
|
|
|
|
#
|
|
|
|
# A value of one makes target_net
|
|
|
|
# change at the same rate as policy_net.
|
|
|
|
#
|
|
|
|
# A value of zero makes target_net
|
|
|
|
# not change at all.
|
|
|
|
TAU = 0.005
|
|
|
|
|
|
|
|
|
|
|
|
# GAMMA is the discount factor as mentioned in the previous section
|
|
|
|
GAMMA = 0.99
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
# Outline our network
|
|
|
|
class DQN(torch.nn.Module):
|
|
|
|
def __init__(self, n_observations: int, n_actions: int):
|
|
|
|
super(DQN, self).__init__()
|
2023-02-18 19:28:02 -08:00
|
|
|
|
|
|
|
self.layers = torch.nn.Sequential(
|
|
|
|
torch.nn.Linear(n_observations, 128),
|
|
|
|
torch.nn.ReLU(),
|
|
|
|
|
|
|
|
torch.nn.Linear(128, 128),
|
|
|
|
torch.nn.ReLU(),
|
|
|
|
|
|
|
|
torch.nn.Linear(128, 128),
|
|
|
|
torch.nn.ReLU(),
|
|
|
|
|
|
|
|
torch.torch.nn.Linear(128, n_actions)
|
|
|
|
)
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
# Can be called with one input, or with a batch.
|
|
|
|
#
|
|
|
|
# Returns tensor(
|
|
|
|
# [ Q(s, left), Q(s, right) ], ...
|
|
|
|
# )
|
|
|
|
#
|
|
|
|
# Recall that Q(s, a) is the (expected) return of taking
|
|
|
|
# action `a` at state `s`
|
|
|
|
def forward(self, x):
|
2023-02-18 19:28:02 -08:00
|
|
|
return self.layers(x)
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
steps_done = 0
|
|
|
|
|
|
|
|
num_episodes = 100
|
|
|
|
|
|
|
|
|
|
|
|
# Create replay memory.
|
|
|
|
#
|
|
|
|
# Transition: a container for naming data (defined in util.py)
|
|
|
|
# Memory: a deque that holds recent states as Transitions
|
|
|
|
# Has a fixed length, drops oldest
|
|
|
|
# element if maxlen is exceeded.
|
2023-02-18 19:28:02 -08:00
|
|
|
memory = deque([], maxlen=100_000)
|
2023-02-15 23:38:27 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
policy_net = DQN(
|
|
|
|
n_observations,
|
|
|
|
n_actions
|
|
|
|
).to(compute_device)
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
target_net = DQN(
|
|
|
|
n_observations,
|
|
|
|
n_actions
|
|
|
|
).to(compute_device)
|
|
|
|
|
|
|
|
target_net.load_state_dict(policy_net.state_dict())
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
|
policy_net.parameters(),
|
2023-02-18 19:28:02 -08:00
|
|
|
lr = 0.01, # Hyperparameter: learning rate
|
2023-02-15 23:38:27 -08:00
|
|
|
amsgrad = True
|
|
|
|
)
|
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
def select_action(state, steps_done):
|
|
|
|
"""
|
|
|
|
Select an action using an epsilon-greedy policy.
|
|
|
|
|
|
|
|
Sometimes use our model, sometimes sample one uniformly.
|
|
|
|
|
|
|
|
P(random action) starts at EPS_START and decays to EPS_END.
|
|
|
|
Decay rate is controlled by EPS_DECAY.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Random number 0 <= x < 1
|
|
|
|
sample = random.random()
|
|
|
|
|
|
|
|
# Calculate random step threshhold
|
|
|
|
eps_threshold = (
|
|
|
|
EPS_END + (EPS_START - EPS_END) *
|
|
|
|
math.exp(
|
|
|
|
-1.0 * steps_done /
|
|
|
|
EPS_DECAY
|
2023-02-15 19:24:19 -08:00
|
|
|
)
|
2023-02-15 22:24:40 -08:00
|
|
|
)
|
|
|
|
|
|
|
|
if sample > eps_threshold:
|
|
|
|
with torch.no_grad():
|
|
|
|
# t.max(1) will return the largest column value of each row.
|
|
|
|
# second column on max result is index of where max element was
|
|
|
|
# found, so we pick action with the larger expected reward.
|
|
|
|
return policy_net(state).max(1)[1].view(1, 1).item()
|
|
|
|
|
|
|
|
else:
|
|
|
|
return random.randint( 0, n_actions-1 )
|
|
|
|
|
|
|
|
|
|
|
|
last_state = None
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
Transition = namedtuple(
|
|
|
|
"Transition",
|
|
|
|
(
|
|
|
|
"state",
|
|
|
|
"action",
|
|
|
|
"next_state",
|
|
|
|
"reward"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
def optimize_model():
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
if len(memory) < BATCH_SIZE:
|
|
|
|
raise Exception(f"Not enough elements in memory for a batch of {BATCH_SIZE}")
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Get a random sample of transitions
|
|
|
|
batch = random.sample(memory, BATCH_SIZE)
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Conversion.
|
|
|
|
# Transposes batch, turning an array of Transitions
|
|
|
|
# into a Transition of arrays.
|
|
|
|
batch = Transition(*zip(*batch))
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Conversion.
|
|
|
|
# Combine states, actions, and rewards into their own tensors.
|
|
|
|
state_batch = torch.cat(batch.state)
|
|
|
|
action_batch = torch.cat(batch.action)
|
|
|
|
reward_batch = torch.cat(batch.reward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compute a mask of non_final_states.
|
|
|
|
# Each element of this tensor corresponds to an element in the batch.
|
|
|
|
# True if this is a final state, False if it is.
|
|
|
|
#
|
|
|
|
# We use this to select non-final states later.
|
|
|
|
non_final_mask = torch.tensor(
|
|
|
|
tuple(map(
|
|
|
|
lambda s: s is not None,
|
|
|
|
batch.next_state
|
|
|
|
))
|
|
|
|
)
|
|
|
|
|
|
|
|
non_final_next_states = torch.cat(
|
|
|
|
[s for s in batch.next_state if s is not None]
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# How .gather works:
|
|
|
|
# if out = a.gather(1, b),
|
|
|
|
# out[i, j] = a[ i ][ b[i,j] ]
|
|
|
|
#
|
|
|
|
# a is "input," b is "index"
|
|
|
|
# If this doesn't make sense, RTFD.
|
|
|
|
|
|
|
|
# Compute Q(s_t, a).
|
|
|
|
# - Use policy_net to compute Q(s_t) for each state in the batch.
|
|
|
|
# This gives a tensor of [ Q(state, left), Q(state, right) ]
|
|
|
|
#
|
|
|
|
# - Action batch is a tensor that looks like [ [0], [1], [1], ... ]
|
|
|
|
# listing the action that was taken in each transition.
|
|
|
|
# 0 => we went left, 1 => we went right.
|
|
|
|
#
|
|
|
|
# This aligns nicely with the output of policy_net. We use
|
|
|
|
# action_batch to index the output of policy_net's prediction.
|
|
|
|
#
|
|
|
|
# This gives us a tensor that contains the return we expect to get
|
|
|
|
# at that state if we follow the model's advice.
|
|
|
|
|
|
|
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compute V(s_t+1) for all next states.
|
|
|
|
# V(s_t+1) = max_a ( Q(s_t+1, a) )
|
|
|
|
# = the maximum reward over all possible actions at state s_t+1.
|
|
|
|
next_state_values = torch.zeros(BATCH_SIZE, device = compute_device)
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Don't compute gradient for operations in this block.
|
|
|
|
# If you don't understand what this means, RTFD.
|
|
|
|
with torch.no_grad():
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Note the use of non_final_mask here.
|
2023-02-16 13:52:59 -08:00
|
|
|
# States that are final do not have their reward set by the line
|
2023-02-15 23:38:27 -08:00
|
|
|
# below, so their reward stays at zero.
|
|
|
|
#
|
|
|
|
# States that are not final get their predicted value
|
|
|
|
# set to the best value the model predicts.
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# Expected values of action are selected with the "older" target net,
|
|
|
|
# and their best reward (over possible actions) is selected with max(1)[0].
|
|
|
|
|
|
|
|
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: What does this mean?
|
|
|
|
# "Compute expected Q values"
|
|
|
|
expected_state_action_values = reward_batch + (next_state_values * GAMMA)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Compute Huber loss between predicted reward and expected reward.
|
|
|
|
# Pytorch is will account for this when we compute the gradient of loss.
|
|
|
|
#
|
|
|
|
# loss is a single-element tensor (i.e, a scalar).
|
|
|
|
criterion = torch.nn.SmoothL1Loss()
|
|
|
|
loss = criterion(
|
|
|
|
state_action_values,
|
|
|
|
expected_state_action_values.unsqueeze(1)
|
|
|
|
)
|
|
|
|
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
|
|
|
|
# We can now run a step of backpropagation on our model.
|
|
|
|
|
|
|
|
# TODO: what does this do?
|
|
|
|
#
|
|
|
|
# Calling .backward() multiple times will accumulate parameter gradients.
|
|
|
|
# Thus, we reset the gradient before each step.
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
# Compute the gradient of loss wrt... something?
|
|
|
|
# TODO: what does this do, we never use loss again?!
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
# Prevent vanishing and exploding gradients.
|
|
|
|
# Forces gradients to be in [-clip_value, +clip_value]
|
|
|
|
torch.nn.utils.clip_grad_value_( # type: ignore
|
|
|
|
policy_net.parameters(),
|
|
|
|
clip_value = 100
|
|
|
|
)
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Perform a single optimizer step.
|
|
|
|
#
|
|
|
|
# Uses the current gradient, which is stored
|
|
|
|
# in the .grad attribute of the parameter.
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
episode_number = 0
|
|
|
|
|
|
|
|
|
2023-02-18 19:35:46 -08:00
|
|
|
if model_save_path.is_file():
|
2023-02-18 19:28:02 -08:00
|
|
|
# Load model if one exists
|
2023-02-18 19:35:46 -08:00
|
|
|
checkpoint = torch.load(model_save_path)
|
2023-02-18 19:28:02 -08:00
|
|
|
policy_net.load_state_dict(checkpoint["policy_state_dict"])
|
|
|
|
target_net.load_state_dict(checkpoint["target_state_dict"])
|
|
|
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
|
memory = checkpoint["memory"]
|
|
|
|
episode_number = checkpoint["episode_number"] + 1
|
|
|
|
steps_done = checkpoint["steps_done"]
|
|
|
|
|
|
|
|
|
2023-02-16 12:11:04 -08:00
|
|
|
def on_state_before(celeste):
|
2023-02-15 23:38:27 -08:00
|
|
|
global steps_done
|
|
|
|
|
|
|
|
# Conversion to pytorch
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
state = celeste.state
|
2023-02-15 23:38:27 -08:00
|
|
|
|
|
|
|
pt_state = torch.tensor(
|
2023-02-18 19:28:02 -08:00
|
|
|
[getattr(state, x) for x in Celeste.state_number_map],
|
2023-02-15 22:24:40 -08:00
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
action = None
|
|
|
|
while (action) is None or ((not state.can_dash) and (str_action not in ["left", "right"])):
|
|
|
|
action = select_action(
|
|
|
|
pt_state,
|
|
|
|
steps_done
|
|
|
|
)
|
|
|
|
str_action = Celeste.action_space[action]
|
2023-02-15 23:38:27 -08:00
|
|
|
steps_done += 1
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-16 12:11:04 -08:00
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
# For manual testing
|
|
|
|
#str_action = ""
|
|
|
|
#while str_action not in Celeste.action_space:
|
|
|
|
# str_action = input("action> ")
|
|
|
|
#action = Celeste.action_space.index(str_action)
|
2023-02-16 12:11:04 -08:00
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
print(str_action)
|
2023-02-16 12:11:04 -08:00
|
|
|
celeste.act(str_action)
|
|
|
|
|
|
|
|
return state, action
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
|
|
|
|
|
|
|
|
image_interval = 10
|
|
|
|
|
|
|
|
|
2023-02-16 12:11:04 -08:00
|
|
|
def on_state_after(celeste, before_out):
|
2023-02-18 19:28:02 -08:00
|
|
|
global episode_number
|
|
|
|
global image_count
|
2023-02-16 12:11:04 -08:00
|
|
|
|
|
|
|
state, action = before_out
|
2023-02-18 19:28:02 -08:00
|
|
|
next_state = celeste.state
|
2023-02-16 12:11:04 -08:00
|
|
|
|
|
|
|
pt_state = torch.tensor(
|
2023-02-18 19:28:02 -08:00
|
|
|
[getattr(state, x) for x in Celeste.state_number_map],
|
2023-02-16 12:11:04 -08:00
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
pt_action = torch.tensor(
|
|
|
|
[[ action ]],
|
|
|
|
device = compute_device,
|
|
|
|
dtype = torch.long
|
|
|
|
)
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
if next_state.deaths != 0:
|
2023-02-15 23:38:27 -08:00
|
|
|
pt_next_state = None
|
|
|
|
reward = 0
|
|
|
|
|
|
|
|
else:
|
|
|
|
pt_next_state = torch.tensor(
|
2023-02-18 19:28:02 -08:00
|
|
|
[getattr(next_state, x) for x in Celeste.state_number_map],
|
2023-02-15 23:38:27 -08:00
|
|
|
dtype = torch.float32,
|
|
|
|
device = compute_device
|
|
|
|
).unsqueeze(0)
|
|
|
|
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
if state.next_point == next_state.next_point:
|
|
|
|
reward = state.dist - next_state.dist
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
# Clip rewards that are too large
|
|
|
|
if reward > 1:
|
2023-02-16 13:52:59 -08:00
|
|
|
reward = 1
|
|
|
|
else:
|
|
|
|
reward = 0
|
2023-02-18 19:28:02 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
else:
|
|
|
|
# Score for reaching a point
|
2023-02-18 19:28:02 -08:00
|
|
|
reward = 1
|
2023-02-16 13:52:59 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
pt_reward = torch.tensor([reward], device = compute_device)
|
|
|
|
|
|
|
|
|
|
|
|
# Add this state transition to memory.
|
|
|
|
memory.append(
|
|
|
|
Transition(
|
|
|
|
pt_state, # last state
|
|
|
|
pt_action,
|
|
|
|
pt_next_state, # next state
|
|
|
|
pt_reward
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-02-18 19:28:02 -08:00
|
|
|
print("==> ", int(reward))
|
|
|
|
print("\n")
|
2023-02-15 22:24:40 -08:00
|
|
|
|
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Only train the network if we have enough
|
|
|
|
# transitions in memory to do so.
|
|
|
|
if len(memory) >= BATCH_SIZE:
|
|
|
|
optimize_model()
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Soft update target_net weights
|
|
|
|
target_net_state = target_net.state_dict()
|
|
|
|
policy_net_state = policy_net.state_dict()
|
|
|
|
for key in policy_net_state:
|
|
|
|
target_net_state[key] = (
|
|
|
|
policy_net_state[key] * TAU +
|
|
|
|
target_net_state[key] * (1-TAU)
|
|
|
|
)
|
|
|
|
target_net.load_state_dict(target_net_state)
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 23:38:27 -08:00
|
|
|
# Move on to the next episode once we reach
|
|
|
|
# a terminal state.
|
2023-02-18 19:28:02 -08:00
|
|
|
if (next_state.deaths != 0):
|
|
|
|
s = celeste.state
|
2023-02-18 19:35:46 -08:00
|
|
|
with model_train_log.open("a") as f:
|
2023-02-18 19:28:02 -08:00
|
|
|
f.write(json.dumps({
|
|
|
|
"checkpoints": s.next_point,
|
|
|
|
"state_count": s.state_count
|
|
|
|
}) + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
# Save model
|
|
|
|
torch.save({
|
|
|
|
"policy_state_dict": policy_net.state_dict(),
|
|
|
|
"target_state_dict": target_net.state_dict(),
|
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
|
"memory": memory,
|
|
|
|
"episode_number": episode_number,
|
|
|
|
"steps_done": steps_done
|
2023-02-18 19:35:46 -08:00
|
|
|
}, model_save_path)
|
2023-02-18 19:28:02 -08:00
|
|
|
|
|
|
|
|
|
|
|
# Clean up screenshots
|
|
|
|
shots = Path("/home/mark/Desktop").glob("hackcel_*.png")
|
|
|
|
|
2023-02-18 19:35:46 -08:00
|
|
|
target = screenshot_dir / Path(f"{episode_number}")
|
2023-02-18 19:28:02 -08:00
|
|
|
target.mkdir(parents = True)
|
|
|
|
|
|
|
|
for s in shots:
|
|
|
|
s.rename(target / s.name)
|
|
|
|
|
|
|
|
# Save a prediction graph
|
|
|
|
if episode_number % image_interval == 0:
|
|
|
|
torch.save({
|
|
|
|
"policy_state_dict": policy_net.state_dict(),
|
|
|
|
"target_state_dict": target_net.state_dict(),
|
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
|
"memory": memory,
|
|
|
|
"episode_number": episode_number,
|
|
|
|
"steps_done": steps_done
|
2023-02-18 19:35:46 -08:00
|
|
|
}, model_archive_dir / f"{episode_number}.torch")
|
2023-02-18 19:28:02 -08:00
|
|
|
|
|
|
|
|
2023-02-18 19:35:46 -08:00
|
|
|
print("Game over. Resetting.")
|
2023-02-18 19:28:02 -08:00
|
|
|
episode_number += 1
|
2023-02-15 23:38:27 -08:00
|
|
|
celeste.reset()
|
2023-02-15 22:24:40 -08:00
|
|
|
|
2023-02-15 19:24:19 -08:00
|
|
|
|
|
|
|
|
2023-02-16 12:11:04 -08:00
|
|
|
c = Celeste()
|
2023-02-15 19:24:19 -08:00
|
|
|
|
2023-02-16 12:11:04 -08:00
|
|
|
c.update_loop(
|
|
|
|
on_state_before,
|
|
|
|
on_state_after
|
|
|
|
)
|