From 60b63dfd2ea54987cd5f94a4ab926534de19d589 Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Fri, 12 Dec 2025 21:53:43 -0800 Subject: [PATCH] Chapter 3.5: attention --- crates/llmfs/src/command/sample_data.rs | 194 ++++++++++++++++++++++-- 1 file changed, 182 insertions(+), 12 deletions(-) diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs index e7ffd93..740c9cf 100644 --- a/crates/llmfs/src/command/sample_data.rs +++ b/crates/llmfs/src/command/sample_data.rs @@ -2,13 +2,15 @@ use anyhow::{Context, Result}; use burn::{ Tensor, backend::{Cuda, cuda::CudaDevice}, - nn::{Embedding, EmbeddingConfig}, - tensor::Int, + module::{Module, Param, ParamId}, + nn::{Dropout, Embedding, EmbeddingConfig}, + prelude::Backend, + tensor::{Bool, Distribution, Int, activation::softmax}, }; use clap::Args; use indicatif::MultiProgress; use ndarray::Array2; -use std::{fs::File, path::PathBuf}; +use std::{f32, fs::File, path::PathBuf}; use tokenizer::Tokenizer; use crate::data_reader::DataReader; @@ -47,7 +49,11 @@ impl SampleDataArgs { let stride = 4; // Dimension of each token vector - let embedding_dim = 256; + let embedding_dim = 3; + // self-attention output dim + let sa_dim_out = 2; + + let dropout = 0.5f64; let batch_size = 10; let mut input_batch = Vec::with_capacity(batch_size); @@ -60,16 +66,18 @@ impl SampleDataArgs { let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim); let tok_embedder: Embedding = tok_embedder.init(&device); + // TODO: do we want to train this? let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim); let pos_embedder: Embedding = pos_embedder.init(&device); let pos_tensor: Tensor = Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0); - // [1, context_size, dim] - let pos_embedding = pos_embedder.forward(pos_tensor); + // [context_size, dim] + let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0); - println!("{:?}", pos_embedding.shape()); + let attention: CausalAttentionv1 = + CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device); for i in iter { let tokens = tokenizer.encode(&i); @@ -82,12 +90,13 @@ impl SampleDataArgs { input_batch.push(a.to_owned()); output_batch.push(b.to_owned()); - let context = a; - let desired = &b[b.len() - 1..]; - - println!("{context:?} -> {desired:?}"); + // TODO: non-uniform batches? /* + let context = a; + let desired = &b[b.len() - 1..]; + println!("{context:?} -> {desired:?}"); + let input = tokenizer.decode(context); let target = tokenizer.decode(desired); println!("{input:?} -> {target:?}"); @@ -110,9 +119,27 @@ impl SampleDataArgs { Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device) .reshape(shape); + // Input token embeddings + // dim: [batch, token, dim] let tok_e = tok_embedder.forward(input); - let tok_e: Tensor = Tensor::from_data(tok_e.to_data(), &device); let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0)); + + /* + // simple self-attention + + // dim: [batch, query token, other token] + let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose()); + let attention_score = softmax(attention_scores, 1); + + // context vectors for each input token + // dim: [batch, token, dim] + let context_vectors = attention_score.matmul(tok_e.clone()); + */ + + // Trainable self-attention + + // shape: [batch, tokens, out_dim] + let a = attention.forward(tok_e); } } } @@ -120,3 +147,146 @@ impl SampleDataArgs { Ok(()) } } + +#[derive(Module, Debug)] +pub struct CausalAttentionv1 { + // Can also use Linear layers with disabled bias + // (they may also have a better initialization routine) + // TODO: see source code, make this equivalent + /// Query weight matrix. + /// Maps [tokens, dim] into [tokens, inner_dim]. + /// + /// Intuitively, this learns "what question to ask about the text" + /// for a given query token. (e.g, "it" -> what does "it" refer to?) + w_query: Param>, + + /// Key weight matrix. + /// Maps [tokens, dim] into [tokens, inner_dim]. + /// + /// Intuitively, this learns what properties a certain token + /// has when it appears as a context (non-query) token. + w_key: Param>, + + /// Value weight matrix. + /// Maps [tokens, dim] into [tokens, inner_dim]. + /// + /// Intuitively, ??? + w_value: Param>, + + dropout: Dropout, + + /// Upper-triangular matrix of ones, excluding diagonal. + /// Used to mask future tokens. + utri_mask: Tensor, +} + +impl CausalAttentionv1 { + pub fn new( + embedding_dim: usize, + out_dim: usize, + context_length: usize, + dropout: f64, + device: &B::Device, + ) -> Self { + Self { + w_query: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [embedding_dim, out_dim].into(), + ), + + w_key: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [embedding_dim, out_dim].into(), + ), + + w_value: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [embedding_dim, out_dim].into(), + ), + + dropout: Dropout { prob: dropout }, + + utri_mask: Tensor::::tril_mask([context_length, context_length], 0, &device), + } + } + + /// Compute self-attention vector for the given batch + /// + /// - input shape is [batch, token, token_dim] + /// - input shape is [batch, token, attn_dim] + pub fn forward(&self, input: Tensor) -> Tensor { + // Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok) + // But adds an "inner latent space" using Wq, Qk, and Wv. + + let batch = input.dims()[0]; + + let w_query = self + .w_query + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + let w_key = self + .w_key + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + let w_value = self + .w_value + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + + // Map batch to inner latent space. + // shape: [batch, token, inner_dim] + let queries = input.clone().matmul(w_query); + let keys = input.clone().matmul(w_key); + let values = input.clone().matmul(w_value); + + // Compute attention scores + // (cosine similarity of each query token to each context token) + // shape: [batch, query_token, context_token] + let attn_scores = queries.matmul(keys.clone().transpose()); + + let mask = self + .utri_mask + .clone() + .unsqueeze_dim::<3>(0) + .expand(attn_scores.shape()); + + // Mask out future tokens by filling + // upper-triangular with -inf, which becomes 0.0 after softmax. + let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY); + + // Normalize attn weights. + // + // Divide by sqrt(inner_dim) because... + // - dot products get larger with larger dimensions + // - this causes softmax to "saturate", making all other values very small + // - which makes gradients vanish during training + let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2); + let attn_weights = self.dropout.forward(attn_weights); + + let context_vec = attn_weights.matmul(values); + return context_vec; + } +}