Chapter 3.5: attention
This commit is contained in:
@@ -2,13 +2,15 @@ use anyhow::{Context, Result};
|
|||||||
use burn::{
|
use burn::{
|
||||||
Tensor,
|
Tensor,
|
||||||
backend::{Cuda, cuda::CudaDevice},
|
backend::{Cuda, cuda::CudaDevice},
|
||||||
nn::{Embedding, EmbeddingConfig},
|
module::{Module, Param, ParamId},
|
||||||
tensor::Int,
|
nn::{Dropout, Embedding, EmbeddingConfig},
|
||||||
|
prelude::Backend,
|
||||||
|
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||||
};
|
};
|
||||||
use clap::Args;
|
use clap::Args;
|
||||||
use indicatif::MultiProgress;
|
use indicatif::MultiProgress;
|
||||||
use ndarray::Array2;
|
use ndarray::Array2;
|
||||||
use std::{fs::File, path::PathBuf};
|
use std::{f32, fs::File, path::PathBuf};
|
||||||
use tokenizer::Tokenizer;
|
use tokenizer::Tokenizer;
|
||||||
|
|
||||||
use crate::data_reader::DataReader;
|
use crate::data_reader::DataReader;
|
||||||
@@ -47,7 +49,11 @@ impl SampleDataArgs {
|
|||||||
let stride = 4;
|
let stride = 4;
|
||||||
|
|
||||||
// Dimension of each token vector
|
// Dimension of each token vector
|
||||||
let embedding_dim = 256;
|
let embedding_dim = 3;
|
||||||
|
// self-attention output dim
|
||||||
|
let sa_dim_out = 2;
|
||||||
|
|
||||||
|
let dropout = 0.5f64;
|
||||||
|
|
||||||
let batch_size = 10;
|
let batch_size = 10;
|
||||||
let mut input_batch = Vec::with_capacity(batch_size);
|
let mut input_batch = Vec::with_capacity(batch_size);
|
||||||
@@ -60,16 +66,18 @@ impl SampleDataArgs {
|
|||||||
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
|
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
|
||||||
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
|
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
|
||||||
|
|
||||||
|
// TODO: do we want to train this?
|
||||||
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
|
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
|
||||||
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
|
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
|
||||||
|
|
||||||
let pos_tensor: Tensor<Cuda, 2, Int> =
|
let pos_tensor: Tensor<Cuda, 2, Int> =
|
||||||
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0);
|
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0);
|
||||||
|
|
||||||
// [1, context_size, dim]
|
// [context_size, dim]
|
||||||
let pos_embedding = pos_embedder.forward(pos_tensor);
|
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
||||||
|
|
||||||
println!("{:?}", pos_embedding.shape());
|
let attention: CausalAttentionv1<Cuda> =
|
||||||
|
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device);
|
||||||
|
|
||||||
for i in iter {
|
for i in iter {
|
||||||
let tokens = tokenizer.encode(&i);
|
let tokens = tokenizer.encode(&i);
|
||||||
@@ -82,12 +90,13 @@ impl SampleDataArgs {
|
|||||||
input_batch.push(a.to_owned());
|
input_batch.push(a.to_owned());
|
||||||
output_batch.push(b.to_owned());
|
output_batch.push(b.to_owned());
|
||||||
|
|
||||||
let context = a;
|
// TODO: non-uniform batches?
|
||||||
let desired = &b[b.len() - 1..];
|
|
||||||
|
|
||||||
println!("{context:?} -> {desired:?}");
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
let context = a;
|
||||||
|
let desired = &b[b.len() - 1..];
|
||||||
|
println!("{context:?} -> {desired:?}");
|
||||||
|
|
||||||
let input = tokenizer.decode(context);
|
let input = tokenizer.decode(context);
|
||||||
let target = tokenizer.decode(desired);
|
let target = tokenizer.decode(desired);
|
||||||
println!("{input:?} -> {target:?}");
|
println!("{input:?} -> {target:?}");
|
||||||
@@ -110,9 +119,27 @@ impl SampleDataArgs {
|
|||||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
||||||
.reshape(shape);
|
.reshape(shape);
|
||||||
|
|
||||||
|
// Input token embeddings
|
||||||
|
// dim: [batch, token, dim]
|
||||||
let tok_e = tok_embedder.forward(input);
|
let tok_e = tok_embedder.forward(input);
|
||||||
let tok_e: Tensor<Cuda, 3> = Tensor::from_data(tok_e.to_data(), &device);
|
|
||||||
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
||||||
|
|
||||||
|
/*
|
||||||
|
// simple self-attention
|
||||||
|
|
||||||
|
// dim: [batch, query token, other token]
|
||||||
|
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
|
||||||
|
let attention_score = softmax(attention_scores, 1);
|
||||||
|
|
||||||
|
// context vectors for each input token
|
||||||
|
// dim: [batch, token, dim]
|
||||||
|
let context_vectors = attention_score.matmul(tok_e.clone());
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Trainable self-attention
|
||||||
|
|
||||||
|
// shape: [batch, tokens, out_dim]
|
||||||
|
let a = attention.forward(tok_e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -120,3 +147,146 @@ impl SampleDataArgs {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct CausalAttentionv1<B: Backend> {
|
||||||
|
// Can also use Linear layers with disabled bias
|
||||||
|
// (they may also have a better initialization routine)
|
||||||
|
// TODO: see source code, make this equivalent
|
||||||
|
/// Query weight matrix.
|
||||||
|
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||||
|
///
|
||||||
|
/// Intuitively, this learns "what question to ask about the text"
|
||||||
|
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||||
|
w_query: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
|
/// Key weight matrix.
|
||||||
|
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||||
|
///
|
||||||
|
/// Intuitively, this learns what properties a certain token
|
||||||
|
/// has when it appears as a context (non-query) token.
|
||||||
|
w_key: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
|
/// Value weight matrix.
|
||||||
|
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||||
|
///
|
||||||
|
/// Intuitively, ???
|
||||||
|
w_value: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
|
dropout: Dropout,
|
||||||
|
|
||||||
|
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||||
|
/// Used to mask future tokens.
|
||||||
|
utri_mask: Tensor<B, 2, Bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> CausalAttentionv1<B> {
|
||||||
|
pub fn new(
|
||||||
|
embedding_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
context_length: usize,
|
||||||
|
dropout: f64,
|
||||||
|
device: &B::Device,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
w_query: Param::uninitialized(
|
||||||
|
ParamId::new(),
|
||||||
|
move |device, is_require_grad| {
|
||||||
|
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||||
|
.set_require_grad(is_require_grad)
|
||||||
|
},
|
||||||
|
device.clone(),
|
||||||
|
true,
|
||||||
|
[embedding_dim, out_dim].into(),
|
||||||
|
),
|
||||||
|
|
||||||
|
w_key: Param::uninitialized(
|
||||||
|
ParamId::new(),
|
||||||
|
move |device, is_require_grad| {
|
||||||
|
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||||
|
.set_require_grad(is_require_grad)
|
||||||
|
},
|
||||||
|
device.clone(),
|
||||||
|
true,
|
||||||
|
[embedding_dim, out_dim].into(),
|
||||||
|
),
|
||||||
|
|
||||||
|
w_value: Param::uninitialized(
|
||||||
|
ParamId::new(),
|
||||||
|
move |device, is_require_grad| {
|
||||||
|
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||||
|
.set_require_grad(is_require_grad)
|
||||||
|
},
|
||||||
|
device.clone(),
|
||||||
|
true,
|
||||||
|
[embedding_dim, out_dim].into(),
|
||||||
|
),
|
||||||
|
|
||||||
|
dropout: Dropout { prob: dropout },
|
||||||
|
|
||||||
|
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, &device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute self-attention vector for the given batch
|
||||||
|
///
|
||||||
|
/// - input shape is [batch, token, token_dim]
|
||||||
|
/// - input shape is [batch, token, attn_dim]
|
||||||
|
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||||
|
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||||
|
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||||
|
|
||||||
|
let batch = input.dims()[0];
|
||||||
|
|
||||||
|
let w_query = self
|
||||||
|
.w_query
|
||||||
|
.val()
|
||||||
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.expand([batch as i64, -1, -1]);
|
||||||
|
|
||||||
|
let w_key = self
|
||||||
|
.w_key
|
||||||
|
.val()
|
||||||
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.expand([batch as i64, -1, -1]);
|
||||||
|
|
||||||
|
let w_value = self
|
||||||
|
.w_value
|
||||||
|
.val()
|
||||||
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.expand([batch as i64, -1, -1]);
|
||||||
|
|
||||||
|
// Map batch to inner latent space.
|
||||||
|
// shape: [batch, token, inner_dim]
|
||||||
|
let queries = input.clone().matmul(w_query);
|
||||||
|
let keys = input.clone().matmul(w_key);
|
||||||
|
let values = input.clone().matmul(w_value);
|
||||||
|
|
||||||
|
// Compute attention scores
|
||||||
|
// (cosine similarity of each query token to each context token)
|
||||||
|
// shape: [batch, query_token, context_token]
|
||||||
|
let attn_scores = queries.matmul(keys.clone().transpose());
|
||||||
|
|
||||||
|
let mask = self
|
||||||
|
.utri_mask
|
||||||
|
.clone()
|
||||||
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.expand(attn_scores.shape());
|
||||||
|
|
||||||
|
// Mask out future tokens by filling
|
||||||
|
// upper-triangular with -inf, which becomes 0.0 after softmax.
|
||||||
|
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
|
||||||
|
|
||||||
|
// Normalize attn weights.
|
||||||
|
//
|
||||||
|
// Divide by sqrt(inner_dim) because...
|
||||||
|
// - dot products get larger with larger dimensions
|
||||||
|
// - this causes softmax to "saturate", making all other values very small
|
||||||
|
// - which makes gradients vanish during training
|
||||||
|
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2);
|
||||||
|
let attn_weights = self.dropout.forward(attn_weights);
|
||||||
|
|
||||||
|
let context_vec = attn_weights.matmul(values);
|
||||||
|
return context_vec;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user