Refactor
This commit is contained in:
13
.editorconfig
Normal file
13
.editorconfig
Normal file
@@ -0,0 +1,13 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = tab
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.md]
|
||||
indent_size = 2
|
||||
indent_style = space
|
||||
801
Cargo.lock
generated
801
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
@@ -75,10 +75,12 @@ compact_str = "0.9.0"
|
||||
dary_heap = "0.3.8"
|
||||
fancy-regex = "0.16.2"
|
||||
indicatif = { version = "0.18.3", features = ["improved_unicode"] }
|
||||
itertools = "0.14.0"
|
||||
futures-util = "0.3.31"
|
||||
ndarray = { version = "0.16.1", features = ["serde"] }
|
||||
parking_lot = "0.12.5"
|
||||
parquet = "56.2.0"
|
||||
rand = "0.9.2"
|
||||
rayon = "1.11.0"
|
||||
reqwest = { version = "0.12.24", features = ["json", "stream"] }
|
||||
serde = "1.0.228"
|
||||
@@ -91,7 +93,11 @@ tracing-indicatif = "0.3.13"
|
||||
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
||||
url = "2.5.7"
|
||||
|
||||
|
||||
burn-train = { git = "https://github.com/tracel-ai/burn.git", default-features = false }
|
||||
|
||||
[workspace.dependencies.burn]
|
||||
version = "0.19.1"
|
||||
#version = "0.19.1"
|
||||
git = "https://github.com/tracel-ai/burn.git"
|
||||
default-features = false
|
||||
features = ["std", "fusion", "ndarray", "webgpu", "cuda"]
|
||||
features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"]
|
||||
|
||||
26
README.md
Normal file
26
README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# LLM from scratch
|
||||
|
||||
## Resources
|
||||
- [Build a Large Language Model](https://www.manning.com/books/build-a-large-language-model-from-scratch)
|
||||
- [Writing an LLM from scratch, part 28](https://www.gilesthomas.com/2025/12/llm-from-scratch-28-training-a-base-model-from-scratch)
|
||||
- [nanochat](https://github.com/karpathy/nanochat)
|
||||
|
||||
## TODO:
|
||||
- chat cli, evaluate each epoch
|
||||
- better arch (read nanochat)
|
||||
- count tokens
|
||||
- download more data (code, full fineweb)
|
||||
|
||||
- Notes
|
||||
- comments
|
||||
|
||||
- TrainTestIterator
|
||||
- total length
|
||||
- deterministic shuffle
|
||||
- prepare in parallel
|
||||
- refactor new() into builder
|
||||
- small texts (<|bos|>?)
|
||||
|
||||
- Training
|
||||
- multi-device training
|
||||
- model parameters in file
|
||||
@@ -10,15 +10,18 @@ workspace = true
|
||||
[dependencies]
|
||||
tokenizer = { workspace = true }
|
||||
|
||||
ahash = { workspace = true }
|
||||
anstyle = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
burn = { workspace = true }
|
||||
burn-train = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
indicatif = { workspace = true }
|
||||
ndarray = { workspace = true }
|
||||
parking_lot = { workspace = true }
|
||||
parquet = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
@@ -62,3 +62,16 @@ pub fn progress_bytes() -> ProgressStyle {
|
||||
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
|
||||
]);
|
||||
}
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
pub fn progress_persec() -> ProgressStyle {
|
||||
return ProgressStyle::default_bar()
|
||||
.template(
|
||||
" {bar:16.red/white.dim} {elapsed_precise:.dim} {pos}/{len} ({per_sec:>3}) {msg:.dim} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.progress_chars("---")
|
||||
.tick_strings(&[
|
||||
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
|
||||
]);
|
||||
}
|
||||
|
||||
@@ -21,9 +21,8 @@ const MAX_SHARD: usize = 1822;
|
||||
#[derive(Debug, Args, Clone)]
|
||||
|
||||
pub struct DownloadArgs {
|
||||
/// Training data dir
|
||||
#[clap(default_value = "data")]
|
||||
data_dir: PathBuf,
|
||||
/// Training data directory (will be created)
|
||||
data: PathBuf,
|
||||
|
||||
/// Number of shards to download (-1 for all)
|
||||
#[arg(short = 'n', long, default_value = "-1")]
|
||||
@@ -37,7 +36,7 @@ pub struct DownloadArgs {
|
||||
impl DownloadArgs {
|
||||
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
||||
info!("Downloading files from {BASE_URL}");
|
||||
fs::create_dir_all(&self.data_dir)?;
|
||||
fs::create_dir_all(&self.data)?;
|
||||
|
||||
let num_shards_to_download = if self.num_files == -1 {
|
||||
MAX_SHARD + 1
|
||||
@@ -48,7 +47,7 @@ impl DownloadArgs {
|
||||
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
|
||||
|
||||
info!("Downloading {} shards...", ids_to_download.len(),);
|
||||
info!("Target directory: {}", self.data_dir.display());
|
||||
info!("Target directory: {}", self.data.display());
|
||||
|
||||
let main_pb = mp.as_ref().map(|mp| {
|
||||
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
|
||||
@@ -70,7 +69,7 @@ impl DownloadArgs {
|
||||
ids_to_download
|
||||
.into_par_iter()
|
||||
.for_each_with(tx, |tx, index| {
|
||||
let target = self.data_dir.clone();
|
||||
let target = self.data.clone();
|
||||
let main_pb = main_pb.clone();
|
||||
let mp_clone = mp.clone();
|
||||
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
mod download;
|
||||
mod sample_data;
|
||||
mod train_model;
|
||||
mod train_tokenizer;
|
||||
|
||||
#[derive(Debug, clap::Subcommand)]
|
||||
@@ -19,7 +19,7 @@ pub enum SubCommand {
|
||||
/// Train model
|
||||
TrainModel {
|
||||
#[command(flatten)]
|
||||
args: sample_data::TrainModelArgs,
|
||||
args: train_model::TrainModelArgs,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,735 +0,0 @@
|
||||
use ahash::AHasher;
|
||||
use anyhow::{Context, Result};
|
||||
use burn::{
|
||||
Tensor,
|
||||
backend::{Autodiff, Cuda, cuda::CudaDevice},
|
||||
config::Config,
|
||||
module::{AutodiffModule, Module, Param, ParamId},
|
||||
nn::{
|
||||
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
||||
loss::CrossEntropyLossConfig,
|
||||
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
||||
},
|
||||
optim::{AdamConfig, GradientsParams, Optimizer},
|
||||
prelude::{Backend, ToElement},
|
||||
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||
};
|
||||
use burn_train::ClassificationOutput;
|
||||
use clap::Args;
|
||||
use indicatif::MultiProgress;
|
||||
use ndarray::{Array1, Array2};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
f32,
|
||||
fs::File,
|
||||
hash::Hasher,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use tokenizer::Tokenizer;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::data_reader::{DataReader, DataReaderError};
|
||||
|
||||
// Text generation routine
|
||||
|
||||
/*
|
||||
{
|
||||
let init = "Initial context. This is ";
|
||||
let tokens = tokenizer.encode(&init);
|
||||
|
||||
let n_tokens = tokens.len();
|
||||
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||
let mut input: Tensor<Cuda, 1, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||
.reshape([n_tokens]);
|
||||
|
||||
for _ in 0..100 {
|
||||
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||
println!("{:?}", tokens);
|
||||
println!("{}", tokenizer.decode(&tokens));
|
||||
|
||||
// Crop idx to context size;
|
||||
let batch = input
|
||||
.clone()
|
||||
.slice([0..config.context_size])
|
||||
.unsqueeze_dim(0);
|
||||
|
||||
// shape: [tokens, vocab_size]
|
||||
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||
|
||||
// shape: [vocab_size]
|
||||
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||
|
||||
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||
let id_next = probs.argmax(0); // shape: [1]
|
||||
|
||||
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
struct TrainTestIterator<'a, B: Backend> {
|
||||
reader: DataReader,
|
||||
|
||||
ccfg: &'a ComputeConfig,
|
||||
mcfg: &'a GptModelConfig,
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
device: &'a B::Device,
|
||||
|
||||
error: bool,
|
||||
|
||||
// Tokenized input/output pairs
|
||||
pairs: VecDeque<(Vec<u32>, u32)>,
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> TrainTestIterator<'a, B> {
|
||||
pub fn new(
|
||||
data_dir: impl AsRef<Path>,
|
||||
ccfg: &'a ComputeConfig,
|
||||
mcfg: &'a GptModelConfig,
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
device: &'a B::Device,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let reader = DataReader::new(10, data_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
reader,
|
||||
ccfg,
|
||||
mcfg,
|
||||
tokenizer,
|
||||
eval,
|
||||
device,
|
||||
|
||||
error: false,
|
||||
pairs: VecDeque::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
|
||||
type Item = Result<TrainBatch<B>, DataReaderError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.error {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
|
||||
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
|
||||
let stride = self.mcfg.context_size;
|
||||
|
||||
while inputs.len() < self.ccfg.batch_size {
|
||||
match self.pairs.pop_front() {
|
||||
Some((i, t)) => {
|
||||
// train/test split
|
||||
{
|
||||
let mut hasher = AHasher::default();
|
||||
hasher.write(self.ccfg.eval_salt.as_bytes());
|
||||
|
||||
// Don't care about endianness, ahash output is unstable anyway
|
||||
hasher.write(unsafe { std::mem::transmute(&i[..]) });
|
||||
hasher.write_u32(t);
|
||||
|
||||
let test = // is this point in the test set?
|
||||
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
|
||||
|
||||
if test ^ self.eval {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
inputs.push(i);
|
||||
targets.push(t);
|
||||
}
|
||||
|
||||
None => {
|
||||
let text = match self.reader.next() {
|
||||
None => break,
|
||||
Some(Ok(x)) => x,
|
||||
Some(Err(x)) => {
|
||||
self.error = true;
|
||||
return Some(Err(x));
|
||||
}
|
||||
};
|
||||
|
||||
let emb = self.tokenizer.encode(&text);
|
||||
|
||||
// Skip small texts
|
||||
//
|
||||
// TODO: do this better
|
||||
// TODO: maybe using <|bos|>?
|
||||
// TODO: non-uniform batches?
|
||||
if emb.len() < self.mcfg.context_size {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pairs = emb
|
||||
.windows(self.mcfg.context_size + 1)
|
||||
.step_by(stride)
|
||||
.map(|x| {
|
||||
(
|
||||
x[..self.mcfg.context_size].to_vec(),
|
||||
x[self.mcfg.context_size],
|
||||
)
|
||||
});
|
||||
|
||||
self.pairs.extend(pairs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inputs.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let shape = [inputs.len(), self.mcfg.context_size];
|
||||
|
||||
// Arrange data in memory
|
||||
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
|
||||
let targets: Array1<u32> = Array1::from_vec(targets);
|
||||
|
||||
// Create tensors on gpu
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let inputs =
|
||||
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
|
||||
|
||||
return Some(Ok(TrainBatch { inputs, targets }));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
|
||||
pub struct TrainModelArgs {
|
||||
/// Path to training data
|
||||
data: PathBuf,
|
||||
|
||||
/// Path to tokenizer
|
||||
#[clap(long)]
|
||||
tokenizer: PathBuf,
|
||||
}
|
||||
|
||||
pub struct ComputeConfig {
|
||||
pub batch_size: usize,
|
||||
pub eval_frac: f64,
|
||||
pub eval_salt: String,
|
||||
}
|
||||
|
||||
impl TrainModelArgs {
|
||||
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
||||
let device = CudaDevice::new(0);
|
||||
//let device = WgpuDevice::DiscreteGpu(0);
|
||||
|
||||
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
||||
let tokenizer: Tokenizer =
|
||||
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
||||
|
||||
let ccfg = ComputeConfig {
|
||||
batch_size: 10,
|
||||
eval_frac: 0.1,
|
||||
eval_salt: "salt".into(),
|
||||
};
|
||||
|
||||
let mcfg = GptModelConfig {
|
||||
vocab_size: tokenizer.vocab_size(),
|
||||
context_size: 256,
|
||||
embed_dim: 768,
|
||||
n_heads: 12,
|
||||
head_dim: 64, // = 768 / 12
|
||||
n_layers: 1,
|
||||
embed_drop_rate: 0.1,
|
||||
attention_drop_rate: 0.1,
|
||||
shortcut_drop_rate: 0.1,
|
||||
};
|
||||
|
||||
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
|
||||
|
||||
/*
|
||||
let loader_train = DataLoaderBuilder::new(batcher.clone())
|
||||
.batch_size(ccfg.batch_size)
|
||||
//.shuffle(config.seed)
|
||||
.num_workers(5)
|
||||
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||
|
||||
let loader_test = DataLoaderBuilder::new(batcher)
|
||||
.batch_size(ccfg.batch_size)
|
||||
//.shuffle(config.seed)
|
||||
.num_workers(5)
|
||||
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
|
||||
|
||||
let learner = LearnerBuilder::new("./tmp")
|
||||
.metric_train_numeric(AccuracyMetric::new())
|
||||
.metric_valid_numeric(AccuracyMetric::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
|
||||
.num_epochs(10)
|
||||
.summary()
|
||||
.build(model, AdamConfig::new().init(), 1e-4);
|
||||
|
||||
learner.fit(loader_train, loader_test);
|
||||
*/
|
||||
|
||||
// Initialize optimizer
|
||||
let mut optim = AdamConfig::new().init();
|
||||
let learning_rate = 1e-4;
|
||||
|
||||
for epoch in 0..10 {
|
||||
debug!("Running epoch {epoch}");
|
||||
|
||||
// Training phase
|
||||
let mut train_loss_sum = 0.0;
|
||||
let mut train_total = 0;
|
||||
|
||||
for batch in
|
||||
TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
|
||||
// Forward pass with gradients
|
||||
let output = model.forward_train(batch.inputs, batch.targets);
|
||||
|
||||
train_total += output.targets.dims()[0] as i32;
|
||||
train_loss_sum += output.loss.clone().into_scalar().to_f32();
|
||||
|
||||
let grads = output.loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
}
|
||||
|
||||
let mut valid_loss_sum = 0.0;
|
||||
let mut valid_total = 0;
|
||||
|
||||
let mut n_eval = 0;
|
||||
debug!("Evaluating batches");
|
||||
|
||||
for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
n_eval += batch.targets.shape()[0];
|
||||
|
||||
// Forward pass without gradients
|
||||
let output = model.valid().forward_train(batch.inputs, batch.targets);
|
||||
|
||||
valid_total += output.targets.dims()[0] as i32;
|
||||
valid_loss_sum += output.loss.into_scalar().to_f32();
|
||||
}
|
||||
|
||||
// Compute and log epoch results
|
||||
let train_loss = if train_total > 0 {
|
||||
train_loss_sum / train_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let valid_loss = if valid_total > 0 {
|
||||
valid_loss_sum / valid_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// MARK: model
|
||||
//
|
||||
|
||||
/// Multihead attention.
|
||||
///
|
||||
/// Equivalent to many stacked CausalAttention layers.
|
||||
/// These are packed inside one big tensor for efficiency.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiheadAttention<B: Backend> {
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
|
||||
// Can also use Linear layers with disabled bias
|
||||
// (they may also have a better initialization routine)
|
||||
// TODO: see source code, make this equivalent
|
||||
/// Query weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns "what question to ask about the text"
|
||||
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||
w_query: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Key weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns what properties a certain token
|
||||
/// has when it appears as a context (non-query) token.
|
||||
w_key: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Value weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, ???
|
||||
w_value: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Optional final projection.
|
||||
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
|
||||
w_output: Param<Tensor<B, 2>>,
|
||||
|
||||
dropout: Dropout,
|
||||
|
||||
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||
/// Used to mask future tokens.
|
||||
utri_mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiheadAttention<B> {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
head_dim: usize,
|
||||
n_heads: usize,
|
||||
context_length: usize,
|
||||
dropout: f64,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
let total_dim = head_dim * n_heads;
|
||||
|
||||
Self {
|
||||
n_heads,
|
||||
head_dim,
|
||||
|
||||
w_query: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_key: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_value: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_output: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([total_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[total_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
dropout: Dropout { prob: dropout },
|
||||
|
||||
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute self-attention vector for the given batch
|
||||
///
|
||||
/// - input shape is [batch, token, token_dim]
|
||||
/// - input shape is [batch, token, n_heads * head_dim]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||
//
|
||||
// Multiple heads are batched into one tensor.
|
||||
|
||||
let batch = input.dims()[0];
|
||||
let tokens = input.dims()[1];
|
||||
|
||||
let w_query = self
|
||||
.w_query
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_key = self
|
||||
.w_key
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_value = self
|
||||
.w_value
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_output = self
|
||||
.w_output
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
// Map batch to inner latent space.
|
||||
// shape: [batch, token, inner_dim]
|
||||
let queries = input.clone().matmul(w_query);
|
||||
let keys = input.clone().matmul(w_key);
|
||||
let values = input.clone().matmul(w_value);
|
||||
|
||||
// Split head dimensions
|
||||
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
|
||||
// from: [batch, tok, head, head_dim]
|
||||
// to: [batch, head, tok, head_dim]
|
||||
let keys = keys.swap_dims(1, 2);
|
||||
let values = values.swap_dims(1, 2);
|
||||
let queries = queries.swap_dims(1, 2);
|
||||
|
||||
// Compute attention scores for each head
|
||||
// (cosine similarity of each query token to each context token, per head)
|
||||
//
|
||||
// lhs shape: [batch, head, tok, head_dim]
|
||||
// rhs shape: [batch, head, head_dim, tok]
|
||||
// output shape: [batch, head, query_token, context_token]
|
||||
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
|
||||
|
||||
let mask = self
|
||||
.utri_mask
|
||||
.clone()
|
||||
.slice([0..tokens, 0..tokens])
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.unsqueeze_dim::<4>(0)
|
||||
.expand(attn_scores.shape());
|
||||
|
||||
// Mask out future tokens by filling
|
||||
// upper-triangular with -inf, which becomes 0.0 after softmax.
|
||||
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
|
||||
|
||||
// Normalize attn weights.
|
||||
//
|
||||
// Divide by sqrt(inner_dim) because...
|
||||
// - dot products get larger with larger dimensions
|
||||
// - this causes softmax to "saturate", making all other values very small
|
||||
// - which makes gradients vanish during training
|
||||
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
|
||||
let attn_weights = self.dropout.forward(attn_weights);
|
||||
|
||||
// lhs shape: [batch, head, query_token, context_token]
|
||||
// rhs shape: [batch, head, tok, head_dim]
|
||||
// matmul shape: [batch, head, tok, head_dim]
|
||||
// out shape: [batch, tok, head, head_dim]
|
||||
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
|
||||
|
||||
// shape: [batch, tok, stacked_dim]
|
||||
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
|
||||
|
||||
// Apply final projection (optional)
|
||||
let context_vec = context_vec.matmul(w_output);
|
||||
|
||||
return context_vec;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct GptModelConfig {
|
||||
/// Number of tokens
|
||||
pub vocab_size: u32,
|
||||
|
||||
/// Maximum number of input tokens with positional embeddings
|
||||
pub context_size: usize,
|
||||
|
||||
/// Dimension of each token's embedding
|
||||
pub embed_dim: usize,
|
||||
|
||||
/// Number of attention heads
|
||||
pub n_heads: usize,
|
||||
|
||||
/// Dimension of each attn head
|
||||
pub head_dim: usize,
|
||||
|
||||
/// Number of transformer blocks
|
||||
pub n_layers: usize,
|
||||
|
||||
pub embed_drop_rate: f64,
|
||||
pub attention_drop_rate: f64,
|
||||
pub shortcut_drop_rate: f64,
|
||||
}
|
||||
|
||||
impl GptModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
|
||||
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
|
||||
|
||||
GptModel {
|
||||
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
|
||||
.init(device),
|
||||
|
||||
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
|
||||
|
||||
embedder_drop: Dropout {
|
||||
prob: self.embed_drop_rate,
|
||||
},
|
||||
|
||||
trf_blocks: (0..self.n_layers)
|
||||
.map(|_| TransformerBlock::new(&self, device))
|
||||
.collect(),
|
||||
|
||||
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
|
||||
|
||||
out_head: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random(out_head_shape, Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
out_head_shape.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainBatch<B: Backend> {
|
||||
pub inputs: Tensor<B, 2, Int>,
|
||||
|
||||
/// Correct next token for each input
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GptModel<B: Backend> {
|
||||
embedder_tok: Embedding<B>,
|
||||
embedder_pos: Embedding<B>,
|
||||
embedder_drop: Dropout,
|
||||
|
||||
trf_blocks: Vec<TransformerBlock<B>>,
|
||||
final_norm: LayerNorm<B>,
|
||||
out_head: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GptModel<B> {
|
||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let n_tokens = input.shape()[1];
|
||||
|
||||
let embed_tok = self.embedder_tok.forward(input.clone());
|
||||
let embed_pos = self
|
||||
.embedder_pos
|
||||
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
||||
|
||||
let x = embed_tok + embed_pos;
|
||||
let x = self.embedder_drop.forward(x);
|
||||
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
|
||||
let x = self.final_norm.forward(x);
|
||||
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
|
||||
|
||||
return logits;
|
||||
}
|
||||
|
||||
pub fn forward_train(
|
||||
&self,
|
||||
inputs: Tensor<B, 2, Int>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
// shape: [batch, n_tokens, n_vocabulary]
|
||||
let output = self.forward(inputs);
|
||||
|
||||
// Get last token
|
||||
// shape: [batch, n_vocabulary]
|
||||
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
|
||||
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&targets.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerBlock<B: Backend> {
|
||||
attention: MultiheadAttention<B>,
|
||||
|
||||
/// TODO: wtf?
|
||||
ff: PositionWiseFeedForward<B>,
|
||||
|
||||
/// TODO: wtf?
|
||||
norm_a: LayerNorm<B>,
|
||||
norm_b: LayerNorm<B>,
|
||||
|
||||
drop_shortcut: Dropout,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerBlock<B> {
|
||||
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
attention: MultiheadAttention::new(
|
||||
cfg.embed_dim,
|
||||
cfg.head_dim,
|
||||
cfg.n_heads,
|
||||
cfg.context_size,
|
||||
cfg.attention_drop_rate,
|
||||
device,
|
||||
),
|
||||
|
||||
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
|
||||
.with_dropout(0.0)
|
||||
.init(device),
|
||||
|
||||
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||
|
||||
drop_shortcut: Dropout {
|
||||
prob: cfg.shortcut_drop_rate,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let input = {
|
||||
let shortcut = input.clone();
|
||||
let x = self.norm_a.forward(input);
|
||||
let x = self.attention.forward(x);
|
||||
let x = self.drop_shortcut.forward(x);
|
||||
x + shortcut
|
||||
};
|
||||
|
||||
let input = {
|
||||
// TODO: wtf?
|
||||
let shortcut = input.clone();
|
||||
let x = self.norm_b.forward(input);
|
||||
let x = self.ff.forward(x);
|
||||
let x = self.drop_shortcut.forward(x);
|
||||
x + shortcut
|
||||
};
|
||||
|
||||
return input;
|
||||
}
|
||||
}
|
||||
300
crates/llmfs/src/command/train_model.rs
Normal file
300
crates/llmfs/src/command/train_model.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
use anyhow::{Context, Result};
|
||||
use burn::{
|
||||
backend::Autodiff,
|
||||
module::{AutodiffModule, Module},
|
||||
optim::{AdamConfig, GradientsParams, Optimizer},
|
||||
prelude::ToElement,
|
||||
record::{FullPrecisionSettings, NamedMpkFileRecorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
};
|
||||
use clap::Args;
|
||||
use indicatif::{MultiProgress, ProgressBar};
|
||||
use std::{f32, fs::File, num::NonZero, path::PathBuf, time::Duration};
|
||||
use tokenizer::Tokenizer;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{
|
||||
InferenceDevice,
|
||||
cli::{progress_big, progress_persec},
|
||||
parts::{GptModel, GptModelConfig},
|
||||
train_test_iterator::TrainTestIterator,
|
||||
};
|
||||
|
||||
// Text generation routine
|
||||
/*
|
||||
{
|
||||
let init = "Initial context. This is ";
|
||||
let tokens = tokenizer.encode(&init);
|
||||
|
||||
let n_tokens = tokens.len();
|
||||
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||
let mut input: Tensor<Cuda, 1, Int> =
|
||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||
.reshape([n_tokens]);
|
||||
|
||||
for _ in 0..100 {
|
||||
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||
println!("{:?}", tokens);
|
||||
println!("{}", tokenizer.decode(&tokens));
|
||||
|
||||
// Crop idx to context size;
|
||||
let batch = input
|
||||
.clone()
|
||||
.slice([0..config.context_size])
|
||||
.unsqueeze_dim(0);
|
||||
|
||||
// shape: [tokens, vocab_size]
|
||||
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||
|
||||
// shape: [vocab_size]
|
||||
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||
|
||||
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||
let id_next = probs.argmax(0); // shape: [1]
|
||||
|
||||
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
pub struct TrainModelArgs {
|
||||
/// Path to training data
|
||||
data: PathBuf,
|
||||
|
||||
/// Path to tokenizer
|
||||
#[clap(long, default_value = "tokenizer.json")]
|
||||
tokenizer: PathBuf,
|
||||
|
||||
/// directory to save checkpoints
|
||||
#[clap(long, default_value = "checkpoints")]
|
||||
checkpoints: PathBuf,
|
||||
|
||||
/// The device to use for compute. `wgpu:n`, `cuda:n`, or `cpu`
|
||||
#[clap(long, default_value = "cpu")]
|
||||
device: InferenceDevice,
|
||||
|
||||
/// Training batch size
|
||||
#[clap(long, default_value = "10")]
|
||||
batch: NonZero<usize>,
|
||||
|
||||
/// Proportion of data reserved for evaluation
|
||||
#[clap(long, default_value = "0.1")]
|
||||
eval_frac: f64,
|
||||
|
||||
/// Eval hasher salt
|
||||
#[clap(long, default_value = "eval-salt")]
|
||||
eval_salt: String,
|
||||
|
||||
/// Number of threads reading data
|
||||
#[clap(long, default_value = "5")]
|
||||
readers: usize,
|
||||
}
|
||||
|
||||
pub struct ComputeConfig {
|
||||
pub batch_size: usize,
|
||||
pub eval_frac: f64,
|
||||
pub eval_salt: String,
|
||||
}
|
||||
|
||||
impl TrainModelArgs {
|
||||
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
||||
match self.device {
|
||||
InferenceDevice::Cpu => {
|
||||
use burn::backend::NdArray;
|
||||
use burn::backend::ndarray::NdArrayDevice;
|
||||
|
||||
let device = NdArrayDevice::Cpu;
|
||||
self.run_inner::<Autodiff<NdArray>>(mp, device)?;
|
||||
}
|
||||
|
||||
InferenceDevice::Cuda(x) => {
|
||||
use burn::backend::Cuda;
|
||||
use burn::backend::cuda::CudaDevice;
|
||||
|
||||
let device = CudaDevice::new(x);
|
||||
self.run_inner::<Autodiff<Cuda>>(mp, device)?;
|
||||
}
|
||||
|
||||
InferenceDevice::Wgpu(x) => {
|
||||
use burn::backend::Wgpu;
|
||||
use burn::backend::wgpu::WgpuDevice;
|
||||
|
||||
let device = WgpuDevice::DiscreteGpu(x);
|
||||
self.run_inner::<Autodiff<Wgpu>>(mp, device)?;
|
||||
}
|
||||
};
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn run_inner<B: AutodiffBackend>(
|
||||
self,
|
||||
mp: Option<MultiProgress>,
|
||||
device: B::Device,
|
||||
) -> Result<()> {
|
||||
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
||||
let tokenizer: Tokenizer =
|
||||
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
||||
|
||||
let ccfg = ComputeConfig {
|
||||
batch_size: self.batch.get(),
|
||||
eval_frac: self.eval_frac,
|
||||
eval_salt: self.eval_salt.clone(),
|
||||
};
|
||||
|
||||
let mcfg = GptModelConfig {
|
||||
vocab_size: tokenizer.vocab_size(),
|
||||
context_size: 256, // TODO: MORE!
|
||||
embed_dim: 768,
|
||||
n_heads: 12,
|
||||
head_dim: 64, // = 768 / 12
|
||||
n_layers: 12,
|
||||
embed_drop_rate: 0.1,
|
||||
attention_drop_rate: 0.1,
|
||||
shortcut_drop_rate: 0.1,
|
||||
};
|
||||
|
||||
let mut model: GptModel<B> = mcfg.init(&device);
|
||||
|
||||
let mut optim = AdamConfig::new().init();
|
||||
let learning_rate = 1e-4;
|
||||
|
||||
std::fs::create_dir_all(&self.checkpoints).context("while creating checkpoint dir")?;
|
||||
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
|
||||
|
||||
let main_pb = mp.as_ref().map(|mp| {
|
||||
let pb = mp.add(ProgressBar::new(10 as u64));
|
||||
pb.set_style(progress_big());
|
||||
pb.set_message("Training model");
|
||||
pb.enable_steady_tick(Duration::from_millis(100));
|
||||
pb
|
||||
});
|
||||
|
||||
for epoch in 0..10 {
|
||||
debug!("Running epoch {epoch}");
|
||||
|
||||
let epoch_pb = mp.as_ref().map(|mp| {
|
||||
let pb = mp.add(ProgressBar::no_length());
|
||||
pb.set_style(progress_persec());
|
||||
pb.set_message(format!("Running epoch {epoch}"));
|
||||
pb.enable_steady_tick(Duration::from_millis(100));
|
||||
pb
|
||||
});
|
||||
|
||||
// Training phase
|
||||
let mut train_loss_sum = 0.0;
|
||||
let mut train_total = 0;
|
||||
|
||||
let mut n_train = 0u64;
|
||||
for batch in TrainTestIterator::new(
|
||||
&self.data,
|
||||
&tokenizer,
|
||||
false,
|
||||
ccfg.batch_size,
|
||||
mcfg.context_size,
|
||||
ccfg.eval_frac,
|
||||
&ccfg.eval_salt,
|
||||
self.readers,
|
||||
&device,
|
||||
)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
epoch_pb.as_ref().map(|x| x.inc(1));
|
||||
n_train += batch.inputs.shape()[0] as u64;
|
||||
|
||||
// Forward pass with gradients
|
||||
let output = model.forward_train(batch.inputs, batch.targets);
|
||||
|
||||
train_total += output.targets.dims()[0] as i32;
|
||||
train_loss_sum += output.loss.clone().into_scalar().to_f32();
|
||||
|
||||
let grads = output.loss.backward();
|
||||
let grads = GradientsParams::from_grads(grads, &model);
|
||||
|
||||
model = optim.step(learning_rate, model, grads);
|
||||
}
|
||||
|
||||
epoch_pb.map(|x| x.finish_and_clear());
|
||||
|
||||
let mut valid_loss_sum = 0.0;
|
||||
let mut valid_total = 0;
|
||||
|
||||
let mut n_eval = 0;
|
||||
debug!("Evaluating batches");
|
||||
|
||||
let eval_pb = mp.as_ref().map(|mp| {
|
||||
let pb = mp.add(ProgressBar::no_length());
|
||||
pb.set_style(progress_persec());
|
||||
pb.set_message(format!("Evaluating epoch {epoch}"));
|
||||
pb.enable_steady_tick(Duration::from_millis(100));
|
||||
pb
|
||||
});
|
||||
|
||||
for batch in TrainTestIterator::new(
|
||||
&self.data,
|
||||
&tokenizer,
|
||||
true,
|
||||
ccfg.batch_size,
|
||||
mcfg.context_size,
|
||||
ccfg.eval_frac,
|
||||
&ccfg.eval_salt,
|
||||
self.readers,
|
||||
&device,
|
||||
)
|
||||
.context("while initializing reader")?
|
||||
{
|
||||
let batch = batch.context("while reading batch")?;
|
||||
eval_pb.as_ref().map(|x| x.inc(1));
|
||||
n_eval += batch.inputs.shape()[0] as u64;
|
||||
|
||||
// Forward pass without gradients
|
||||
let output = model.valid().forward_train(batch.inputs, batch.targets);
|
||||
|
||||
valid_total += output.targets.dims()[0] as i32;
|
||||
valid_loss_sum += output.loss.into_scalar().to_f32();
|
||||
}
|
||||
|
||||
eval_pb.map(|x| x.finish_and_clear());
|
||||
|
||||
// Compute and log epoch results
|
||||
let train_loss = if train_total > 0 {
|
||||
train_loss_sum / train_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let valid_loss = if valid_total > 0 {
|
||||
valid_loss_sum / valid_total as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
info!(
|
||||
message = "Ran epoch",
|
||||
epoch, train_loss, valid_loss, n_train, n_eval
|
||||
);
|
||||
main_pb.as_ref().map(|x| x.inc(1));
|
||||
|
||||
{
|
||||
let target = self.checkpoints.join(format!("epoch-{epoch:02}"));
|
||||
|
||||
info!(message = "Saving checkpoint", ?target);
|
||||
std::fs::create_dir_all(&self.checkpoints)
|
||||
.context("while creating checkpoint dir")?;
|
||||
|
||||
model
|
||||
.clone()
|
||||
.save_file(target, &recorder)
|
||||
.context("while saving checkpoint")?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pb) = main_pb.as_ref() {
|
||||
pb.finish_and_clear();
|
||||
info!("Training complete");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -12,22 +12,25 @@ use crate::data_reader::DataReader;
|
||||
#[derive(Debug, Args, Clone)]
|
||||
|
||||
pub struct TrainTokenizerArgs {
|
||||
/// Where to save tokenizer
|
||||
#[clap(default_value = "tokenizer.json")]
|
||||
target: PathBuf,
|
||||
|
||||
/// Path to training data
|
||||
#[clap(long, default_value = "data")]
|
||||
data_dir: PathBuf,
|
||||
data: PathBuf,
|
||||
|
||||
/// Where to save tokenizer
|
||||
#[clap(long, default_value = "tokenizer.json")]
|
||||
target: PathBuf,
|
||||
|
||||
/// Only train on the first n texts
|
||||
#[clap(long)]
|
||||
first_n: Option<usize>,
|
||||
|
||||
/// Number of threads to use for training
|
||||
/// Number of threads to use for training. 0 to autodetect.
|
||||
#[clap(long, default_value = "0")]
|
||||
threads: usize,
|
||||
|
||||
/// Number of threads reading data
|
||||
#[clap(long, default_value = "5")]
|
||||
readers: usize,
|
||||
|
||||
/// Tokenizer vocabulary size
|
||||
#[clap(long, default_value = "65535")]
|
||||
n_tokens: u32,
|
||||
@@ -35,7 +38,8 @@ pub struct TrainTokenizerArgs {
|
||||
|
||||
impl TrainTokenizerArgs {
|
||||
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
||||
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
|
||||
let iter = DataReader::new(self.readers.max(1), &self.data)
|
||||
.context("while initializing data reader")?;
|
||||
|
||||
#[expect(clippy::unwrap_used)] // Lazy error handling
|
||||
let iter = iter.map(|x| x.unwrap());
|
||||
|
||||
@@ -3,6 +3,7 @@ use parking_lot::Mutex;
|
||||
use parquet::errors::ParquetError;
|
||||
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||
use parquet::record::RowAccessor;
|
||||
use rand::seq::SliceRandom;
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
@@ -25,10 +26,11 @@ pub enum DataReaderError {
|
||||
///
|
||||
/// All parquet files have exactly one text column.
|
||||
/// No promises about this struct's behavior if this is not the case.
|
||||
#[derive(Clone)]
|
||||
pub struct DataReader {
|
||||
rx: Receiver<Result<String, DataReaderError>>,
|
||||
rx: Arc<Mutex<Receiver<Result<String, DataReaderError>>>>,
|
||||
total_rows: usize,
|
||||
consumed_rows: AtomicUsize,
|
||||
consumed_rows: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl DataReader {
|
||||
@@ -57,6 +59,8 @@ impl DataReader {
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
files.shuffle(&mut rand::rng());
|
||||
(Arc::new(Mutex::new(files)), total_rows)
|
||||
};
|
||||
|
||||
@@ -147,9 +151,9 @@ impl DataReader {
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
rx,
|
||||
rx: Arc::new(Mutex::new(rx)),
|
||||
total_rows,
|
||||
consumed_rows: AtomicUsize::new(0),
|
||||
consumed_rows: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,7 +161,7 @@ impl DataReader {
|
||||
/// Order is arbitrary.
|
||||
/// Returns `None` when all rows have been read.
|
||||
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
|
||||
self.rx.recv().ok()
|
||||
self.rx.lock().recv().ok()
|
||||
}
|
||||
|
||||
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {
|
||||
|
||||
@@ -12,8 +12,7 @@ use tracing_subscriber::{
|
||||
// MARK: loglevel
|
||||
//
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
|
||||
#[derive(Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)]
|
||||
pub enum LogLevel {
|
||||
Trace,
|
||||
Debug,
|
||||
@@ -23,7 +22,6 @@ pub enum LogLevel {
|
||||
Error,
|
||||
}
|
||||
|
||||
|
||||
impl Display for LogLevel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
@@ -47,7 +45,7 @@ pub struct LoggingConfig {
|
||||
pub silence: LogLevel,
|
||||
|
||||
// Bins
|
||||
pub nanochat: LogLevel,
|
||||
pub llmfs: LogLevel,
|
||||
}
|
||||
|
||||
impl From<LoggingConfig> for EnvFilter {
|
||||
@@ -71,7 +69,7 @@ impl From<LoggingConfig> for EnvFilter {
|
||||
//
|
||||
// Bins
|
||||
//
|
||||
format!("nanochat_rs={}", conf.nanochat),
|
||||
format!("llmfs={}", conf.llmfs),
|
||||
conf.other.to_string(),
|
||||
]
|
||||
.join(","),
|
||||
@@ -164,31 +162,31 @@ impl LogFilterPreset {
|
||||
Self::Error => LoggingConfig {
|
||||
other: LogLevel::Error,
|
||||
silence: LogLevel::Error,
|
||||
nanochat: LogLevel::Error,
|
||||
llmfs: LogLevel::Error,
|
||||
},
|
||||
|
||||
Self::Warn => LoggingConfig {
|
||||
other: LogLevel::Warn,
|
||||
silence: LogLevel::Warn,
|
||||
nanochat: LogLevel::Warn,
|
||||
llmfs: LogLevel::Warn,
|
||||
},
|
||||
|
||||
Self::Info => LoggingConfig {
|
||||
other: LogLevel::Warn,
|
||||
silence: LogLevel::Warn,
|
||||
nanochat: LogLevel::Info,
|
||||
llmfs: LogLevel::Info,
|
||||
},
|
||||
|
||||
Self::Debug => LoggingConfig {
|
||||
other: LogLevel::Warn,
|
||||
silence: LogLevel::Warn,
|
||||
nanochat: LogLevel::Debug,
|
||||
llmfs: LogLevel::Debug,
|
||||
},
|
||||
|
||||
Self::Trace => LoggingConfig {
|
||||
other: LogLevel::Trace,
|
||||
silence: LogLevel::Warn,
|
||||
nanochat: LogLevel::Trace,
|
||||
llmfs: LogLevel::Trace,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -216,8 +214,7 @@ pub enum LoggingTarget {
|
||||
}
|
||||
|
||||
/// How to print logs
|
||||
#[derive(Debug, Clone, Copy, Deserialize)]
|
||||
#[derive(Default)]
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Default)]
|
||||
pub enum LoggingFormat {
|
||||
#[default]
|
||||
Ansi,
|
||||
@@ -225,7 +222,6 @@ pub enum LoggingFormat {
|
||||
Json,
|
||||
}
|
||||
|
||||
|
||||
pub struct LoggingInitializer {
|
||||
/// Log filter for printed logs
|
||||
pub preset: LogFilterPreset,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
#![recursion_limit = "256"] // needed to resolve burn types
|
||||
|
||||
use clap::Parser;
|
||||
use indicatif::MultiProgress;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
@@ -11,6 +14,8 @@ mod cli;
|
||||
mod command;
|
||||
mod data_reader;
|
||||
mod logging;
|
||||
mod parts;
|
||||
mod train_test_iterator;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None, styles=crate::cli::clap_styles())]
|
||||
@@ -60,3 +65,66 @@ fn main() {
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
//
|
||||
//
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub enum InferenceDevice {
|
||||
#[default]
|
||||
Cpu,
|
||||
Cuda(usize),
|
||||
Wgpu(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
|
||||
#[error("{0}")]
|
||||
pub struct ParseDeviceError(String);
|
||||
|
||||
impl std::str::FromStr for InferenceDevice {
|
||||
type Err = ParseDeviceError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let s = s.to_lowercase();
|
||||
|
||||
if s == "cpu" {
|
||||
return Ok(InferenceDevice::Cpu);
|
||||
}
|
||||
|
||||
if let Some(index_str) = s.strip_prefix("cuda:") {
|
||||
return match index_str.parse::<usize>() {
|
||||
Ok(index) => Ok(InferenceDevice::Cuda(index)),
|
||||
Err(_) => Err(ParseDeviceError(format!(
|
||||
"Invalid device index: '{}'",
|
||||
index_str
|
||||
))),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(index_str) = s.strip_prefix("wgpu:") {
|
||||
return match index_str.parse::<usize>() {
|
||||
Ok(index) => Ok(InferenceDevice::Wgpu(index)),
|
||||
Err(_) => Err(ParseDeviceError(format!(
|
||||
"Invalid device index: '{}'",
|
||||
index_str
|
||||
))),
|
||||
};
|
||||
}
|
||||
|
||||
return Err(ParseDeviceError(format!(
|
||||
"Invalid device format: '{}'. Expected 'cpu', 'cuda:N', 'wgpu:N'",
|
||||
s
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for InferenceDevice {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse().map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
228
crates/llmfs/src/parts/attention.rs
Normal file
228
crates/llmfs/src/parts/attention.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use burn::{
|
||||
Tensor,
|
||||
config::Config,
|
||||
module::{Module, Param, ParamId},
|
||||
nn::Dropout,
|
||||
prelude::Backend,
|
||||
tensor::{Bool, Distribution, activation::softmax},
|
||||
};
|
||||
use std::f32;
|
||||
|
||||
#[derive(Debug, Config)]
|
||||
pub struct MultiheadAttentionConfig {
|
||||
pub context_size: usize,
|
||||
pub embed_dim: usize,
|
||||
|
||||
pub n_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub drop_rate: f64,
|
||||
}
|
||||
|
||||
impl MultiheadAttentionConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiheadAttention<B> {
|
||||
let total_dim = self.head_dim * self.n_heads;
|
||||
let embedding_dim = self.embed_dim;
|
||||
|
||||
MultiheadAttention {
|
||||
n_heads: self.n_heads,
|
||||
head_dim: self.head_dim,
|
||||
|
||||
w_query: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[self.embed_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_key: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[self.embed_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_value: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[self.embed_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_output: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([total_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[total_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
dropout: Dropout {
|
||||
prob: self.drop_rate,
|
||||
},
|
||||
|
||||
utri_mask: Tensor::<B, 2, _>::tril_mask(
|
||||
[self.context_size, self.context_size],
|
||||
0,
|
||||
device,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Multihead attention.
|
||||
///
|
||||
/// Equivalent to many stacked CausalAttention layers.
|
||||
/// These are packed inside one big tensor for efficiency.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiheadAttention<B: Backend> {
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
|
||||
// Can also use Linear layers with disabled bias
|
||||
// (they may also have a better initialization routine)
|
||||
// TODO: see source code, make this equivalent
|
||||
/// Query weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns "what question to ask about the text"
|
||||
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||
w_query: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Key weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns what properties a certain token
|
||||
/// has when it appears as a context (non-query) token.
|
||||
w_key: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Value weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, ???
|
||||
w_value: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Optional final projection.
|
||||
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
|
||||
w_output: Param<Tensor<B, 2>>,
|
||||
|
||||
dropout: Dropout,
|
||||
|
||||
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||
/// Used to mask future tokens.
|
||||
utri_mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiheadAttention<B> {
|
||||
/// Compute self-attention vector for the given batch
|
||||
///
|
||||
/// - input shape is [batch, token, token_dim]
|
||||
/// - input shape is [batch, token, n_heads * head_dim]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||
//
|
||||
// Multiple heads are batched into one tensor.
|
||||
|
||||
let batch = input.dims()[0];
|
||||
let tokens = input.dims()[1];
|
||||
|
||||
let w_query = self
|
||||
.w_query
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_key = self
|
||||
.w_key
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_value = self
|
||||
.w_value
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_output = self
|
||||
.w_output
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
// Map batch to inner latent space.
|
||||
// shape: [batch, token, inner_dim]
|
||||
let queries = input.clone().matmul(w_query);
|
||||
let keys = input.clone().matmul(w_key);
|
||||
let values = input.clone().matmul(w_value);
|
||||
|
||||
// Split head dimensions
|
||||
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
|
||||
// from: [batch, tok, head, head_dim]
|
||||
// to: [batch, head, tok, head_dim]
|
||||
let keys = keys.swap_dims(1, 2);
|
||||
let values = values.swap_dims(1, 2);
|
||||
let queries = queries.swap_dims(1, 2);
|
||||
|
||||
// Compute attention scores for each head
|
||||
// (cosine similarity of each query token to each context token, per head)
|
||||
//
|
||||
// lhs shape: [batch, head, tok, head_dim]
|
||||
// rhs shape: [batch, head, head_dim, tok]
|
||||
// output shape: [batch, head, query_token, context_token]
|
||||
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
|
||||
|
||||
let mask = self
|
||||
.utri_mask
|
||||
.clone()
|
||||
.slice([0..tokens, 0..tokens])
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.unsqueeze_dim::<4>(0)
|
||||
.expand(attn_scores.shape());
|
||||
|
||||
// Mask out future tokens by filling
|
||||
// upper-triangular with -inf, which becomes 0.0 after softmax.
|
||||
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
|
||||
|
||||
// Normalize attn weights.
|
||||
//
|
||||
// Divide by sqrt(inner_dim) because...
|
||||
// - dot products get larger with larger dimensions
|
||||
// - this causes softmax to "saturate", making all other values very small
|
||||
// - which makes gradients vanish during training
|
||||
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
|
||||
let attn_weights = self.dropout.forward(attn_weights);
|
||||
|
||||
// lhs shape: [batch, head, query_token, context_token]
|
||||
// rhs shape: [batch, head, tok, head_dim]
|
||||
// matmul shape: [batch, head, tok, head_dim]
|
||||
// out shape: [batch, tok, head, head_dim]
|
||||
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
|
||||
|
||||
// shape: [batch, tok, stacked_dim]
|
||||
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
|
||||
|
||||
// Apply final projection (optional)
|
||||
let context_vec = context_vec.matmul(w_output);
|
||||
|
||||
return context_vec;
|
||||
}
|
||||
}
|
||||
5
crates/llmfs/src/parts/mod.rs
Normal file
5
crates/llmfs/src/parts/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod attention;
|
||||
pub use attention::*;
|
||||
|
||||
mod model;
|
||||
pub use model::*;
|
||||
194
crates/llmfs/src/parts/model.rs
Normal file
194
crates/llmfs/src/parts/model.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use burn::{
|
||||
Tensor,
|
||||
config::Config,
|
||||
module::{Module, Param, ParamId},
|
||||
nn::{
|
||||
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
||||
loss::CrossEntropyLossConfig,
|
||||
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
||||
},
|
||||
prelude::Backend,
|
||||
tensor::{Distribution, Int},
|
||||
};
|
||||
use burn_train::ClassificationOutput;
|
||||
|
||||
use crate::parts::{MultiheadAttention, MultiheadAttentionConfig};
|
||||
|
||||
#[derive(Debug, Config)]
|
||||
pub struct GptModelConfig {
|
||||
/// Number of tokens
|
||||
pub vocab_size: u32,
|
||||
|
||||
/// Maximum number of input tokens with positional embeddings
|
||||
pub context_size: usize,
|
||||
|
||||
/// Dimension of each token's embedding
|
||||
pub embed_dim: usize,
|
||||
|
||||
/// Number of attention heads
|
||||
pub n_heads: usize,
|
||||
|
||||
/// Dimension of each attn head
|
||||
pub head_dim: usize,
|
||||
|
||||
/// Number of transformer blocks
|
||||
#[config(default = 12)]
|
||||
pub n_layers: usize,
|
||||
|
||||
#[config(default = 0.1)]
|
||||
pub embed_drop_rate: f64,
|
||||
|
||||
#[config(default = 0.1)]
|
||||
pub attention_drop_rate: f64,
|
||||
|
||||
#[config(default = 0.1)]
|
||||
pub shortcut_drop_rate: f64,
|
||||
}
|
||||
|
||||
impl GptModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
|
||||
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
|
||||
|
||||
GptModel {
|
||||
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
|
||||
.init(device),
|
||||
|
||||
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
|
||||
|
||||
embedder_drop: Dropout {
|
||||
prob: self.embed_drop_rate,
|
||||
},
|
||||
|
||||
trf_blocks: (0..self.n_layers)
|
||||
.map(|_| TransformerBlock::new(&self, device))
|
||||
.collect(),
|
||||
|
||||
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
|
||||
|
||||
out_head: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random(out_head_shape, Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
out_head_shape.into(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GptModel<B: Backend> {
|
||||
embedder_tok: Embedding<B>,
|
||||
embedder_pos: Embedding<B>,
|
||||
embedder_drop: Dropout,
|
||||
|
||||
trf_blocks: Vec<TransformerBlock<B>>,
|
||||
final_norm: LayerNorm<B>,
|
||||
out_head: Param<Tensor<B, 2>>,
|
||||
}
|
||||
|
||||
impl<B: Backend> GptModel<B> {
|
||||
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let n_tokens = input.shape()[1];
|
||||
|
||||
let embed_tok = self.embedder_tok.forward(input.clone());
|
||||
let embed_pos = self
|
||||
.embedder_pos
|
||||
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
||||
|
||||
let x = embed_tok + embed_pos;
|
||||
let x = self.embedder_drop.forward(x);
|
||||
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
|
||||
let x = self.final_norm.forward(x);
|
||||
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
|
||||
|
||||
return logits;
|
||||
}
|
||||
|
||||
pub fn forward_train(
|
||||
&self,
|
||||
inputs: Tensor<B, 2, Int>,
|
||||
targets: Tensor<B, 1, Int>,
|
||||
) -> ClassificationOutput<B> {
|
||||
// shape: [batch, n_tokens, n_vocabulary]
|
||||
let output = self.forward(inputs);
|
||||
|
||||
// Get last token
|
||||
// shape: [batch, n_vocabulary]
|
||||
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
|
||||
|
||||
let loss = CrossEntropyLossConfig::new()
|
||||
.init(&targets.device())
|
||||
.forward(output.clone(), targets.clone());
|
||||
|
||||
ClassificationOutput {
|
||||
loss,
|
||||
output,
|
||||
targets,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerBlock<B: Backend> {
|
||||
attention: MultiheadAttention<B>,
|
||||
|
||||
/// TODO: wtf?
|
||||
ff: PositionWiseFeedForward<B>,
|
||||
|
||||
/// TODO: wtf?
|
||||
norm_a: LayerNorm<B>,
|
||||
norm_b: LayerNorm<B>,
|
||||
|
||||
drop_shortcut: Dropout,
|
||||
}
|
||||
|
||||
impl<B: Backend> TransformerBlock<B> {
|
||||
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
|
||||
Self {
|
||||
attention: MultiheadAttentionConfig {
|
||||
embed_dim: cfg.embed_dim,
|
||||
head_dim: cfg.head_dim,
|
||||
n_heads: cfg.n_heads,
|
||||
context_size: cfg.context_size,
|
||||
drop_rate: cfg.attention_drop_rate,
|
||||
}
|
||||
.init(device),
|
||||
|
||||
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
|
||||
.with_dropout(0.0)
|
||||
.init(device),
|
||||
|
||||
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||
|
||||
drop_shortcut: Dropout {
|
||||
prob: cfg.shortcut_drop_rate,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let input = {
|
||||
let shortcut = input.clone();
|
||||
let x = self.norm_a.forward(input);
|
||||
let x = self.attention.forward(x);
|
||||
let x = self.drop_shortcut.forward(x);
|
||||
x + shortcut
|
||||
};
|
||||
|
||||
let input = {
|
||||
// TODO: wtf?
|
||||
let shortcut = input.clone();
|
||||
let x = self.norm_b.forward(input);
|
||||
let x = self.ff.forward(x);
|
||||
let x = self.drop_shortcut.forward(x);
|
||||
x + shortcut
|
||||
};
|
||||
|
||||
return input;
|
||||
}
|
||||
}
|
||||
164
crates/llmfs/src/train_test_iterator.rs
Normal file
164
crates/llmfs/src/train_test_iterator.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use ahash::AHasher;
|
||||
use anyhow::Result;
|
||||
use burn::{
|
||||
Tensor,
|
||||
prelude::{Backend, ToElement},
|
||||
tensor::Int,
|
||||
};
|
||||
use ndarray::{Array1, Array2};
|
||||
use std::{collections::VecDeque, hash::Hasher, path::Path};
|
||||
use tokenizer::Tokenizer;
|
||||
|
||||
use crate::data_reader::{DataReader, DataReaderError};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrainBatch<B: Backend> {
|
||||
/// Input texts.
|
||||
/// shape: [batch, context_size]
|
||||
pub inputs: Tensor<B, 2, Int>,
|
||||
|
||||
/// Correct next token for each input.
|
||||
/// shape: [batch]
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
/// Read texts from a [DataReader], then
|
||||
/// - extract context windows
|
||||
/// - deterministically classify these as "test" or "train"
|
||||
/// - batch output into tensors of token ids
|
||||
pub struct TrainTestIterator<'a, B: Backend> {
|
||||
reader: DataReader,
|
||||
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
device: &'a B::Device,
|
||||
|
||||
batch_size: usize,
|
||||
context_size: usize,
|
||||
eval_frac: f64,
|
||||
eval_salt: String,
|
||||
|
||||
// Tokenized input/output pairs
|
||||
pairs: VecDeque<(Vec<u32>, u32)>,
|
||||
error: bool,
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> TrainTestIterator<'a, B> {
|
||||
pub fn new(
|
||||
data_dir: impl AsRef<Path>,
|
||||
tokenizer: &'a Tokenizer,
|
||||
eval: bool,
|
||||
batch_size: usize,
|
||||
context_size: usize,
|
||||
eval_frac: f64,
|
||||
eval_salt: impl Into<String>,
|
||||
readers: usize,
|
||||
device: &'a B::Device,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let reader = DataReader::new(readers.max(1), data_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
reader,
|
||||
tokenizer,
|
||||
eval,
|
||||
device,
|
||||
|
||||
batch_size,
|
||||
context_size,
|
||||
eval_frac,
|
||||
eval_salt: eval_salt.into(),
|
||||
|
||||
pairs: VecDeque::new(),
|
||||
error: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
|
||||
type Item = Result<TrainBatch<B>, DataReaderError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.error {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut inputs = Vec::with_capacity(self.batch_size);
|
||||
let mut targets = Vec::with_capacity(self.batch_size);
|
||||
let stride = self.context_size;
|
||||
|
||||
while inputs.len() < self.batch_size {
|
||||
match self.pairs.pop_front() {
|
||||
Some((i, t)) => {
|
||||
// train/test split
|
||||
{
|
||||
let mut hasher = AHasher::default();
|
||||
hasher.write(self.eval_salt.as_bytes());
|
||||
|
||||
// Don't care about endianness, ahash output is unstable anyway
|
||||
hasher.write(unsafe { std::mem::transmute(&i[..]) });
|
||||
hasher.write_u32(t);
|
||||
|
||||
let test = // is this point in the test set?
|
||||
hasher.finish() > (u64::MAX as f64 * self.eval_frac).to_u64();
|
||||
|
||||
if test ^ self.eval {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
inputs.push(i);
|
||||
targets.push(t);
|
||||
}
|
||||
|
||||
None => {
|
||||
let text = match self.reader.next() {
|
||||
None => break,
|
||||
Some(Ok(x)) => x,
|
||||
Some(Err(x)) => {
|
||||
self.error = true;
|
||||
return Some(Err(x));
|
||||
}
|
||||
};
|
||||
|
||||
let emb = self.tokenizer.encode(&text);
|
||||
|
||||
// Skip small texts
|
||||
//
|
||||
// TODO: do this better
|
||||
// TODO: maybe using <|bos|>?
|
||||
// TODO: non-uniform batches?
|
||||
if emb.len() < self.context_size {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pairs = emb
|
||||
.windows(self.context_size + 1)
|
||||
.step_by(stride)
|
||||
.map(|x| (x[..self.context_size].to_vec(), x[self.context_size]));
|
||||
|
||||
self.pairs.extend(pairs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inputs.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let shape = [inputs.len(), self.context_size];
|
||||
|
||||
// Arrange data in memory
|
||||
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
|
||||
let targets: Array1<u32> = Array1::from_vec(targets);
|
||||
|
||||
// Create tensors on gpu
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let inputs =
|
||||
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
|
||||
|
||||
#[expect(clippy::unwrap_used)]
|
||||
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
|
||||
|
||||
return Some(Ok(TrainBatch { inputs, targets }));
|
||||
}
|
||||
}
|
||||
@@ -19,8 +19,7 @@ use tracing::{debug, info};
|
||||
|
||||
use crate::{progress_big, split::regex_segment};
|
||||
|
||||
// TODO:
|
||||
// - maybe don't use regex
|
||||
// Maybe don't use regex for performance?
|
||||
|
||||
#[derive(Debug, Clone, thiserror::Error)]
|
||||
pub enum TokenizerTrainError {
|
||||
|
||||
Reference in New Issue
Block a user