1
0

TMP
Some checks failed
CI / Check links (push) Successful in 6s
CI / Check typos (push) Successful in 10s
CI / Clippy (push) Failing after 1m23s
CI / Build and test (push) Failing after 2m12s

This commit is contained in:
2025-12-14 13:49:22 -08:00
parent e29b25c162
commit 437c1fbf82
8 changed files with 590 additions and 184 deletions

193
Cargo.lock generated
View File

@@ -451,6 +451,7 @@ dependencies = [
"ahash",
"bincode",
"burn-common",
"burn-dataset",
"burn-derive",
"burn-tensor",
"data-encoding",
@@ -546,6 +547,25 @@ dependencies = [
"log",
]
[[package]]
name = "burn-dataset"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "534d4398fd6aaec32f8caeb3f20ddffcd8a059bdefc01cc2794b91b4e984e8ea"
dependencies = [
"csv",
"derive-new",
"dirs",
"rand",
"rmp-serde",
"sanitize-filename",
"serde",
"serde_json",
"strum",
"tempfile",
"thiserror 2.0.17",
]
[[package]]
name = "burn-derive"
version = "0.19.1"
@@ -598,8 +618,10 @@ dependencies = [
"burn-common",
"burn-ir",
"burn-tensor",
"bytemuck",
"const-random",
"derive-new",
"itertools 0.14.0",
"libm",
"macerator",
"matrixmultiply",
@@ -608,6 +630,7 @@ dependencies = [
"paste",
"portable-atomic-util",
"rand",
"seq-macro",
"spin",
]
@@ -703,6 +726,25 @@ dependencies = [
"serde_bytes",
]
[[package]]
name = "burn-train"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0f1553197d50668823a4bafc187c62439df49b218973f0ca79e034b57ce38d6"
dependencies = [
"async-channel",
"burn-core",
"burn-ndarray",
"burn-optim",
"derive-new",
"log",
"rstest",
"serde",
"tracing-appender",
"tracing-core",
"tracing-subscriber",
]
[[package]]
name = "burn-wgpu"
version = "0.19.1"
@@ -1057,6 +1099,15 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
@@ -1098,6 +1149,27 @@ dependencies = [
"typenum",
]
[[package]]
name = "csv"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938"
dependencies = [
"csv-core",
"itoa",
"ryu",
"serde_core",
]
[[package]]
name = "csv-core"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782"
dependencies = [
"memchr",
]
[[package]]
name = "cubecl"
version = "0.8.1"
@@ -1573,6 +1645,15 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "deranged"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587"
dependencies = [
"powerfmt",
]
[[package]]
name = "derive-new"
version = "0.7.0"
@@ -2064,6 +2145,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
version = "3.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
[[package]]
name = "futures-util"
version = "0.3.31"
@@ -3044,15 +3131,18 @@ checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
name = "llmfs"
version = "0.0.1"
dependencies = [
"ahash",
"anstyle",
"anyhow",
"burn",
"burn-train",
"clap",
"futures-util",
"indicatif",
"ndarray",
"parking_lot",
"parquet",
"rand",
"rayon",
"reqwest",
"serde",
@@ -3149,7 +3239,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
dependencies = [
"autocfg",
"num_cpus",
"once_cell",
"rawpointer",
"thread-tree",
]
[[package]]
@@ -3299,6 +3392,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"rawpointer",
"rayon",
"serde",
]
@@ -3379,6 +3473,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
@@ -3689,6 +3789,12 @@ dependencies = [
"zerovec",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
@@ -3994,6 +4100,12 @@ version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]]
name = "relative-path"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2"
[[package]]
name = "renderdoc-sys"
version = "1.1.0"
@@ -4084,6 +4196,35 @@ dependencies = [
"serde",
]
[[package]]
name = "rstest"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5a3193c063baaa2a95a33f03035c8a72b83d97a54916055ba22d35ed3839d49"
dependencies = [
"futures-timer",
"futures-util",
"rstest_macros",
]
[[package]]
name = "rstest_macros"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c845311f0ff7951c5506121a9ad75aec44d083c31583b2ea5a30bcb0b0abba0"
dependencies = [
"cfg-if",
"glob",
"proc-macro-crate",
"proc-macro2",
"quote",
"regex",
"relative-path",
"rustc_version",
"syn",
"unicode-ident",
]
[[package]]
name = "rustc-hash"
version = "1.1.0"
@@ -4672,6 +4813,15 @@ dependencies = [
"syn",
]
[[package]]
name = "thread-tree"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630"
dependencies = [
"crossbeam-channel",
]
[[package]]
name = "thread_local"
version = "1.1.9"
@@ -4692,6 +4842,37 @@ dependencies = [
"ordered-float 2.10.1",
]
[[package]]
name = "time"
version = "0.3.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b"
[[package]]
name = "time-macros"
version = "0.2.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3"
dependencies = [
"num-conv",
"time-core",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
@@ -4989,6 +5170,18 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "tracing-appender"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf"
dependencies = [
"crossbeam-channel",
"thiserror 2.0.17",
"time",
"tracing-subscriber",
]
[[package]]
name = "tracing-attributes"
version = "0.1.31"

View File

@@ -75,10 +75,12 @@ compact_str = "0.9.0"
dary_heap = "0.3.8"
fancy-regex = "0.16.2"
indicatif = { version = "0.18.3", features = ["improved_unicode"] }
itertools = "0.14.0"
futures-util = "0.3.31"
ndarray = { version = "0.16.1", features = ["serde"] }
parking_lot = "0.12.5"
parquet = "56.2.0"
rand = "0.9.2"
rayon = "1.11.0"
reqwest = { version = "0.12.24", features = ["json", "stream"] }
serde = "1.0.228"
@@ -91,7 +93,10 @@ tracing-indicatif = "0.3.13"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
url = "2.5.7"
burn-train = { version = "0.19.1", default-features = false }
[workspace.dependencies.burn]
version = "0.19.1"
default-features = false
features = ["std", "fusion", "ndarray", "webgpu", "cuda"]
features = ["std", "fusion", "ndarray", "webgpu", "cuda", "autodiff"]

View File

@@ -10,15 +10,18 @@ workspace = true
[dependencies]
tokenizer = { workspace = true }
ahash = { workspace = true }
anstyle = { workspace = true }
anyhow = { workspace = true }
burn = { workspace = true }
burn-train = { workspace = true }
clap = { workspace = true }
futures-util = { workspace = true }
indicatif = { workspace = true }
ndarray = { workspace = true }
parking_lot = { workspace = true }
parquet = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }

View File

@@ -15,7 +15,6 @@ pub enum SubCommand {
#[command(flatten)]
args: train_tokenizer::TrainTokenizerArgs,
},
/// Sample data
SampleData {
#[command(flatten)]

View File

@@ -1,22 +1,206 @@
use ahash::AHasher;
use anyhow::{Context, Result};
use burn::{
Tensor,
backend::{Cuda, cuda::CudaDevice},
module::{Module, Param, ParamId},
backend::{Autodiff, Cuda, Wgpu, cuda::CudaDevice, wgpu::WgpuDevice},
config::Config,
module::{AutodiffModule, Module, Param, ParamId},
nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
loss::CrossEntropyLossConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
},
prelude::Backend,
optim::{AdamConfig, GradientsParams, Optimizer},
prelude::{Backend, ToElement},
tensor::{Bool, Distribution, Int, activation::softmax},
};
use burn_train::ClassificationOutput;
use clap::Args;
use indicatif::MultiProgress;
use ndarray::Array2;
use std::{f32, fs::File, path::PathBuf};
use indicatif::{MultiProgress, ProgressIterator};
use ndarray::{Array1, Array2};
use std::{
collections::VecDeque,
f32,
fs::File,
hash::Hasher,
path::{Path, PathBuf},
};
use tokenizer::Tokenizer;
use tracing::{debug, info};
use crate::data_reader::DataReader;
use crate::data_reader::{DataReader, DataReaderError};
// Text generation routine
/*
{
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
let n_tokens = tokens.len();
let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
for _ in 0..100 {
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
println!("{:?}", tokens);
println!("{}", tokenizer.decode(&tokens));
// Crop idx to context size;
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/
struct TrainTestIterator<'a, B: Backend> {
reader: DataReader,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
error: bool,
// Tokenized input/output pairs
pairs: VecDeque<(Vec<u32>, u32)>,
}
impl<'a, B: Backend> TrainTestIterator<'a, B> {
pub fn new(
data_dir: impl AsRef<Path>,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
) -> Result<Self, std::io::Error> {
let reader = DataReader::new(3, data_dir)?;
Ok(Self {
reader,
ccfg,
mcfg,
tokenizer,
eval,
device,
error: false,
pairs: VecDeque::new(),
})
}
}
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
type Item = Result<TrainBatch<B>, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
if self.error {
return None;
}
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
let stride = self.mcfg.context_size;
while inputs.len() < self.ccfg.batch_size {
match self.pairs.pop_front() {
Some((i, t)) => {
// train/test split
{
let mut hasher = AHasher::default();
hasher.write(self.ccfg.eval_salt.as_bytes());
// Don't care about endianness, ahash output is unstable anyway
hasher.write(unsafe { std::mem::transmute(&i[..]) });
hasher.write_u32(t);
let test = // is this point in the test set?
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
if test ^ self.eval {
continue;
}
}
inputs.push(i);
targets.push(t);
}
None => {
let text = match self.reader.next() {
None => break,
Some(Ok(x)) => x,
Some(Err(x)) => {
self.error = true;
return Some(Err(x));
}
};
let emb = self.tokenizer.encode(&text);
// Skip small texts
//
// TODO: do this better
// TODO: maybe using <|bos|>?
// TODO: non-uniform batches?
if emb.len() < self.mcfg.context_size {
continue;
}
let pairs = emb
.windows(self.mcfg.context_size + 1)
.step_by(stride)
.map(|x| {
(
x[..self.mcfg.context_size].to_vec(),
x[self.mcfg.context_size],
)
});
self.pairs.extend(pairs);
}
}
}
if inputs.is_empty() {
return None;
}
let shape = [inputs.len(), self.mcfg.context_size];
// Arrange data in memory
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
let targets: Array1<u32> = Array1::from_vec(targets);
// Create tensors on gpu
#[expect(clippy::unwrap_used)]
let inputs =
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
#[expect(clippy::unwrap_used)]
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
return Some(Ok(TrainBatch { inputs, targets }));
}
}
#[derive(Debug, Args, Clone)]
@@ -32,190 +216,144 @@ pub struct SampleDataArgs {
/// How many texts to return
#[clap(long, short = 'n', default_value = "10")]
n: usize,
/// How many texts to skip
#[clap(long, short = 's', default_value = "0")]
skip: usize,
}
#[derive(Debug, Clone)]
pub struct Config {
/// Number of tokens
pub vocab_size: u32,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
pub struct ComputeConfig {
pub batch_size: usize,
pub eval_frac: f64,
pub eval_salt: String,
}
impl SampleDataArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0);
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
//let device = WgpuDevice::DiscreteGpu(0);
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let config = Config {
let ccfg = ComputeConfig {
batch_size: 10,
eval_frac: 0.1,
eval_salt: "salt".into(),
};
let mcfg = GptModelConfig {
vocab_size: tokenizer.vocab_size(),
context_size: 4,
context_size: 256,
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 12,
n_layers: 1,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
};
let stride = config.context_size;
let batch_size = 10;
let mut input_batch = Vec::with_capacity(batch_size);
let mut output_batch = Vec::with_capacity(batch_size);
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
let model = GptModel::new(&config, &device);
// Text generation routine
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
/*
{
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
let loader_train = DataLoaderBuilder::new(batcher.clone())
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let n_tokens = tokens.len();
let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
let loader_test = DataLoaderBuilder::new(batcher)
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
for _ in 0..100 {
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
println!("{:?}", tokens);
println!("{}", tokenizer.decode(&tokens));
let learner = LearnerBuilder::new("./tmp")
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
.num_epochs(10)
.summary()
.build(model, AdamConfig::new().init(), 1e-4);
// Crop idx to context size;
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
learner.fit(loader_train, loader_test);
*/
for i in iter {
let tokens = tokenizer.encode(&i);
// Initialize optimizer
let mut optim = AdamConfig::new().init();
let learning_rate = 1e-4;
// Skip small texts.
// TODO: do this better
// TODO: non-uniform batches?
if tokens.len() < config.context_size {
continue;
for epoch in 0..10 {
debug!("Running epoch {epoch}");
// Training phase
let mut train_loss_sum = 0.0;
let mut train_total = 0;
for batch in
TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, false, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
// Forward pass with gradients
let output = model.forward_train(batch.inputs, batch.targets);
train_total += output.targets.dims()[0] as i32;
train_loss_sum += output.loss.clone().into_scalar().to_f32();
debug!("Running backward pass");
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
debug!("Running optimizer step");
model = optim.step(learning_rate, model, grads);
}
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
tokens[stride..]
.windows(config.context_size)
.step_by(stride),
) {
input_batch.push(a.to_owned());
output_batch.push(b.to_owned());
let mut valid_loss_sum = 0.0;
let mut valid_total = 0;
/*
let context = a;
let desired = &b[b.len() - 1..];
println!("{context:?} -> {desired:?}");
let mut n_eval = 0;
debug!("Evaluating batches");
let input = tokenizer.decode(context);
let target = tokenizer.decode(desired);
println!("{input:?} -> {target:?}");
*/
for batch in
TrainTestIterator::new(&self.data_dir, &ccfg, &mcfg, &tokenizer, true, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
n_eval += batch.targets.shape()[0];
if input_batch.len() >= batch_size {
let shape = [input_batch.len(), config.context_size];
// Forward pass without gradients
let output = model.valid().forward_train(batch.inputs, batch.targets);
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape(shape);
let output =
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
.reshape(shape);
self.batch(&config, input, &model);
}
valid_total += output.targets.dims()[0] as i32;
valid_loss_sum += output.loss.into_scalar().to_f32();
}
}
if !input_batch.is_empty() {
let shape = [input_batch.len(), config.context_size];
// Compute and log epoch results
let train_loss = if train_total > 0 {
train_loss_sum / train_total as f32
} else {
0.0
};
let valid_loss = if valid_total > 0 {
valid_loss_sum / valid_total as f32
} else {
0.0
};
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
self.batch(&config, input, &model);
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
}
Ok(())
}
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
let logits = model.forward(input);
println!("{:?}", logits.shape());
}
}
//
// MARK: model
//
/// Multihead attention.
///
/// Equivalent to many stacked CausalAttention layers.
@@ -315,7 +453,7 @@ impl<B: Backend> MultiheadAttention<B> {
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
[total_dim, total_dim].into(),
),
dropout: Dropout { prob: dropout },
@@ -389,6 +527,7 @@ impl<B: Backend> MultiheadAttention<B> {
let mask = self
.utri_mask
.clone()
.slice([0..tokens, 0..tokens])
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
.expand(attn_scores.shape());
@@ -422,35 +561,50 @@ impl<B: Backend> MultiheadAttention<B> {
}
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
#[derive(Config, Debug)]
pub struct GptModelConfig {
/// Number of tokens
pub vocab_size: u32,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
}
impl<B: Backend> GptModel<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
impl GptModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
Self {
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
GptModel {
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
.init(device),
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
embedder_drop: Dropout {
prob: cfg.embed_drop_rate,
prob: self.embed_drop_rate,
},
trf_blocks: (0..cfg.n_layers)
.map(|_| TransformerBlock::new(cfg, device))
trf_blocks: (0..self.n_layers)
.map(|_| TransformerBlock::new(&self, device))
.collect(),
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
out_head: Param::uninitialized(
ParamId::new(),
@@ -464,13 +618,34 @@ impl<B: Backend> GptModel<B> {
),
}
}
}
#[derive(Debug, Clone)]
pub struct TrainBatch<B: Backend> {
pub inputs: Tensor<B, 2, Int>,
/// Correct next token for each input
pub targets: Tensor<B, 1, Int>,
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self
.embedder_tok
.embedder_pos
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos;
@@ -481,6 +656,29 @@ impl<B: Backend> GptModel<B> {
return logits;
}
pub fn forward_train(
&self,
inputs: Tensor<B, 2, Int>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// shape: [batch, n_tokens, n_vocabulary]
let output = self.forward(inputs);
// Get last token
// shape: [batch, n_vocabulary]
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
let loss = CrossEntropyLossConfig::new()
.init(&targets.device())
.forward(output.clone(), targets.clone());
ClassificationOutput {
loss,
output,
targets,
}
}
}
#[derive(Module, Debug)]
@@ -498,7 +696,7 @@ pub struct TransformerBlock<B: Backend> {
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
Self {
attention: MultiheadAttention::new(
cfg.embed_dim,

View File

@@ -25,10 +25,11 @@ pub enum DataReaderError {
///
/// All parquet files have exactly one text column.
/// No promises about this struct's behavior if this is not the case.
#[derive(Clone)]
pub struct DataReader {
rx: Receiver<Result<String, DataReaderError>>,
rx: Arc<Mutex<Receiver<Result<String, DataReaderError>>>>,
total_rows: usize,
consumed_rows: AtomicUsize,
consumed_rows: Arc<AtomicUsize>,
}
impl DataReader {
@@ -57,6 +58,15 @@ impl DataReader {
files.push(path);
}
}
files.sort_by_key(|a| {
a.file_name()
.map(|x| x.to_str())
.flatten()
.unwrap_or("")
.to_owned()
});
(Arc::new(Mutex::new(files)), total_rows)
};
@@ -147,9 +157,9 @@ impl DataReader {
});
Ok(Self {
rx,
rx: Arc::new(Mutex::new(rx)),
total_rows,
consumed_rows: AtomicUsize::new(0),
consumed_rows: Arc::new(AtomicUsize::new(0)),
})
}
@@ -157,7 +167,7 @@ impl DataReader {
/// Order is arbitrary.
/// Returns `None` when all rows have been read.
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
self.rx.recv().ok()
self.rx.lock().recv().ok()
}
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {

View File

@@ -12,18 +12,16 @@ use tracing_subscriber::{
// MARK: loglevel
//
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
#[derive(Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)]
pub enum LogLevel {
Trace,
Debug,
#[default]
Info,
Info,
Warn,
Error,
}
impl Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -71,7 +69,7 @@ impl From<LoggingConfig> for EnvFilter {
//
// Bins
//
format!("nanochat_rs={}", conf.nanochat),
format!("llmfs={}", conf.nanochat),
conf.other.to_string(),
]
.join(","),
@@ -216,16 +214,14 @@ pub enum LoggingTarget {
}
/// How to print logs
#[derive(Debug, Clone, Copy, Deserialize)]
#[derive(Default)]
#[derive(Debug, Clone, Copy, Deserialize, Default)]
pub enum LoggingFormat {
#[default]
Ansi,
Ansi,
AnsiNoColor,
Json,
}
pub struct LoggingInitializer {
/// Log filter for printed logs
pub preset: LogFilterPreset,

View File

@@ -1,3 +1,5 @@
#![recursion_limit = "256"]
use clap::Parser;
use indicatif::MultiProgress;
use tracing::error;