Chapter 3: multhead attention
This commit is contained in:
@@ -50,8 +50,9 @@ impl SampleDataArgs {
|
|||||||
|
|
||||||
// Dimension of each token vector
|
// Dimension of each token vector
|
||||||
let embedding_dim = 3;
|
let embedding_dim = 3;
|
||||||
// self-attention output dim
|
// attention
|
||||||
let sa_dim_out = 2;
|
let head_dim = 2;
|
||||||
|
let n_heads = 2;
|
||||||
|
|
||||||
let dropout = 0.5f64;
|
let dropout = 0.5f64;
|
||||||
|
|
||||||
@@ -76,8 +77,14 @@ impl SampleDataArgs {
|
|||||||
// [context_size, dim]
|
// [context_size, dim]
|
||||||
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
||||||
|
|
||||||
let attention: CausalAttentionv1<Cuda> =
|
let attention: MultiheadAttention<Cuda> = MultiheadAttention::new(
|
||||||
CausalAttentionv1::new(embedding_dim, sa_dim_out, context_size, dropout, &device);
|
embedding_dim,
|
||||||
|
head_dim,
|
||||||
|
n_heads,
|
||||||
|
context_size,
|
||||||
|
dropout,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
for i in iter {
|
for i in iter {
|
||||||
let tokens = tokenizer.encode(&i);
|
let tokens = tokenizer.encode(&i);
|
||||||
@@ -124,19 +131,7 @@ impl SampleDataArgs {
|
|||||||
let tok_e = tok_embedder.forward(input);
|
let tok_e = tok_embedder.forward(input);
|
||||||
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
||||||
|
|
||||||
/*
|
// Attention
|
||||||
// simple self-attention
|
|
||||||
|
|
||||||
// dim: [batch, query token, other token]
|
|
||||||
let attention_scores = tok_e.clone().matmul(tok_e.clone().transpose());
|
|
||||||
let attention_score = softmax(attention_scores, 1);
|
|
||||||
|
|
||||||
// context vectors for each input token
|
|
||||||
// dim: [batch, token, dim]
|
|
||||||
let context_vectors = attention_score.matmul(tok_e.clone());
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Trainable self-attention
|
|
||||||
|
|
||||||
// shape: [batch, tokens, out_dim]
|
// shape: [batch, tokens, out_dim]
|
||||||
let a = attention.forward(tok_e);
|
let a = attention.forward(tok_e);
|
||||||
@@ -148,31 +143,42 @@ impl SampleDataArgs {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Multihead attention.
|
||||||
|
///
|
||||||
|
/// Equivalent to many stacked CausalAttention layers.
|
||||||
|
/// These are packed inside one big tensor for efficiency.
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct CausalAttentionv1<B: Backend> {
|
pub struct MultiheadAttention<B: Backend> {
|
||||||
|
n_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
|
||||||
// Can also use Linear layers with disabled bias
|
// Can also use Linear layers with disabled bias
|
||||||
// (they may also have a better initialization routine)
|
// (they may also have a better initialization routine)
|
||||||
// TODO: see source code, make this equivalent
|
// TODO: see source code, make this equivalent
|
||||||
/// Query weight matrix.
|
/// Query weight matrices for each head, stacked on the last dimension.
|
||||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
/// (so that shape is [tokens, n_heads * head_dim])
|
||||||
///
|
///
|
||||||
/// Intuitively, this learns "what question to ask about the text"
|
/// Intuitively, this learns "what question to ask about the text"
|
||||||
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
|
||||||
w_query: Param<Tensor<B, 2>>,
|
w_query: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
/// Key weight matrix.
|
/// Key weight matrices for each head, stacked on the last dimension.
|
||||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
/// (so that shape is [tokens, n_heads * head_dim])
|
||||||
///
|
///
|
||||||
/// Intuitively, this learns what properties a certain token
|
/// Intuitively, this learns what properties a certain token
|
||||||
/// has when it appears as a context (non-query) token.
|
/// has when it appears as a context (non-query) token.
|
||||||
w_key: Param<Tensor<B, 2>>,
|
w_key: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
/// Value weight matrix.
|
/// Value weight matrices for each head, stacked on the last dimension.
|
||||||
/// Maps [tokens, dim] into [tokens, inner_dim].
|
/// (so that shape is [tokens, n_heads * head_dim])
|
||||||
///
|
///
|
||||||
/// Intuitively, ???
|
/// Intuitively, ???
|
||||||
w_value: Param<Tensor<B, 2>>,
|
w_value: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
|
/// Optional final projection.
|
||||||
|
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
|
||||||
|
w_output: Param<Tensor<B, 2>>,
|
||||||
|
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
|
|
||||||
/// Upper-triangular matrix of ones, excluding diagonal.
|
/// Upper-triangular matrix of ones, excluding diagonal.
|
||||||
@@ -180,46 +186,63 @@ pub struct CausalAttentionv1<B: Backend> {
|
|||||||
utri_mask: Tensor<B, 2, Bool>,
|
utri_mask: Tensor<B, 2, Bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> CausalAttentionv1<B> {
|
impl<B: Backend> MultiheadAttention<B> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
embedding_dim: usize,
|
embedding_dim: usize,
|
||||||
out_dim: usize,
|
head_dim: usize,
|
||||||
|
n_heads: usize,
|
||||||
context_length: usize,
|
context_length: usize,
|
||||||
dropout: f64,
|
dropout: f64,
|
||||||
device: &B::Device,
|
device: &B::Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let total_dim = head_dim * n_heads;
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
n_heads,
|
||||||
|
head_dim,
|
||||||
|
|
||||||
w_query: Param::uninitialized(
|
w_query: Param::uninitialized(
|
||||||
ParamId::new(),
|
ParamId::new(),
|
||||||
move |device, is_require_grad| {
|
move |device, is_require_grad| {
|
||||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||||
.set_require_grad(is_require_grad)
|
.set_require_grad(is_require_grad)
|
||||||
},
|
},
|
||||||
device.clone(),
|
device.clone(),
|
||||||
true,
|
true,
|
||||||
[embedding_dim, out_dim].into(),
|
[embedding_dim, total_dim].into(),
|
||||||
),
|
),
|
||||||
|
|
||||||
w_key: Param::uninitialized(
|
w_key: Param::uninitialized(
|
||||||
ParamId::new(),
|
ParamId::new(),
|
||||||
move |device, is_require_grad| {
|
move |device, is_require_grad| {
|
||||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||||
.set_require_grad(is_require_grad)
|
.set_require_grad(is_require_grad)
|
||||||
},
|
},
|
||||||
device.clone(),
|
device.clone(),
|
||||||
true,
|
true,
|
||||||
[embedding_dim, out_dim].into(),
|
[embedding_dim, total_dim].into(),
|
||||||
),
|
),
|
||||||
|
|
||||||
w_value: Param::uninitialized(
|
w_value: Param::uninitialized(
|
||||||
ParamId::new(),
|
ParamId::new(),
|
||||||
move |device, is_require_grad| {
|
move |device, is_require_grad| {
|
||||||
Tensor::random([embedding_dim, out_dim], Distribution::Default, device)
|
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
|
||||||
.set_require_grad(is_require_grad)
|
.set_require_grad(is_require_grad)
|
||||||
},
|
},
|
||||||
device.clone(),
|
device.clone(),
|
||||||
true,
|
true,
|
||||||
[embedding_dim, out_dim].into(),
|
[embedding_dim, total_dim].into(),
|
||||||
|
),
|
||||||
|
|
||||||
|
w_output: Param::uninitialized(
|
||||||
|
ParamId::new(),
|
||||||
|
move |device, is_require_grad| {
|
||||||
|
Tensor::random([total_dim, total_dim], Distribution::Default, device)
|
||||||
|
.set_require_grad(is_require_grad)
|
||||||
|
},
|
||||||
|
device.clone(),
|
||||||
|
true,
|
||||||
|
[embedding_dim, total_dim].into(),
|
||||||
),
|
),
|
||||||
|
|
||||||
dropout: Dropout { prob: dropout },
|
dropout: Dropout { prob: dropout },
|
||||||
@@ -231,12 +254,15 @@ impl<B: Backend> CausalAttentionv1<B> {
|
|||||||
/// Compute self-attention vector for the given batch
|
/// Compute self-attention vector for the given batch
|
||||||
///
|
///
|
||||||
/// - input shape is [batch, token, token_dim]
|
/// - input shape is [batch, token, token_dim]
|
||||||
/// - input shape is [batch, token, attn_dim]
|
/// - input shape is [batch, token, n_heads * head_dim]
|
||||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||||
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
|
||||||
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
// But adds an "inner latent space" using Wq, Qk, and Wv.
|
||||||
|
//
|
||||||
|
// Multiple heads are batched into one tensor.
|
||||||
|
|
||||||
let batch = input.dims()[0];
|
let batch = input.dims()[0];
|
||||||
|
let tokens = input.dims()[1];
|
||||||
|
|
||||||
let w_query = self
|
let w_query = self
|
||||||
.w_query
|
.w_query
|
||||||
@@ -256,21 +282,42 @@ impl<B: Backend> CausalAttentionv1<B> {
|
|||||||
.unsqueeze_dim::<3>(0)
|
.unsqueeze_dim::<3>(0)
|
||||||
.expand([batch as i64, -1, -1]);
|
.expand([batch as i64, -1, -1]);
|
||||||
|
|
||||||
|
let w_output = self
|
||||||
|
.w_output
|
||||||
|
.val()
|
||||||
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.expand([batch as i64, -1, -1]);
|
||||||
|
|
||||||
// Map batch to inner latent space.
|
// Map batch to inner latent space.
|
||||||
// shape: [batch, token, inner_dim]
|
// shape: [batch, token, inner_dim]
|
||||||
let queries = input.clone().matmul(w_query);
|
let queries = input.clone().matmul(w_query);
|
||||||
let keys = input.clone().matmul(w_key);
|
let keys = input.clone().matmul(w_key);
|
||||||
let values = input.clone().matmul(w_value);
|
let values = input.clone().matmul(w_value);
|
||||||
|
|
||||||
// Compute attention scores
|
// Split head dimensions
|
||||||
// (cosine similarity of each query token to each context token)
|
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||||
// shape: [batch, query_token, context_token]
|
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||||
let attn_scores = queries.matmul(keys.clone().transpose());
|
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
|
||||||
|
|
||||||
|
// from: [batch, tok, head, head_dim]
|
||||||
|
// to: [batch, head, tok, head_dim]
|
||||||
|
let keys = keys.swap_dims(1, 2);
|
||||||
|
let values = values.swap_dims(1, 2);
|
||||||
|
let queries = queries.swap_dims(1, 2);
|
||||||
|
|
||||||
|
// Compute attention scores for each head
|
||||||
|
// (cosine similarity of each query token to each context token, per head)
|
||||||
|
//
|
||||||
|
// lhs shape: [batch, head, tok, head_dim]
|
||||||
|
// rhs shape: [batch, head, head_dim, tok]
|
||||||
|
// output shape: [batch, head, query_token, context_token]
|
||||||
|
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
|
||||||
|
|
||||||
let mask = self
|
let mask = self
|
||||||
.utri_mask
|
.utri_mask
|
||||||
.clone()
|
.clone()
|
||||||
.unsqueeze_dim::<3>(0)
|
.unsqueeze_dim::<3>(0)
|
||||||
|
.unsqueeze_dim::<4>(0)
|
||||||
.expand(attn_scores.shape());
|
.expand(attn_scores.shape());
|
||||||
|
|
||||||
// Mask out future tokens by filling
|
// Mask out future tokens by filling
|
||||||
@@ -283,10 +330,21 @@ impl<B: Backend> CausalAttentionv1<B> {
|
|||||||
// - dot products get larger with larger dimensions
|
// - dot products get larger with larger dimensions
|
||||||
// - this causes softmax to "saturate", making all other values very small
|
// - this causes softmax to "saturate", making all other values very small
|
||||||
// - which makes gradients vanish during training
|
// - which makes gradients vanish during training
|
||||||
let attn_weights = softmax(attn_scores / (keys.shape()[2] as f32).sqrt(), 2);
|
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
|
||||||
let attn_weights = self.dropout.forward(attn_weights);
|
let attn_weights = self.dropout.forward(attn_weights);
|
||||||
|
|
||||||
let context_vec = attn_weights.matmul(values);
|
// lhs shape: [batch, head, query_token, context_token]
|
||||||
|
// rhs shape: [batch, head, tok, head_dim]
|
||||||
|
// matmul shape: [batch, head, tok, head_dim]
|
||||||
|
// out shape: [batch, tok, head, head_dim]
|
||||||
|
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
|
||||||
|
|
||||||
|
// shape: [batch, tok, stacked_dim]
|
||||||
|
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
|
||||||
|
|
||||||
|
// Apply final projection (optional)
|
||||||
|
let context_vec = context_vec.matmul(w_output);
|
||||||
|
|
||||||
return context_vec;
|
return context_vec;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user