Chapter 3: multhead attention
This commit is contained in:
@@ -50,8 +50,9 @@ impl SampleDataArgs {
|
||||
|
||||
// Dimension of each token vector
|
||||
let embedding_dim = 3;
|
||||
// self-attention output dim
|
||||
let sa_dim_out = 2;
|
||||
// attention
|
||||
let head_dim = 2;
|
||||
let n_heads = 2;
|
||||
|
||||
let dropout = 0.5f64;
|
||||
|
||||
@@ -76,8 +77,14 @@ impl SampleDataArgs {
|
||||
// [context_size, dim]
|
||||
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
||||
|
||||
let attention: CausalAttentionv1<Cuda> =
|
||||
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device);
|
||||
let attention: MultiheadAttention<Cuda> = MultiheadAttention::new(
|
||||
embedding_dim,
|
||||
head_dim,
|
||||
n_heads,
|
||||
context_size,
|
||||
dropout,
|
||||
&device,
|
||||
);
|
||||
|
||||
for i in iter {
|
||||
let tokens = tokenizer.encode(&i);
|
||||
@@ -124,19 +131,7 @@ impl SampleDataArgs {
|
||||
let tok_e = tok_embedder.forward(input);
|
||||
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
||||
|
||||
/*
|
||||
// simple self-attention
|
||||
|
||||
// dim: [batch, query token, other token]
|
||||
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
|
||||
let attention_score = softmax(attention_scores, 1);
|
||||
|
||||
// context vectors for each input token
|
||||
// dim: [batch, token, dim]
|
||||
let context_vectors = attention_score.matmul(tok_e.clone());
|
||||
*/
|
||||
|
||||
// Trainable self-attention
|
||||
// Attention
|
||||
|
||||
// shape: [batch, tokens, out_dim]
|
||||
let a = attention.forward(tok_e);
|
||||
@@ -148,31 +143,42 @@ impl SampleDataArgs {
|
||||
}
|
||||
}
|
||||
|
||||
/// Multihead attention.
|
||||
///
|
||||
/// Equivalent to many stacked CausalAttention layers.
|
||||
/// These are packed inside one big tensor for efficiency.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CausalAttentionv1<B: Backend> {
|
||||
pub struct MultiheadAttention<B: Backend> {
|
||||
n_heads: usize,
|
||||
head_dim: usize,
|
||||
|
||||
// Can also use Linear layers with disabled bias
|
||||
// (they may also have a better initialization routine)
|
||||
// TODO: see source code, make this equivalent
|
||||
/// Query weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
/// Query weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns "what question to ask about the text"
|
||||
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||
w_query: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Key weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
/// Key weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, this learns what properties a certain token
|
||||
/// has when it appears as a context (non-query) token.
|
||||
w_key: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Value weight matrix.
|
||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
||||
/// Value weight matrices for each head, stacked on the last dimension.
|
||||
/// (so that shape is [tokens, n_heads * head_dim])
|
||||
///
|
||||
/// Intuitively, ???
|
||||
w_value: Param<Tensor<B, 2>>,
|
||||
|
||||
/// Optional final projection.
|
||||
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
|
||||
w_output: Param<Tensor<B, 2>>,
|
||||
|
||||
dropout: Dropout,
|
||||
|
||||
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||
@@ -180,46 +186,63 @@ pub struct CausalAttentionv1<B: Backend> {
|
||||
utri_mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CausalAttentionv1<B> {
|
||||
impl<B: Backend> MultiheadAttention<B> {
|
||||
pub fn new(
|
||||
embedding_dim: usize,
|
||||
out_dim: usize,
|
||||
head_dim: usize,
|
||||
n_heads: usize,
|
||||
context_length: usize,
|
||||
dropout: f64,
|
||||
device: &B::Device,
|
||||
) -> Self {
|
||||
let total_dim = head_dim * n_heads;
|
||||
|
||||
Self {
|
||||
n_heads,
|
||||
head_dim,
|
||||
|
||||
w_query: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_key: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_value: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
||||
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, out_dim].into(),
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
w_output: Param::uninitialized(
|
||||
ParamId::new(),
|
||||
move |device, is_require_grad| {
|
||||
Tensor::random([total_dim, total_dim], Distribution::Default, device)
|
||||
.set_require_grad(is_require_grad)
|
||||
},
|
||||
device.clone(),
|
||||
true,
|
||||
[embedding_dim, total_dim].into(),
|
||||
),
|
||||
|
||||
dropout: Dropout { prob: dropout },
|
||||
@@ -231,12 +254,15 @@ impl<B: Backend> CausalAttentionv1<B> {
|
||||
/// Compute self-attention vector for the given batch
|
||||
///
|
||||
/// - input shape is [batch, token, token_dim]
|
||||
/// - input shape is [batch, token, attn_dim]
|
||||
/// - input shape is [batch, token, n_heads * head_dim]
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||
//
|
||||
// Multiple heads are batched into one tensor.
|
||||
|
||||
let batch = input.dims()[0];
|
||||
let tokens = input.dims()[1];
|
||||
|
||||
let w_query = self
|
||||
.w_query
|
||||
@@ -256,21 +282,42 @@ impl<B: Backend> CausalAttentionv1<B> {
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
let w_output = self
|
||||
.w_output
|
||||
.val()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.expand([batch as i64, -1, -1]);
|
||||
|
||||
// Map batch to inner latent space.
|
||||
// shape: [batch, token, inner_dim]
|
||||
let queries = input.clone().matmul(w_query);
|
||||
let keys = input.clone().matmul(w_key);
|
||||
let values = input.clone().matmul(w_value);
|
||||
|
||||
// Compute attention scores
|
||||
// (cosine similarity of each query token to each context token)
|
||||
// shape: [batch, query_token, context_token]
|
||||
let attn_scores = queries.matmul(keys.clone().transpose());
|
||||
// Split head dimensions
|
||||
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||
|
||||
// from: [batch, tok, head, head_dim]
|
||||
// to: [batch, head, tok, head_dim]
|
||||
let keys = keys.swap_dims(1, 2);
|
||||
let values = values.swap_dims(1, 2);
|
||||
let queries = queries.swap_dims(1, 2);
|
||||
|
||||
// Compute attention scores for each head
|
||||
// (cosine similarity of each query token to each context token, per head)
|
||||
//
|
||||
// lhs shape: [batch, head, tok, head_dim]
|
||||
// rhs shape: [batch, head, head_dim, tok]
|
||||
// output shape: [batch, head, query_token, context_token]
|
||||
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
|
||||
|
||||
let mask = self
|
||||
.utri_mask
|
||||
.clone()
|
||||
.unsqueeze_dim::<3>(0)
|
||||
.unsqueeze_dim::<4>(0)
|
||||
.expand(attn_scores.shape());
|
||||
|
||||
// Mask out future tokens by filling
|
||||
@@ -283,10 +330,21 @@ impl<B: Backend> CausalAttentionv1<B> {
|
||||
// - dot products get larger with larger dimensions
|
||||
// - this causes softmax to "saturate", making all other values very small
|
||||
// - which makes gradients vanish during training
|
||||
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2);
|
||||
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
|
||||
let attn_weights = self.dropout.forward(attn_weights);
|
||||
|
||||
let context_vec = attn_weights.matmul(values);
|
||||
// lhs shape: [batch, head, query_token, context_token]
|
||||
// rhs shape: [batch, head, tok, head_dim]
|
||||
// matmul shape: [batch, head, tok, head_dim]
|
||||
// out shape: [batch, tok, head, head_dim]
|
||||
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
|
||||
|
||||
// shape: [batch, tok, stacked_dim]
|
||||
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
|
||||
|
||||
// Apply final projection (optional)
|
||||
let context_vec = context_vec.matmul(w_output);
|
||||
|
||||
return context_vec;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user