Chapter 3.5: attention
This commit is contained in:
@@ -2,13 +2,15 @@ use anyhow::{Context, Result};
|
||||
use burn::{
|
||||
Tensor,
|
||||
backend::{Cuda, cuda::CudaDevice},
|
||||
nn::{Embedding, EmbeddingConfig},
|
||||
tensor::Int,
|
||||
module::{Module, Param, ParamId},
|
||||
nn::{Dropout, Embedding, EmbeddingConfig},
|
||||
prelude::Backend,
|
||||
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||
};
|
||||
use clap::Args;
|
||||
use indicatif::MultiProgress;
|
||||
use ndarray::Array2;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use std::{f32, fs::File, path::PathBuf};
|
||||
use tokenizer::Tokenizer;
|
||||
|
||||
use crate::data_reader::DataReader;
|
||||
@@ -47,7 +49,11 @@ impl SampleDataArgs {
|
||||
let stride = 4;
|
||||
|
||||
// Dimension of each token vector
|
||||
let embedding_dim = 256;
|
||||
let embedding_dim = 3;
|
||||
// self-attention output dim
|
||||
let sa_dim_out = 2;
|
||||
|
||||
let dropout = 0.5f64;
|
||||
|
||||
let batch_size = 10;
|
||||
let mut input_batch = Vec::with_capacity(batch_size);
|
||||
@@ -60,16 +66,18 @@ impl SampleDataArgs {
|
||||
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
|
||||
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
|
||||
|
||||
// TODO: do we want to train this?
|
||||
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
|
||||
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
|
||||
|
||||
let pos_tensor: Tensor<Cuda, 2, Int> =
|
||||
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0);
|
||||
|
||||
// [1, context_size, dim]
|
||||
let pos_embedding = pos_embedder.forward(pos_tensor);
|
||||
// [context_size, dim]
|
||||
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
||||
|
||||
println!("{:?}", pos_embedding.shape());
|
||||
let attention: CausalAttentionv1<Cuda> =
|
||||
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device);
|
||||
|
||||
for i in iter {
|
||||
let tokens = tokenizer.encode(&i);
|
||||
@@ -82,12 +90,13 @@ impl SampleDataArgs {
|
||||
input_batch.push(a.to_owned());
|
||||
output_batch.push(b.to_owned());
|
||||
|
||||
let context = a;
|
||||
let desired = &b[b.len() - 1..];
|
||||
|
||||
println!("{context:?} -> {desired:?}");
|
||||
// TODO: non-uniform batches?
|
||||
|
||||
/*
|
||||
let context = a;
|
||||
let desired = &b[b.len() - 1..];
|
||||
println!("{context:?} -> {desired:?}");
|
||||
|
||||
let input = tokenizer.decode(context);
|
||||
let target = tokenizer.decode(desired);
|
||||
println!("{input:?} -> {target:?}");
|
||||
@@ -110,9 +119,27 @@ impl SampleDataArgs {
|
||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
||||
.reshape(shape);
|
||||
|
||||
// Input token embeddings
|
||||
// dim: [batch, token, dim]
|
||||
let tok_e = tok_embedder.forward(input);
|
||||
let tok_e: Tensor<Cuda, 3> = Tensor::from_data(tok_e.to_data(), &device);
|
||||
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
||||
|
||||
/*
|
||||
// simple self-attention
|
||||
|
||||
// dim: [batch, query token, other token]
|
||||
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
|
||||
let attention_score = softmax(attention_scores, 1);
|
||||
|
||||
// context vectors for each input token
|
||||
// dim: [batch, token, dim]
|
||||
let context_vectors = attention_score.matmul(tok_e.clone());
|
||||
*/
|
||||
|
||||
// Trainable self-attention
|
||||
|
||||
// shape: [batch, tokens, out_dim]
|
||||
let a = attention.forward(tok_e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,3 +147,146 @@ impl SampleDataArgs {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CausalAttentionv1<B: Backend> {
|
||||
// Can also use Linear layers with disabled bias
|
||||
// (they may also have a better initialization routine)
|
||||
// TODO: see source code, make this equivalent
|
||||
/// Query weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
///
|
||||
/// Intuitively, this learns "what question to ask about the text"
|
||||
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||
w_query: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Key weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
///
|
||||
/// Intuitively, this learns what properties a certain token
|
||||
/// has when it appears as a context (non-query) token.
|
||||
w_key: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Value weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
///
|
||||
/// Intuitively, ???
|
||||
w_value: Param<Tensor<B, 2>>,
|
||||
|
||||
dropout: Dropout,
|
||||
|
||||
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||
/// Used to mask future tokens.
|
||||
utri_mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CausalAttentionv1<B> {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
out_dim: usize,
|
||||
context_length: usize,
|
||||
dropout: f64,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
Self {
|
||||
w_query: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
),
|
||||
|
||||
w_key: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
),
|
||||
|
||||
w_value: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
),
|
||||
|
||||
dropout: Dropout { prob: dropout },
|
||||
|
||||
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, &device),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute self-attention vector for the given batch
|
||||
///
|
||||
/// - input shape is [batch, token, token_dim]
|
||||
/// - input shape is [batch, token, attn_dim]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||
|
||||
let batch = input.dims()[0];
|
||||
|
||||
let w_query = self
|
||||
.w_query
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_key = self
|
||||
.w_key
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_value = self
|
||||
.w_value
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
// Map batch to inner latent space.
|
||||
// shape: [batch, token, inner_dim]
|
||||
let queries = input.clone().matmul(w_query);
|
||||
let keys = input.clone().matmul(w_key);
|
||||
let values = input.clone().matmul(w_value);
|
||||
|
||||
// Compute attention scores
|
||||
// (cosine similarity of each query token to each context token)
|
||||
// shape: [batch, query_token, context_token]
|
||||
let attn_scores = queries.matmul(keys.clone().transpose());
|
||||
|
||||
let mask = self
|
||||
.utri_mask
|
||||
.clone()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand(attn_scores.shape());
|
||||
|
||||
// Mask out future tokens by filling
|
||||
// upper-triangular with -inf, which becomes 0.0 after softmax.
|
||||
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
|
||||
|
||||
// Normalize attn weights.
|
||||
//
|
||||
// Divide by sqrt(inner_dim) because...
|
||||
// - dot products get larger with larger dimensions
|
||||
// - this causes softmax to "saturate", making all other values very small
|
||||
// - which makes gradients vanish during training
|
||||
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2);
|
||||
let attn_weights = self.dropout.forward(attn_weights);
|
||||
|
||||
let context_vec = attn_weights.matmul(values);
|
||||
return context_vec;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user