1
0

Chapter 3: multhead attention
Some checks failed
CI / Check links (push) Successful in 7s
CI / Check typos (push) Successful in 1m6s
CI / Clippy (push) Failing after 2m12s
CI / Build and test (push) Failing after 3m6s

This commit is contained in:
2025-12-13 08:42:22 -08:00
parent 1d2a4da269
commit db4c76c76b

View File

@@ -50,8 +50,9 @@ impl SampleDataArgs {
// Dimension of each token vector // Dimension of each token vector
let embedding_dim = 3; let embedding_dim = 3;
// self-attention output dim // attention
let sa_dim_out = 2; let head_dim = 2;
let n_heads = 2;
let dropout = 0.5f64; let dropout = 0.5f64;
@@ -76,8 +77,14 @@ impl SampleDataArgs {
// [context_size, dim] // [context_size, dim]
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0); let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
let attention: CausalAttentionv1<Cuda> = let attention: MultiheadAttention<Cuda> = MultiheadAttention::new(
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device); embedding_dim,
head_dim,
n_heads,
context_size,
dropout,
&device,
);
for i in iter { for i in iter {
let tokens = tokenizer.encode(&i); let tokens = tokenizer.encode(&i);
@@ -124,19 +131,7 @@ impl SampleDataArgs {
let tok_e = tok_embedder.forward(input); let tok_e = tok_embedder.forward(input);
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0)); let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
/* // Attention
// simple self-attention
// dim: [batch, query token, other token]
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
let attention_score = softmax(attention_scores, 1);
// context vectors for each input token
// dim: [batch, token, dim]
let context_vectors = attention_score.matmul(tok_e.clone());
*/
// Trainable self-attention
// shape: [batch, tokens, out_dim] // shape: [batch, tokens, out_dim]
let a = attention.forward(tok_e); let a = attention.forward(tok_e);
@@ -148,31 +143,42 @@ impl SampleDataArgs {
} }
} }
/// Multihead attention.
///
/// Equivalent to many stacked CausalAttention layers.
/// These are packed inside one big tensor for efficiency.
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct CausalAttentionv1<B: Backend> { pub struct MultiheadAttention<B: Backend> {
n_heads: usize,
head_dim: usize,
// Can also use Linear layers with disabled bias // Can also use Linear layers with disabled bias
// (they may also have a better initialization routine) // (they may also have a better initialization routine)
// TODO: see source code, make this equivalent // TODO: see source code, make this equivalent
/// Query weight matrix. /// Query weight matrices for each head, stacked on the last dimension.
/// Maps [tokens, dim] into [tokens, inner_dim]. /// (so that shape is [tokens, n_heads * head_dim])
/// ///
/// Intuitively, this learns "what question to ask about the text" /// Intuitively, this learns "what question to ask about the text"
/// for a given query token. (e.g, "it" -> what does "it" refer to?) /// for a given query token. (e.g, "it" -> what does "it" refer to?)
w_query: Param<Tensor<B, 2>>, w_query: Param<Tensor<B, 2>>,
/// Key weight matrix. /// Key weight matrices for each head, stacked on the last dimension.
/// Maps [tokens, dim] into [tokens, inner_dim]. /// (so that shape is [tokens, n_heads * head_dim])
/// ///
/// Intuitively, this learns what properties a certain token /// Intuitively, this learns what properties a certain token
/// has when it appears as a context (non-query) token. /// has when it appears as a context (non-query) token.
w_key: Param<Tensor<B, 2>>, w_key: Param<Tensor<B, 2>>,
/// Value weight matrix. /// Value weight matrices for each head, stacked on the last dimension.
/// Maps [tokens, dim] into [tokens, inner_dim]. /// (so that shape is [tokens, n_heads * head_dim])
/// ///
/// Intuitively, ??? /// Intuitively, ???
w_value: Param<Tensor<B, 2>>, w_value: Param<Tensor<B, 2>>,
/// Optional final projection.
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
w_output: Param<Tensor<B, 2>>,
dropout: Dropout, dropout: Dropout,
/// Upper-triangular matrix of ones, excluding diagonal. /// Upper-triangular matrix of ones, excluding diagonal.
@@ -180,46 +186,63 @@ pub struct CausalAttentionv1<B: Backend> {
utri_mask: Tensor<B, 2, Bool>, utri_mask: Tensor<B, 2, Bool>,
} }
impl<B: Backend> CausalAttentionv1<B> { impl<B: Backend> MultiheadAttention<B> {
pub fn new( pub fn new(
embedding_dim: usize, embedding_dim: usize,
out_dim: usize, head_dim: usize,
n_heads: usize,
context_length: usize, context_length: usize,
dropout: f64, dropout: f64,
device: &B::Device, device: &B::Device,
) -> Self { ) -> Self {
let total_dim = head_dim * n_heads;
Self { Self {
n_heads,
head_dim,
w_query: Param::uninitialized( w_query: Param::uninitialized(
ParamId::new(), ParamId::new(),
move |device, is_require_grad| { move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device) Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad) .set_require_grad(is_require_grad)
}, },
device.clone(), device.clone(),
true, true,
[embedding_dim, out_dim].into(), [embedding_dim, total_dim].into(),
), ),
w_key: Param::uninitialized( w_key: Param::uninitialized(
ParamId::new(), ParamId::new(),
move |device, is_require_grad| { move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device) Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad) .set_require_grad(is_require_grad)
}, },
device.clone(), device.clone(),
true, true,
[embedding_dim, out_dim].into(), [embedding_dim, total_dim].into(),
), ),
w_value: Param::uninitialized( w_value: Param::uninitialized(
ParamId::new(), ParamId::new(),
move |device, is_require_grad| { move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device) Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad) .set_require_grad(is_require_grad)
}, },
device.clone(), device.clone(),
true, true,
[embedding_dim, out_dim].into(), [embedding_dim, total_dim].into(),
),
w_output: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([total_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
), ),
dropout: Dropout { prob: dropout }, dropout: Dropout { prob: dropout },
@@ -231,12 +254,15 @@ impl<B: Backend> CausalAttentionv1<B> {
/// Compute self-attention vector for the given batch /// Compute self-attention vector for the given batch
/// ///
/// - input shape is [batch, token, token_dim] /// - input shape is [batch, token, token_dim]
/// - input shape is [batch, token, attn_dim] /// - input shape is [batch, token, n_heads * head_dim]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> { pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok) // Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
// But adds an "inner latent space" using Wq, Qk, and Wv. // But adds an "inner latent space" using Wq, Qk, and Wv.
//
// Multiple heads are batched into one tensor.
let batch = input.dims()[0]; let batch = input.dims()[0];
let tokens = input.dims()[1];
let w_query = self let w_query = self
.w_query .w_query
@@ -256,21 +282,42 @@ impl<B: Backend> CausalAttentionv1<B> {
.unsqueeze_dim::<3>(0) .unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]); .expand([batch as i64, -1, -1]);
let w_output = self
.w_output
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
// Map batch to inner latent space. // Map batch to inner latent space.
// shape: [batch, token, inner_dim] // shape: [batch, token, inner_dim]
let queries = input.clone().matmul(w_query); let queries = input.clone().matmul(w_query);
let keys = input.clone().matmul(w_key); let keys = input.clone().matmul(w_key);
let values = input.clone().matmul(w_value); let values = input.clone().matmul(w_value);
// Compute attention scores // Split head dimensions
// (cosine similarity of each query token to each context token) let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
// shape: [batch, query_token, context_token] let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
let attn_scores = queries.matmul(keys.clone().transpose()); let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
// from: [batch, tok, head, head_dim]
// to: [batch, head, tok, head_dim]
let keys = keys.swap_dims(1, 2);
let values = values.swap_dims(1, 2);
let queries = queries.swap_dims(1, 2);
// Compute attention scores for each head
// (cosine similarity of each query token to each context token, per head)
//
// lhs shape: [batch, head, tok, head_dim]
// rhs shape: [batch, head, head_dim, tok]
// output shape: [batch, head, query_token, context_token]
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
let mask = self let mask = self
.utri_mask .utri_mask
.clone() .clone()
.unsqueeze_dim::<3>(0) .unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
.expand(attn_scores.shape()); .expand(attn_scores.shape());
// Mask out future tokens by filling // Mask out future tokens by filling
@@ -283,10 +330,21 @@ impl<B: Backend> CausalAttentionv1<B> {
// - dot products get larger with larger dimensions // - dot products get larger with larger dimensions
// - this causes softmax to "saturate", making all other values very small // - this causes softmax to "saturate", making all other values very small
// - which makes gradients vanish during training // - which makes gradients vanish during training
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2); let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
let attn_weights = self.dropout.forward(attn_weights); let attn_weights = self.dropout.forward(attn_weights);
let context_vec = attn_weights.matmul(values); // lhs shape: [batch, head, query_token, context_token]
// rhs shape: [batch, head, tok, head_dim]
// matmul shape: [batch, head, tok, head_dim]
// out shape: [batch, tok, head, head_dim]
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
// shape: [batch, tok, stacked_dim]
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
// Apply final projection (optional)
let context_vec = context_vec.matmul(w_output);
return context_vec; return context_vec;
} }
} }