From db4c76c76b0e9ef8179b245968a62492bf26c446 Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Sat, 13 Dec 2025 08:42:22 -0800 Subject: [PATCH] Chapter 3: multhead attention --- crates/llmfs/src/command/sample_data.rs | 136 +++++++++++++++++------- 1 file changed, 97 insertions(+), 39 deletions(-) diff --git a/crates/llmfs/src/command/sample_data.rs b/crates/llmfs/src/command/sample_data.rs index 740c9cf..4820ab1 100644 --- a/crates/llmfs/src/command/sample_data.rs +++ b/crates/llmfs/src/command/sample_data.rs @@ -50,8 +50,9 @@ impl SampleDataArgs { // Dimension of each token vector let embedding_dim = 3; - // self-attention output dim - let sa_dim_out = 2; + // attention + let head_dim = 2; + let n_heads = 2; let dropout = 0.5f64; @@ -76,8 +77,14 @@ impl SampleDataArgs { // [context_size, dim] let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0); - let attention: CausalAttentionv1 = - CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device); + let attention: MultiheadAttention = MultiheadAttention::new( + embedding_dim, + head_dim, + n_heads, + context_size, + dropout, + &device, + ); for i in iter { let tokens = tokenizer.encode(&i); @@ -124,19 +131,7 @@ impl SampleDataArgs { let tok_e = tok_embedder.forward(input); let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0)); - /* - // simple self-attention - - // dim: [batch, query token, other token] - let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose()); - let attention_score = softmax(attention_scores, 1); - - // context vectors for each input token - // dim: [batch, token, dim] - let context_vectors = attention_score.matmul(tok_e.clone()); - */ - - // Trainable self-attention + // Attention // shape: [batch, tokens, out_dim] let a = attention.forward(tok_e); @@ -148,31 +143,42 @@ impl SampleDataArgs { } } +/// Multihead attention. +/// +/// Equivalent to many stacked CausalAttention layers. +/// These are packed inside one big tensor for efficiency. #[derive(Module, Debug)] -pub struct CausalAttentionv1 { +pub struct MultiheadAttention { + n_heads: usize, + head_dim: usize, + // Can also use Linear layers with disabled bias // (they may also have a better initialization routine) // TODO: see source code, make this equivalent - /// Query weight matrix. - /// Maps [tokens, dim] into [tokens, inner_dim]. + /// Query weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) /// /// Intuitively, this learns "what question to ask about the text" /// for a given query token. (e.g, "it" -> what does "it" refer to?) w_query: Param>, - /// Key weight matrix. - /// Maps [tokens, dim] into [tokens, inner_dim]. + /// Key weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) /// /// Intuitively, this learns what properties a certain token /// has when it appears as a context (non-query) token. w_key: Param>, - /// Value weight matrix. - /// Maps [tokens, dim] into [tokens, inner_dim]. + /// Value weight matrices for each head, stacked on the last dimension. + /// (so that shape is [tokens, n_heads * head_dim]) /// /// Intuitively, ??? w_value: Param>, + /// Optional final projection. + /// Maps [total_dim, total_dim] to [total_dim, total_dim] + w_output: Param>, + dropout: Dropout, /// Upper-triangular matrix of ones, excluding diagonal. @@ -180,46 +186,63 @@ pub struct CausalAttentionv1 { utri_mask: Tensor, } -impl CausalAttentionv1 { +impl MultiheadAttention { pub fn new( embedding_dim: usize, - out_dim: usize, + head_dim: usize, + n_heads: usize, context_length: usize, dropout: f64, device: &B::Device, ) -> Self { + let total_dim = head_dim * n_heads; + Self { + n_heads, + head_dim, + w_query: Param::uninitialized( ParamId::new(), move |device, is_require_grad| { - Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) .set_require_grad(is_require_grad) }, device.clone(), true, - [embedding_dim, out_dim].into(), + [embedding_dim, total_dim].into(), ), w_key: Param::uninitialized( ParamId::new(), move |device, is_require_grad| { - Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) .set_require_grad(is_require_grad) }, device.clone(), true, - [embedding_dim, out_dim].into(), + [embedding_dim, total_dim].into(), ), w_value: Param::uninitialized( ParamId::new(), move |device, is_require_grad| { - Tensor::random([embedding_dim, out_dim], Distribution::Default, device) + Tensor::random([embedding_dim, total_dim], Distribution::Default, device) .set_require_grad(is_require_grad) }, device.clone(), true, - [embedding_dim, out_dim].into(), + [embedding_dim, total_dim].into(), + ), + + w_output: Param::uninitialized( + ParamId::new(), + move |device, is_require_grad| { + Tensor::random([total_dim, total_dim], Distribution::Default, device) + .set_require_grad(is_require_grad) + }, + device.clone(), + true, + [embedding_dim, total_dim].into(), ), dropout: Dropout { prob: dropout }, @@ -231,12 +254,15 @@ impl CausalAttentionv1 { /// Compute self-attention vector for the given batch /// /// - input shape is [batch, token, token_dim] - /// - input shape is [batch, token, attn_dim] + /// - input shape is [batch, token, n_heads * head_dim] pub fn forward(&self, input: Tensor) -> Tensor { // Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok) // But adds an "inner latent space" using Wq, Qk, and Wv. + // + // Multiple heads are batched into one tensor. let batch = input.dims()[0]; + let tokens = input.dims()[1]; let w_query = self .w_query @@ -256,21 +282,42 @@ impl CausalAttentionv1 { .unsqueeze_dim::<3>(0) .expand([batch as i64, -1, -1]); + let w_output = self + .w_output + .val() + .unsqueeze_dim::<3>(0) + .expand([batch as i64, -1, -1]); + // Map batch to inner latent space. // shape: [batch, token, inner_dim] let queries = input.clone().matmul(w_query); let keys = input.clone().matmul(w_key); let values = input.clone().matmul(w_value); - // Compute attention scores - // (cosine similarity of each query token to each context token) - // shape: [batch, query_token, context_token] - let attn_scores = queries.matmul(keys.clone().transpose()); + // Split head dimensions + let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]); + let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]); + let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]); + + // from: [batch, tok, head, head_dim] + // to: [batch, head, tok, head_dim] + let keys = keys.swap_dims(1, 2); + let values = values.swap_dims(1, 2); + let queries = queries.swap_dims(1, 2); + + // Compute attention scores for each head + // (cosine similarity of each query token to each context token, per head) + // + // lhs shape: [batch, head, tok, head_dim] + // rhs shape: [batch, head, head_dim, tok] + // output shape: [batch, head, query_token, context_token] + let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3)); let mask = self .utri_mask .clone() .unsqueeze_dim::<3>(0) + .unsqueeze_dim::<4>(0) .expand(attn_scores.shape()); // Mask out future tokens by filling @@ -283,10 +330,21 @@ impl CausalAttentionv1 { // - dot products get larger with larger dimensions // - this causes softmax to "saturate", making all other values very small // - which makes gradients vanish during training - let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2); + let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3); let attn_weights = self.dropout.forward(attn_weights); - let context_vec = attn_weights.matmul(values); + // lhs shape: [batch, head, query_token, context_token] + // rhs shape: [batch, head, tok, head_dim] + // matmul shape: [batch, head, tok, head_dim] + // out shape: [batch, tok, head, head_dim] + let context_vec = attn_weights.matmul(values).swap_dims(1, 2); + + // shape: [batch, tok, stacked_dim] + let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]); + + // Apply final projection (optional) + let context_vec = context_vec.matmul(w_output); + return context_vec; } }