From e29b25c16298986d8e4141cc0567794113a63b3d Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Sat, 13 Dec 2025 11:26:26 -0800 Subject: [PATCH] Chapter 4: full model --- crates/llmfs/src/command/sample_data.rs | 291 ++++++++++++++++++++---- 1 file changed, 243 insertions(+), 48 deletions(-) diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs index 4820ab1..01b7662 100644 --- a/crates/llmfs/src/command/sample_data.rs +++ b/crates/llmfs/src/command/sample_data.rs @@ -3,7 +3,10 @@ use burn::{ Tensor, backend::{Cuda, cuda::CudaDevice}, module::{Module, Param, ParamId}, - nn::{Dropout, Embedding, EmbeddingConfig}, + nn::{ + Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, + transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, + }, prelude::Backend, tensor::{Bool, Distribution, Int, activation::softmax}, }; @@ -35,26 +38,54 @@ pub struct SampleDataArgs { skip: usize, } +#[derive(Debug, Clone)] +pub struct Config { + /// Number of tokens + pub vocab_size: u32, + + /// Maximum number of input tokens with positional embeddings + pub context_size: usize, + + /// Dimension of each token's embedding + pub embed_dim: usize, + + /// Number of attention heads + pub n_heads: usize, + + /// Dimension of each attn head + pub head_dim: usize, + + /// Number of transformer blocks + pub n_layers: usize, + + pub embed_drop_rate: f64, + pub attention_drop_rate: f64, + pub shortcut_drop_rate: f64, +} + impl SampleDataArgs { pub fn run(self, _mp: Option) -> Result<()> { let device = CudaDevice::new(0); let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?; - let tokenizer = File::open(self.tokenizer).context("while opening tokenizer")?; + let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; let tokenizer: Tokenizer = serde_json::from_reader(tokenizer).context("while loading tokenizer")?; - let context_size = 4; - let stride = 4; + let config = Config { + vocab_size: tokenizer.vocab_size(), + context_size: 4, + embed_dim: 768, + n_heads: 12, + head_dim: 64, // = 768 / 12 + n_layers: 12, + embed_drop_rate: 0.1, + attention_drop_rate: 0.1, + shortcut_drop_rate: 0.1, + }; - // Dimension of each token vector - let embedding_dim = 3; - // attention - let head_dim = 2; - let n_heads = 2; - - let dropout = 0.5f64; + let stride = config.context_size; let batch_size = 10; let mut input_batch = Vec::with_capacity(batch_size); @@ -63,42 +94,64 @@ impl SampleDataArgs { #[expect(clippy::unwrap_used)] // Lazy error handling let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n); - // TODO: what is this? - let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim); - let tok_embedder: Embedding = tok_embedder.init(&device); + let model = GptModel::new(&config, &device); - // TODO: do we want to train this? - let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim); - let pos_embedder: Embedding = pos_embedder.init(&device); + // Text generation routine - let pos_tensor: Tensor = - Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0); + /* + { + let init = "Initial context. This is "; + let tokens = tokenizer.encode(&init); - // [context_size, dim] - let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0); + let n_tokens = tokens.len(); + let input: Array1 = Array1::from_vec(tokens); + let mut input: Tensor = + Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) + .reshape([n_tokens]); - let attention: MultiheadAttention = MultiheadAttention::new( - embedding_dim, - head_dim, - n_heads, - context_size, - dropout, - &device, - ); + for _ in 0..100 { + let tokens: Vec = input.clone().to_data().convert::().into_vec().unwrap(); + println!("{:?}", tokens); + println!("{}", tokenizer.decode(&tokens)); + + // Crop idx to context size; + let batch = input + .clone() + .slice([0..config.context_size]) + .unsqueeze_dim(0); + + // shape: [tokens, vocab_size] + let logits = model.forward(batch).squeeze_dim::<2>(0); + + // shape: [vocab_size] + let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0); + + let probs = softmax(logits, 0); // shape: [n_tokens] + let id_next = probs.argmax(0); // shape: [1] + + input = Tensor::cat(vec![input.slice([1..]), id_next], 0); + } + } + */ for i in iter { let tokens = tokenizer.encode(&i); - for (a, b) in tokens - .windows(context_size) - .step_by(stride) - .zip(tokens[stride..].windows(context_size).step_by(stride)) - { + // Skip small texts. + // TODO: do this better + // TODO: non-uniform batches? + if tokens.len() < config.context_size { + continue; + } + + for (a, b) in tokens.windows(config.context_size).step_by(stride).zip( + tokens[stride..] + .windows(config.context_size) + .step_by(stride), + ) { input_batch.push(a.to_owned()); output_batch.push(b.to_owned()); - // TODO: non-uniform batches? - /* let context = a; let desired = &b[b.len() - 1..]; @@ -109,12 +162,13 @@ impl SampleDataArgs { println!("{input:?} -> {target:?}"); */ - // TODO: last batch if input_batch.len() >= batch_size { - let shape = [batch_size, context_size]; + let shape = [input_batch.len(), config.context_size]; let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); + + #[expect(clippy::unwrap_used)] let input: Tensor = Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) .reshape(shape); @@ -122,25 +176,44 @@ impl SampleDataArgs { let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); + + #[expect(clippy::unwrap_used)] let output: Tensor = Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device) .reshape(shape); - // Input token embeddings - // dim: [batch, token, dim] - let tok_e = tok_embedder.forward(input); - let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0)); - - // Attention - - // shape: [batch, tokens, out_dim] - let a = attention.forward(tok_e); + self.batch(&config, input, &model); } } } + if !input_batch.is_empty() { + let shape = [input_batch.len(), config.context_size]; + + let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); + let input: Array2 = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); + + #[expect(clippy::unwrap_used)] + let input: Tensor = + Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape); + + let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); + let output: Array2 = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); + + #[expect(clippy::unwrap_used)] + let output: Tensor = + Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape); + + self.batch(&config, input, &model); + } + Ok(()) } + + fn batch(&self, _cfg: &Config, input: Tensor, model: &GptModel) { + let logits = model.forward(input); + println!("{:?}", logits.shape()); + } } /// Multihead attention. @@ -247,7 +320,7 @@ impl MultiheadAttention { dropout: Dropout { prob: dropout }, - utri_mask: Tensor::::tril_mask([context_length, context_length], 0, &device), + utri_mask: Tensor::::tril_mask([context_length, context_length], 0, device), } } @@ -348,3 +421,125 @@ impl MultiheadAttention { return context_vec; } } + +#[derive(Module, Debug)] +pub struct GptModel { + embedder_tok: Embedding, + embedder_pos: Embedding, + embedder_drop: Dropout, + + trf_blocks: Vec>, + final_norm: LayerNorm, + out_head: Param>, +} + +impl GptModel { + pub fn new(cfg: &Config, device: &B::Device) -> Self { + let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize]; + + Self { + embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device), + + embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device), + + embedder_drop: Dropout { + prob: cfg.embed_drop_rate, + }, + + trf_blocks: (0..cfg.n_layers) + .map(|_| TransformerBlock::new(cfg, device)) + .collect(), + + final_norm: LayerNormConfig::new(cfg.embed_dim).init(device), + + out_head: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random(out_head_shape, Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + out_head_shape.into(), + ), + } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let n_tokens = input.shape()[1]; + + let embed_tok = self.embedder_tok.forward(input.clone()); + let embed_pos = self + .embedder_tok + .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); + + let x = embed_tok + embed_pos; + let x = self.embedder_drop.forward(x); + let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x)); + let x = self.final_norm.forward(x); + let logits = x.matmul(self.out_head.val().unsqueeze_dim(0)); + + return logits; + } +} + +#[derive(Module, Debug)] +pub struct TransformerBlock { + attention: MultiheadAttention, + + /// TODO: wtf? + ff: PositionWiseFeedForward, + + /// TODO: wtf? + norm_a: LayerNorm, + norm_b: LayerNorm, + + drop_shortcut: Dropout, +} + +impl TransformerBlock { + pub fn new(cfg: &Config, device: &B::Device) -> Self { + Self { + attention: MultiheadAttention::new( + cfg.embed_dim, + cfg.head_dim, + cfg.n_heads, + cfg.context_size, + cfg.attention_drop_rate, + device, + ), + + ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim) + .with_dropout(0.0) + .init(device), + + norm_a: LayerNormConfig::new(cfg.embed_dim).init(device), + norm_b: LayerNormConfig::new(cfg.embed_dim).init(device), + + drop_shortcut: Dropout { + prob: cfg.shortcut_drop_rate, + }, + } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let input = { + let shortcut = input.clone(); + let x = self.norm_a.forward(input); + let x = self.attention.forward(x); + let x = self.drop_shortcut.forward(x); + x + shortcut + }; + + let input = { + // TODO: wtf? + let shortcut = input.clone(); + let x = self.norm_b.forward(input); + let x = self.ff.forward(x); + let x = self.drop_shortcut.forward(x); + x + shortcut + }; + + return input; + } +}