1
0

Chapter 3.5: attention

This commit is contained in:
2025-12-12 21:53:43 -08:00
parent 4430247d40
commit 60b63dfd2e

View File

@@ -2,13 +2,15 @@ use anyhow::{Context, Result};
use burn::{
Tensor,
backend::{Cuda, cuda::CudaDevice},
nn::{Embedding, EmbeddingConfig},
tensor::Int,
module::{Module, Param, ParamId},
nn::{Dropout, Embedding, EmbeddingConfig},
prelude::Backend,
tensor::{Bool, Distribution, Int, activation::softmax},
};
use clap::Args;
use indicatif::MultiProgress;
use ndarray::Array2;
use std::{fs::File, path::PathBuf};
use std::{f32, fs::File, path::PathBuf};
use tokenizer::Tokenizer;
use crate::data_reader::DataReader;
@@ -47,7 +49,11 @@ impl SampleDataArgs {
let stride = 4;
// Dimension of each token vector
let embedding_dim = 256;
let embedding_dim = 3;
// self-attention output dim
let sa_dim_out = 2;
let dropout = 0.5f64;
let batch_size = 10;
let mut input_batch = Vec::with_capacity(batch_size);
@@ -60,16 +66,18 @@ impl SampleDataArgs {
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
// TODO: do we want to train this?
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
let pos_tensor: Tensor<Cuda, 2, Int> =
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0);
// [1, context_size, dim]
let pos_embedding = pos_embedder.forward(pos_tensor);
// [context_size, dim]
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
println!("{:?}", pos_embedding.shape());
let attention: CausalAttentionv1<Cuda> =
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device);
for i in iter {
let tokens = tokenizer.encode(&i);
@@ -82,12 +90,13 @@ impl SampleDataArgs {
input_batch.push(a.to_owned());
output_batch.push(b.to_owned());
let context = a;
let desired = &b[b.len() - 1..];
println!("{context:?} -> {desired:?}");
// TODO: non-uniform batches?
/*
let context = a;
let desired = &b[b.len() - 1..];
println!("{context:?} -> {desired:?}");
let input = tokenizer.decode(context);
let target = tokenizer.decode(desired);
println!("{input:?} -> {target:?}");
@@ -110,9 +119,27 @@ impl SampleDataArgs {
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
.reshape(shape);
// Input token embeddings
// dim: [batch, token, dim]
let tok_e = tok_embedder.forward(input);
let tok_e: Tensor<Cuda, 3> = Tensor::from_data(tok_e.to_data(), &device);
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
/*
// simple self-attention
// dim: [batch, query token, other token]
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
let attention_score = softmax(attention_scores, 1);
// context vectors for each input token
// dim: [batch, token, dim]
let context_vectors = attention_score.matmul(tok_e.clone());
*/
// Trainable self-attention
// shape: [batch, tokens, out_dim]
let a = attention.forward(tok_e);
}
}
}
@@ -120,3 +147,146 @@ impl SampleDataArgs {
Ok(())
}
}
#[derive(Module, Debug)]
pub struct CausalAttentionv1<B: Backend> {
// Can also use Linear layers with disabled bias
// (they may also have a better initialization routine)
// TODO: see source code, make this equivalent
/// Query weight matrix.
/// Maps [tokens, dim] into [tokens, inner_dim].
///
/// Intuitively, this learns "what question to ask about the text"
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
w_query: Param<Tensor<B, 2>>,
/// Key weight matrix.
/// Maps [tokens, dim] into [tokens, inner_dim].
///
/// Intuitively, this learns what properties a certain token
/// has when it appears as a context (non-query) token.
w_key: Param<Tensor<B, 2>>,
/// Value weight matrix.
/// Maps [tokens, dim] into [tokens, inner_dim].
///
/// Intuitively, ???
w_value: Param<Tensor<B, 2>>,
dropout: Dropout,
/// Upper-triangular matrix of ones, excluding diagonal.
/// Used to mask future tokens.
utri_mask: Tensor<B, 2, Bool>,
}
impl<B: Backend> CausalAttentionv1<B> {
pub fn new(
embedding_dim: usize,
out_dim: usize,
context_length: usize,
dropout: f64,
device: &B::Device,
) -> Self {
Self {
w_query: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, out_dim].into(),
),
w_key: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, out_dim].into(),
),
w_value: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, out_dim].into(),
),
dropout: Dropout { prob: dropout },
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, &device),
}
}
/// Compute self-attention vector for the given batch
///
/// - input shape is [batch, token, token_dim]
/// - input shape is [batch, token, attn_dim]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
// But adds an "inner latent space" using Wq, Qk, and Wv.
let batch = input.dims()[0];
let w_query = self
.w_query
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_key = self
.w_key
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_value = self
.w_value
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
// Map batch to inner latent space.
// shape: [batch, token, inner_dim]
let queries = input.clone().matmul(w_query);
let keys = input.clone().matmul(w_key);
let values = input.clone().matmul(w_value);
// Compute attention scores
// (cosine similarity of each query token to each context token)
// shape: [batch, query_token, context_token]
let attn_scores = queries.matmul(keys.clone().transpose());
let mask = self
.utri_mask
.clone()
.unsqueeze_dim::<3>(0)
.expand(attn_scores.shape());
// Mask out future tokens by filling
// upper-triangular with -inf, which becomes 0.0 after softmax.
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
// Normalize attn weights.
//
// Divide by sqrt(inner_dim) because...
// - dot products get larger with larger dimensions
// - this causes softmax to "saturate", making all other values very small
// - which makes gradients vanish during training
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2);
let attn_weights = self.dropout.forward(attn_weights);
let context_vec = attn_weights.matmul(values);
return context_vec;
}
}