diff --git a/Cargo.lock b/Cargo.lock index 0767f61..69f6bc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -451,6 +451,7 @@ dependencies = [ "ahash", "bincode", "burn-common", + "burn-dataset", "burn-derive", "burn-tensor", "data-encoding", @@ -546,6 +547,25 @@ dependencies = [ "log", ] +[[package]] +name = "burn-dataset" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534d4398fd6aaec32f8caeb3f20ddffcd8a059bdefc01cc2794b91b4e984e8ea" +dependencies = [ + "csv", + "derive-new", + "dirs", + "rand", + "rmp-serde", + "sanitize-filename", + "serde", + "serde_json", + "strum", + "tempfile", + "thiserror 2.0.17", +] + [[package]] name = "burn-derive" version = "0.19.1" @@ -598,8 +618,10 @@ dependencies = [ "burn-common", "burn-ir", "burn-tensor", + "bytemuck", "const-random", "derive-new", + "itertools 0.14.0", "libm", "macerator", "matrixmultiply", @@ -608,6 +630,7 @@ dependencies = [ "paste", "portable-atomic-util", "rand", + "seq-macro", "spin", ] @@ -703,6 +726,25 @@ dependencies = [ "serde_bytes", ] +[[package]] +name = "burn-train" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0f1553197d50668823a4bafc187c62439df49b218973f0ca79e034b57ce38d6" +dependencies = [ + "async-channel", + "burn-core", + "burn-ndarray", + "burn-optim", + "derive-new", + "log", + "rstest", + "serde", + "tracing-appender", + "tracing-core", + "tracing-subscriber", +] + [[package]] name = "burn-wgpu" version = "0.19.1" @@ -1057,6 +1099,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1098,6 +1149,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + [[package]] name = "cubecl" version = "0.8.1" @@ -1573,6 +1645,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive-new" version = "0.7.0" @@ -2064,6 +2145,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -3044,15 +3131,18 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" name = "llmfs" version = "0.0.1" dependencies = [ + "ahash", "anstyle", "anyhow", "burn", + "burn-train", "clap", "futures-util", "indicatif", "ndarray", "parking_lot", "parquet", + "rand", "rayon", "reqwest", "serde", @@ -3149,7 +3239,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" dependencies = [ "autocfg", + "num_cpus", + "once_cell", "rawpointer", + "thread-tree", ] [[package]] @@ -3299,6 +3392,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "rawpointer", + "rayon", "serde", ] @@ -3379,6 +3473,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -3689,6 +3789,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -3994,6 +4100,12 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" + [[package]] name = "renderdoc-sys" version = "1.1.0" @@ -4084,6 +4196,35 @@ dependencies = [ "serde", ] +[[package]] +name = "rstest" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49" +dependencies = [ + "futures-timer", + "futures-util", + "rstest_macros", +] + +[[package]] +name = "rstest_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0" +dependencies = [ + "cfg-if", + "glob", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn", + "unicode-ident", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -4672,6 +4813,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread-tree" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630" +dependencies = [ + "crossbeam-channel", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -4692,6 +4842,37 @@ dependencies = [ "ordered-float 2.10.1", ] +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -4989,6 +5170,18 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf" +dependencies = [ + "crossbeam-channel", + "thiserror 2.0.17", + "time", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.31" diff --git a/Cargo.toml b/Cargo.toml index 82380cd..dd25181 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,10 +75,12 @@ compact_str = "0.9.0" dary_heap = "0.3.8" fancy-regex = "0.16.2" indicatif = { version = "0.18.3", features = ["improved_unicode"] } +itertools = "0.14.0" futures-util = "0.3.31" ndarray = { version = "0.16.1", features = ["serde"] } parking_lot = "0.12.5" parquet = "56.2.0" +rand = "0.9.2" rayon = "1.11.0" reqwest = { version = "0.12.24", features = ["json", "stream"] } serde = "1.0.228" @@ -91,7 +93,10 @@ tracing-indicatif = "0.3.13" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } url = "2.5.7" + +burn-train = { version = "0.19.1", default-features = false } + [workspace.dependencies.burn] version = "0.19.1" default-features = false -features = ["std", "fusion", "ndarray", "webgpu", "cuda"] +features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"] diff --git a/crates/llmfs/Cargo.toml b/crates/llmfs/Cargo.toml index 69bd819..850317d 100644 --- a/crates/llmfs/Cargo.toml +++ b/crates/llmfs/Cargo.toml @@ -10,15 +10,18 @@ workspace = true [dependencies] tokenizer = { workspace = true } +ahash = { workspace = true } anstyle = { workspace = true } anyhow = { workspace = true } burn = { workspace = true } +burn-train = { workspace = true } clap = { workspace = true } futures-util = { workspace = true } indicatif = { workspace = true } ndarray = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true } +rand = { workspace = true } rayon = { workspace = true } reqwest = { workspace = true } serde = { workspace = true } diff --git a/crates/llmfs/src/command/mod.rs b/crates/llmfs/src/command/mod.rs index 779132d..cfa67e1 100644 --- a/crates/llmfs/src/command/mod.rs +++ b/crates/llmfs/src/command/mod.rs @@ -15,7 +15,6 @@ pub enum SubCommand { #[command(flatten)] args: train_tokenizer::TrainTokenizerArgs, }, - /// Sample data SampleData { #[command(flatten)] diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs index 01b7662..0c7ce40 100644 --- a/crates/llmfs/src/command/sample_data.rs +++ b/crates/llmfs/src/command/sample_data.rs @@ -1,22 +1,206 @@ +use ahash::AHasher; use anyhow::{Context, Result}; use burn::{ Tensor, - backend::{Cuda, cuda::CudaDevice}, - module::{Module, Param, ParamId}, + backend::{Autodiff, Cuda, Wgpu, cuda::CudaDevice, wgpu::WgpuDevice}, + config::Config, + module::{AutodiffModule, Module, Param, ParamId}, nn::{ Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, + loss::CrossEntropyLossConfig, transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, }, - prelude::Backend, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::{Backend, ToElement}, tensor::{Bool, Distribution, Int, activation::softmax}, }; +use burn_train::ClassificationOutput; use clap::Args; -use indicatif::MultiProgress; -use ndarray::Array2; -use std::{f32, fs::File, path::PathBuf}; +use indicatif::{MultiProgress, ProgressIterator}; +use ndarray::{Array1, Array2}; +use std::{ + collections::VecDeque, + f32, + fs::File, + hash::Hasher, + path::{Path, PathBuf}, +}; use tokenizer::Tokenizer; +use tracing::{debug, info}; -use crate::data_reader::DataReader; +use crate::data_reader::{DataReader, DataReaderError}; + +// Text generation routine + +/* +{ + let init = "Initial context. This is "; + let tokens = tokenizer.encode(&init); + + let n_tokens = tokens.len(); + let input: Array1 = Array1::from_vec(tokens); + let mut input: Tensor = + Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) + .reshape([n_tokens]); + + for _ in 0..100 { + let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); + println!("{:?}", tokens); + println!("{}", tokenizer.decode(&tokens)); + + // Crop idx to context size; + let batch = input + .clone() + .slice([0..config.context_size]) + .unsqueeze_dim(0); + + // shape: [tokens, vocab_size] + let logits = model.forward(batch).squeeze_dim::<2>(0); + + // shape: [vocab_size] + let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); + + let probs = softmax(logits, 0); // shape: [n_tokens] + let id_next = probs.argmax(0); // shape: [1] + + input = Tensor::cat(vec![input.slice([1..]), id_next], 0); + } +} +*/ + +struct TrainTestIterator<'a, B: Backend> { + reader: DataReader, + + ccfg: &'a ComputeConfig, + mcfg: &'a GptModelConfig, + tokenizer: &'a Tokenizer, + eval: bool, + device: &'a B::Device, + + error: bool, + + // Tokenized input/output pairs + pairs: VecDeque<(Vec, u32)>, +} + +impl<'a, B: Backend> TrainTestIterator<'a, B> { + pub fn new( + data_dir: impl AsRef, + ccfg: &'a ComputeConfig, + mcfg: &'a GptModelConfig, + tokenizer: &'a Tokenizer, + eval: bool, + device: &'a B::Device, + ) -> Result { + let reader = DataReader::new(3, data_dir)?; + + Ok(Self { + reader, + ccfg, + mcfg, + tokenizer, + eval, + device, + + error: false, + pairs: VecDeque::new(), + }) + } +} + +impl Iterator for TrainTestIterator<'_, B> { + type Item = Result, DataReaderError>; + + fn next(&mut self) -> Option { + if self.error { + return None; + } + + let mut inputs = Vec::with_capacity(self.ccfg.batch_size); + let mut targets = Vec::with_capacity(self.ccfg.batch_size); + let stride = self.mcfg.context_size; + + while inputs.len() < self.ccfg.batch_size { + match self.pairs.pop_front() { + Some((i, t)) => { + // train/test split + { + let mut hasher = AHasher::default(); + hasher.write(self.ccfg.eval_salt.as_bytes()); + + // Don't care about endianness, ahash output is unstable anyway + hasher.write(unsafe { std::mem::transmute(&i[..]) }); + hasher.write_u32(t); + + let test = // is this point in the test set? + hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64(); + + if test ^ self.eval { + continue; + } + } + + inputs.push(i); + targets.push(t); + } + + None => { + let text = match self.reader.next() { + None => break, + Some(Ok(x)) => x, + Some(Err(x)) => { + self.error = true; + return Some(Err(x)); + } + }; + + let emb = self.tokenizer.encode(&text); + + // Skip small texts + // + // TODO: do this better + // TODO: maybe using <|bos|>? + // TODO: non-uniform batches? + if emb.len() < self.mcfg.context_size { + continue; + } + + let pairs = emb + .windows(self.mcfg.context_size + 1) + .step_by(stride) + .map(|x| { + ( + x[..self.mcfg.context_size].to_vec(), + x[self.mcfg.context_size], + ) + }); + + self.pairs.extend(pairs); + } + } + } + + if inputs.is_empty() { + return None; + } + + let shape = [inputs.len(), self.mcfg.context_size]; + + // Arrange data in memory + let inputs: Array2 = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]); + let targets: Array1 = Array1::from_vec(targets); + + // Create tensors on gpu + #[expect(clippy::unwrap_used)] + let inputs = + Tensor::::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape); + + #[expect(clippy::unwrap_used)] + let targets = Tensor::::from_ints(targets.as_slice().unwrap(), self.device); + + return Some(Ok(TrainBatch { inputs, targets })); + } +} #[derive(Debug, Args, Clone)] @@ -32,190 +216,144 @@ pub struct SampleDataArgs { /// How many texts to return #[clap(long, short = 'n', default_value = "10")] n: usize, - - /// How many texts to skip - #[clap(long, short = 's', default_value = "0")] - skip: usize, } -#[derive(Debug, Clone)] -pub struct Config { - /// Number of tokens - pub vocab_size: u32, - - /// Maximum number of input tokens with positional embeddings - pub context_size: usize, - - /// Dimension of each token's embedding - pub embed_dim: usize, - - /// Number of attention heads - pub n_heads: usize, - - /// Dimension of each attn head - pub head_dim: usize, - - /// Number of transformer blocks - pub n_layers: usize, - - pub embed_drop_rate: f64, - pub attention_drop_rate: f64, - pub shortcut_drop_rate: f64, +pub struct ComputeConfig { + pub batch_size: usize, + pub eval_frac: f64, + pub eval_salt: String, } impl SampleDataArgs { pub fn run(self, _mp: Option) -> Result<()> { let device = CudaDevice::new(0); - - let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?; + //let device = WgpuDevice::DiscreteGpu(0); let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; let tokenizer: Tokenizer = serde_json::from_reader(tokenizer).context("while loading tokenizer")?; - let config = Config { + let ccfg = ComputeConfig { + batch_size: 10, + eval_frac: 0.1, + eval_salt: "salt".into(), + }; + + let mcfg = GptModelConfig { vocab_size: tokenizer.vocab_size(), - context_size: 4, + context_size: 256, embed_dim: 768, n_heads: 12, head_dim: 64, // = 768 / 12 - n_layers: 12, + n_layers: 1, embed_drop_rate: 0.1, attention_drop_rate: 0.1, shortcut_drop_rate: 0.1, }; - let stride = config.context_size; - - let batch_size = 10; - let mut input_batch = Vec::with_capacity(batch_size); - let mut output_batch = Vec::with_capacity(batch_size); - - #[expect(clippy::unwrap_used)] // Lazy error handling - let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n); - - let model = GptModel::new(&config, &device); - - // Text generation routine + let mut model: GptModel> = mcfg.init(&device); /* - { - let init = "Initial context. This is "; - let tokens = tokenizer.encode(&init); + let loader_train = DataLoaderBuilder::new(batcher.clone()) + .batch_size(ccfg.batch_size) + //.shuffle(config.seed) + .num_workers(5) + .build(Loader::new(&self.data_dir).context("while initializing loader")?); - let n_tokens = tokens.len(); - let input: Array1 = Array1::from_vec(tokens); - let mut input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) - .reshape([n_tokens]); + let loader_test = DataLoaderBuilder::new(batcher) + .batch_size(ccfg.batch_size) + //.shuffle(config.seed) + .num_workers(5) + .build(Loader::new(&self.data_dir).context("while initializing loader")?); - for _ in 0..100 { - let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); - println!("{:?}", tokens); - println!("{}", tokenizer.decode(&tokens)); + let learner = LearnerBuilder::new("./tmp") + .metric_train_numeric(AccuracyMetric::new()) + .metric_valid_numeric(AccuracyMetric::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .learning_strategy(LearningStrategy::SingleDevice(device.clone())) + .num_epochs(10) + .summary() + .build(model, AdamConfig::new().init(), 1e-4); - // Crop idx to context size; - let batch = input - .clone() - .slice([0..config.context_size]) - .unsqueeze_dim(0); - - // shape: [tokens, vocab_size] - let logits = model.forward(batch).squeeze_dim::<2>(0); - - // shape: [vocab_size] - let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); - - let probs = softmax(logits, 0); // shape: [n_tokens] - let id_next = probs.argmax(0); // shape: [1] - - input = Tensor::cat(vec![input.slice([1..]), id_next], 0); - } - } + learner.fit(loader_train, loader_test); */ - for i in iter { - let tokens = tokenizer.encode(&i); + // Initialize optimizer + let mut optim = AdamConfig::new().init(); + let learning_rate = 1e-4; - // Skip small texts. - // TODO: do this better - // TODO: non-uniform batches? - if tokens.len() < config.context_size { - continue; + for epoch in 0..10 { + debug!("Running epoch {epoch}"); + + // Training phase + let mut train_loss_sum = 0.0; + let mut train_total = 0; + + for batch in + TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, false, &device) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + + // Forward pass with gradients + let output = model.forward_train(batch.inputs, batch.targets); + + train_total += output.targets.dims()[0] as i32; + train_loss_sum += output.loss.clone().into_scalar().to_f32(); + + debug!("Running backward pass"); + let grads = output.loss.backward(); + let grads = GradientsParams::from_grads(grads, &model); + + debug!("Running optimizer step"); + model = optim.step(learning_rate, model, grads); } - for (a, b) in tokens.windows(config.context_size).step_by(stride).zip( - tokens[stride..] - .windows(config.context_size) - .step_by(stride), - ) { - input_batch.push(a.to_owned()); - output_batch.push(b.to_owned()); + let mut valid_loss_sum = 0.0; + let mut valid_total = 0; - /* - let context = a; - let desired = &b[b.len() - 1..]; - println!("{context:?} -> {desired:?}"); + let mut n_eval = 0; + debug!("Evaluating batches"); - let input = tokenizer.decode(context); - let target = tokenizer.decode(desired); - println!("{input:?} -> {target:?}"); - */ + for batch in + TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, true, &device) + .context("while initializing reader")? + { + let batch = batch.context("while reading batch")?; + n_eval += batch.targets.shape()[0]; - if input_batch.len() >= batch_size { - let shape = [input_batch.len(), config.context_size]; + // Forward pass without gradients + let output = model.valid().forward_train(batch.inputs, batch.targets); - let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); - let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); - - #[expect(clippy::unwrap_used)] - let input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) - .reshape(shape); - - let output = - std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); - let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); - - #[expect(clippy::unwrap_used)] - let output: Tensor = - Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device) - .reshape(shape); - - self.batch(&config, input, &model); - } + valid_total += output.targets.dims()[0] as i32; + valid_loss_sum += output.loss.into_scalar().to_f32(); } - } - if !input_batch.is_empty() { - let shape = [input_batch.len(), config.context_size]; + // Compute and log epoch results + let train_loss = if train_total > 0 { + train_loss_sum / train_total as f32 + } else { + 0.0 + }; + let valid_loss = if valid_total > 0 { + valid_loss_sum / valid_total as f32 + } else { + 0.0 + }; - let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); - let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); - - #[expect(clippy::unwrap_used)] - let input: Tensor = - Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape); - - let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); - let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); - - #[expect(clippy::unwrap_used)] - let output: Tensor = - Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape); - - self.batch(&config, input, &model); + info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval); } Ok(()) } - - fn batch(&self, _cfg: &Config, input: Tensor, model: &GptModel) { - let logits = model.forward(input); - println!("{:?}", logits.shape()); - } } +// +// MARK: model +// + /// Multihead attention. /// /// Equivalent to many stacked CausalAttention layers. @@ -315,7 +453,7 @@ impl MultiheadAttention { }, device.clone(), true, - [embedding_dim, total_dim].into(), + [total_dim, total_dim].into(), ), dropout: Dropout { prob: dropout }, @@ -389,6 +527,7 @@ impl MultiheadAttention { let mask = self .utri_mask .clone() + .slice([0..tokens, 0..tokens]) .unsqueeze_dim::<3>(0) .unsqueeze_dim::<4>(0) .expand(attn_scores.shape()); @@ -422,35 +561,50 @@ impl MultiheadAttention { } } -#[derive(Module, Debug)] -pub struct GptModel { - embedder_tok: Embedding, - embedder_pos: Embedding, - embedder_drop: Dropout, +#[derive(Config, Debug)] +pub struct GptModelConfig { + /// Number of tokens + pub vocab_size: u32, - trf_blocks: Vec>, - final_norm: LayerNorm, - out_head: Param>, + /// Maximum number of input tokens with positional embeddings + pub context_size: usize, + + /// Dimension of each token's embedding + pub embed_dim: usize, + + /// Number of attention heads + pub n_heads: usize, + + /// Dimension of each attn head + pub head_dim: usize, + + /// Number of transformer blocks + pub n_layers: usize, + + pub embed_drop_rate: f64, + pub attention_drop_rate: f64, + pub shortcut_drop_rate: f64, } -impl GptModel { - pub fn new(cfg: &Config, device: &B::Device) -> Self { - let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize]; +impl GptModelConfig { + pub fn init(&self, device: &B::Device) -> GptModel { + let out_head_shape = [self.embed_dim, self.vocab_size as usize]; - Self { - embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device), + GptModel { + embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim) + .init(device), - embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device), + embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device), embedder_drop: Dropout { - prob: cfg.embed_drop_rate, + prob: self.embed_drop_rate, }, - trf_blocks: (0..cfg.n_layers) - .map(|_| TransformerBlock::new(cfg, device)) + trf_blocks: (0..self.n_layers) + .map(|_| TransformerBlock::new(&self, device)) .collect(), - final_norm: LayerNormConfig::new(cfg.embed_dim).init(device), + final_norm: LayerNormConfig::new(self.embed_dim).init(device), out_head: Param::uninitialized( ParamId::new(), @@ -464,13 +618,34 @@ impl GptModel { ), } } +} +#[derive(Debug, Clone)] +pub struct TrainBatch { + pub inputs: Tensor, + + /// Correct next token for each input + pub targets: Tensor, +} + +#[derive(Module, Debug)] +pub struct GptModel { + embedder_tok: Embedding, + embedder_pos: Embedding, + embedder_drop: Dropout, + + trf_blocks: Vec>, + final_norm: LayerNorm, + out_head: Param>, +} + +impl GptModel { pub fn forward(&self, input: Tensor) -> Tensor { let n_tokens = input.shape()[1]; let embed_tok = self.embedder_tok.forward(input.clone()); let embed_pos = self - .embedder_tok + .embedder_pos .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); let x = embed_tok + embed_pos; @@ -481,6 +656,29 @@ impl GptModel { return logits; } + + pub fn forward_train( + &self, + inputs: Tensor, + targets: Tensor, + ) -> ClassificationOutput { + // shape: [batch, n_tokens, n_vocabulary] + let output = self.forward(inputs); + + // Get last token + // shape: [batch, n_vocabulary] + let output = output.slice_dim(1, -1).squeeze_dim::<2>(1); + + let loss = CrossEntropyLossConfig::new() + .init(&targets.device()) + .forward(output.clone(), targets.clone()); + + ClassificationOutput { + loss, + output, + targets, + } + } } #[derive(Module, Debug)] @@ -498,7 +696,7 @@ pub struct TransformerBlock { } impl TransformerBlock { - pub fn new(cfg: &Config, device: &B::Device) -> Self { + pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self { Self { attention: MultiheadAttention::new( cfg.embed_dim, diff --git a/crates/llmfs/src/data_reader.rs b/crates/llmfs/src/data_reader.rs index 7e1188d..66dfc5f 100644 --- a/crates/llmfs/src/data_reader.rs +++ b/crates/llmfs/src/data_reader.rs @@ -25,10 +25,11 @@ pub enum DataReaderError { /// /// All parquet files have exactly one text column. /// No promises about this struct's behavior if this is not the case. +#[derive(Clone)] pub struct DataReader { - rx: Receiver>, + rx: Arc>>>, total_rows: usize, - consumed_rows: AtomicUsize, + consumed_rows: Arc, } impl DataReader { @@ -57,6 +58,15 @@ impl DataReader { files.push(path); } } + + files.sort_by_key(|a| { + a.file_name() + .map(|x| x.to_str()) + .flatten() + .unwrap_or("") + .to_owned() + }); + (Arc::new(Mutex::new(files)), total_rows) }; @@ -147,9 +157,9 @@ impl DataReader { }); Ok(Self { - rx, + rx: Arc::new(Mutex::new(rx)), total_rows, - consumed_rows: AtomicUsize::new(0), + consumed_rows: Arc::new(AtomicUsize::new(0)), }) } @@ -157,7 +167,7 @@ impl DataReader { /// Order is arbitrary. /// Returns `None` when all rows have been read. pub fn recv(&self) -> Option> { - self.rx.recv().ok() + self.rx.lock().recv().ok() } //pub fn try_recv(&self) -> Result, TryRecvError> { diff --git a/crates/llmfs/src/logging.rs b/crates/llmfs/src/logging.rs index 7660b9e..9b07eaf 100644 --- a/crates/llmfs/src/logging.rs +++ b/crates/llmfs/src/logging.rs @@ -12,18 +12,16 @@ use tracing_subscriber::{ // MARK: loglevel // -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)] pub enum LogLevel { Trace, Debug, #[default] - Info, + Info, Warn, Error, } - impl Display for LogLevel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -71,7 +69,7 @@ impl From for EnvFilter { // // Bins // - format!("nanochat_rs={}", conf.nanochat), + format!("llmfs={}", conf.nanochat), conf.other.to_string(), ] .join(","), @@ -216,16 +214,14 @@ pub enum LoggingTarget { } /// How to print logs -#[derive(Debug, Clone, Copy, Deserialize)] -#[derive(Default)] +#[derive(Debug, Clone, Copy, Deserialize, Default)] pub enum LoggingFormat { #[default] - Ansi, + Ansi, AnsiNoColor, Json, } - pub struct LoggingInitializer { /// Log filter for printed logs pub preset: LogFilterPreset, diff --git a/crates/llmfs/src/main.rs b/crates/llmfs/src/main.rs index 2297c77..e92596e 100644 --- a/crates/llmfs/src/main.rs +++ b/crates/llmfs/src/main.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use clap::Parser; use indicatif::MultiProgress; use tracing::error;