Chapter 4: full model
This commit is contained in:
@@ -3,7 +3,10 @@ use burn::{
|
|||||||
Tensor,
|
Tensor,
|
||||||
backend::{Cuda, cuda::CudaDevice},
|
backend::{Cuda, cuda::CudaDevice},
|
||||||
module::{Module, Param, ParamId},
|
module::{Module, Param, ParamId},
|
||||||
nn::{Dropout, Embedding, EmbeddingConfig},
|
nn::{
|
||||||
|
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
|
||||||
|
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
|
||||||
|
},
|
||||||
prelude::Backend,
|
prelude::Backend,
|
||||||
tensor::{Bool, Distribution, Int, activation::softmax},
|
tensor::{Bool, Distribution, Int, activation::softmax},
|
||||||
};
|
};
|
||||||
@@ -35,26 +38,54 @@ pub struct SampleDataArgs {
|
|||||||
skip: usize,
|
skip: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
/// Number of tokens
|
||||||
|
pub vocab_size: u32,
|
||||||
|
|
||||||
|
/// Maximum number of input tokens with positional embeddings
|
||||||
|
pub context_size: usize,
|
||||||
|
|
||||||
|
/// Dimension of each token's embedding
|
||||||
|
pub embed_dim: usize,
|
||||||
|
|
||||||
|
/// Number of attention heads
|
||||||
|
pub n_heads: usize,
|
||||||
|
|
||||||
|
/// Dimension of each attn head
|
||||||
|
pub head_dim: usize,
|
||||||
|
|
||||||
|
/// Number of transformer blocks
|
||||||
|
pub n_layers: usize,
|
||||||
|
|
||||||
|
pub embed_drop_rate: f64,
|
||||||
|
pub attention_drop_rate: f64,
|
||||||
|
pub shortcut_drop_rate: f64,
|
||||||
|
}
|
||||||
|
|
||||||
impl SampleDataArgs {
|
impl SampleDataArgs {
|
||||||
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
|
||||||
let device = CudaDevice::new(0);
|
let device = CudaDevice::new(0);
|
||||||
|
|
||||||
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
|
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
|
||||||
|
|
||||||
let tokenizer = File::open(self.tokenizer).context("while opening tokenizer")?;
|
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
|
||||||
let tokenizer: Tokenizer =
|
let tokenizer: Tokenizer =
|
||||||
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
|
||||||
|
|
||||||
let context_size = 4;
|
let config = Config {
|
||||||
let stride = 4;
|
vocab_size: tokenizer.vocab_size(),
|
||||||
|
context_size: 4,
|
||||||
|
embed_dim: 768,
|
||||||
|
n_heads: 12,
|
||||||
|
head_dim: 64, // = 768 / 12
|
||||||
|
n_layers: 12,
|
||||||
|
embed_drop_rate: 0.1,
|
||||||
|
attention_drop_rate: 0.1,
|
||||||
|
shortcut_drop_rate: 0.1,
|
||||||
|
};
|
||||||
|
|
||||||
// Dimension of each token vector
|
let stride = config.context_size;
|
||||||
let embedding_dim = 3;
|
|
||||||
// attention
|
|
||||||
let head_dim = 2;
|
|
||||||
let n_heads = 2;
|
|
||||||
|
|
||||||
let dropout = 0.5f64;
|
|
||||||
|
|
||||||
let batch_size = 10;
|
let batch_size = 10;
|
||||||
let mut input_batch = Vec::with_capacity(batch_size);
|
let mut input_batch = Vec::with_capacity(batch_size);
|
||||||
@@ -63,42 +94,64 @@ impl SampleDataArgs {
|
|||||||
#[expect(clippy::unwrap_used)] // Lazy error handling
|
#[expect(clippy::unwrap_used)] // Lazy error handling
|
||||||
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
|
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
|
||||||
|
|
||||||
// TODO: what is this?
|
let model = GptModel::new(&config, &device);
|
||||||
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
|
|
||||||
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
|
|
||||||
|
|
||||||
// TODO: do we want to train this?
|
// Text generation routine
|
||||||
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
|
|
||||||
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
|
|
||||||
|
|
||||||
let pos_tensor: Tensor<Cuda, 2, Int> =
|
/*
|
||||||
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0);
|
{
|
||||||
|
let init = "Initial context. This is ";
|
||||||
|
let tokens = tokenizer.encode(&init);
|
||||||
|
|
||||||
// [context_size, dim]
|
let n_tokens = tokens.len();
|
||||||
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0);
|
let input: Array1<u32> = Array1::from_vec(tokens);
|
||||||
|
let mut input: Tensor<Cuda, 1, Int> =
|
||||||
|
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||||
|
.reshape([n_tokens]);
|
||||||
|
|
||||||
let attention: MultiheadAttention<Cuda> = MultiheadAttention::new(
|
for _ in 0..100 {
|
||||||
embedding_dim,
|
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
|
||||||
head_dim,
|
println!("{:?}", tokens);
|
||||||
n_heads,
|
println!("{}", tokenizer.decode(&tokens));
|
||||||
context_size,
|
|
||||||
dropout,
|
// Crop idx to context size;
|
||||||
&device,
|
let batch = input
|
||||||
);
|
.clone()
|
||||||
|
.slice([0..config.context_size])
|
||||||
|
.unsqueeze_dim(0);
|
||||||
|
|
||||||
|
// shape: [tokens, vocab_size]
|
||||||
|
let logits = model.forward(batch).squeeze_dim::<2>(0);
|
||||||
|
|
||||||
|
// shape: [vocab_size]
|
||||||
|
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
|
||||||
|
|
||||||
|
let probs = softmax(logits, 0); // shape: [n_tokens]
|
||||||
|
let id_next = probs.argmax(0); // shape: [1]
|
||||||
|
|
||||||
|
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
for i in iter {
|
for i in iter {
|
||||||
let tokens = tokenizer.encode(&i);
|
let tokens = tokenizer.encode(&i);
|
||||||
|
|
||||||
for (a, b) in tokens
|
// Skip small texts.
|
||||||
.windows(context_size)
|
// TODO: do this better
|
||||||
.step_by(stride)
|
// TODO: non-uniform batches?
|
||||||
.zip(tokens[stride..].windows(context_size).step_by(stride))
|
if tokens.len() < config.context_size {
|
||||||
{
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
|
||||||
|
tokens[stride..]
|
||||||
|
.windows(config.context_size)
|
||||||
|
.step_by(stride),
|
||||||
|
) {
|
||||||
input_batch.push(a.to_owned());
|
input_batch.push(a.to_owned());
|
||||||
output_batch.push(b.to_owned());
|
output_batch.push(b.to_owned());
|
||||||
|
|
||||||
// TODO: non-uniform batches?
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
let context = a;
|
let context = a;
|
||||||
let desired = &b[b.len() - 1..];
|
let desired = &b[b.len() - 1..];
|
||||||
@@ -109,12 +162,13 @@ impl SampleDataArgs {
|
|||||||
println!("{input:?} -> {target:?}");
|
println!("{input:?} -> {target:?}");
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// TODO: last batch
|
|
||||||
if input_batch.len() >= batch_size {
|
if input_batch.len() >= batch_size {
|
||||||
let shape = [batch_size, context_size];
|
let shape = [input_batch.len(), config.context_size];
|
||||||
|
|
||||||
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
||||||
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
let input: Tensor<Cuda, 2, Int> =
|
let input: Tensor<Cuda, 2, Int> =
|
||||||
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
|
||||||
.reshape(shape);
|
.reshape(shape);
|
||||||
@@ -122,25 +176,44 @@ impl SampleDataArgs {
|
|||||||
let output =
|
let output =
|
||||||
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
||||||
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
let output: Tensor<Cuda, 2, Int> =
|
let output: Tensor<Cuda, 2, Int> =
|
||||||
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
|
||||||
.reshape(shape);
|
.reshape(shape);
|
||||||
|
|
||||||
// Input token embeddings
|
self.batch(&config, input, &model);
|
||||||
// dim: [batch, token, dim]
|
|
||||||
let tok_e = tok_embedder.forward(input);
|
|
||||||
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
|
|
||||||
|
|
||||||
// Attention
|
|
||||||
|
|
||||||
// shape: [batch, tokens, out_dim]
|
|
||||||
let a = attention.forward(tok_e);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !input_batch.is_empty() {
|
||||||
|
let shape = [input_batch.len(), config.context_size];
|
||||||
|
|
||||||
|
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
|
||||||
|
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let input: Tensor<Cuda, 2, Int> =
|
||||||
|
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
|
||||||
|
|
||||||
|
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
|
||||||
|
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let output: Tensor<Cuda, 2, Int> =
|
||||||
|
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
|
||||||
|
|
||||||
|
self.batch(&config, input, &model);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
|
||||||
|
let logits = model.forward(input);
|
||||||
|
println!("{:?}", logits.shape());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Multihead attention.
|
/// Multihead attention.
|
||||||
@@ -247,7 +320,7 @@ impl<B: Backend> MultiheadAttention<B> {
|
|||||||
|
|
||||||
dropout: Dropout { prob: dropout },
|
dropout: Dropout { prob: dropout },
|
||||||
|
|
||||||
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, &device),
|
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, device),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,3 +421,125 @@ impl<B: Backend> MultiheadAttention<B> {
|
|||||||
return context_vec;
|
return context_vec;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct GptModel<B: Backend> {
|
||||||
|
embedder_tok: Embedding<B>,
|
||||||
|
embedder_pos: Embedding<B>,
|
||||||
|
embedder_drop: Dropout,
|
||||||
|
|
||||||
|
trf_blocks: Vec<TransformerBlock<B>>,
|
||||||
|
final_norm: LayerNorm<B>,
|
||||||
|
out_head: Param<Tensor<B, 2>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> GptModel<B> {
|
||||||
|
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
||||||
|
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
|
||||||
|
|
||||||
|
Self {
|
||||||
|
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
|
||||||
|
|
||||||
|
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
|
||||||
|
|
||||||
|
embedder_drop: Dropout {
|
||||||
|
prob: cfg.embed_drop_rate,
|
||||||
|
},
|
||||||
|
|
||||||
|
trf_blocks: (0..cfg.n_layers)
|
||||||
|
.map(|_| TransformerBlock::new(cfg, device))
|
||||||
|
.collect(),
|
||||||
|
|
||||||
|
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||||
|
|
||||||
|
out_head: Param::uninitialized(
|
||||||
|
ParamId::new(),
|
||||||
|
move |device, is_require_grad| {
|
||||||
|
Tensor::random(out_head_shape, Distribution::Default, device)
|
||||||
|
.set_require_grad(is_require_grad)
|
||||||
|
},
|
||||||
|
device.clone(),
|
||||||
|
true,
|
||||||
|
out_head_shape.into(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||||
|
let n_tokens = input.shape()[1];
|
||||||
|
|
||||||
|
let embed_tok = self.embedder_tok.forward(input.clone());
|
||||||
|
let embed_pos = self
|
||||||
|
.embedder_tok
|
||||||
|
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
|
||||||
|
|
||||||
|
let x = embed_tok + embed_pos;
|
||||||
|
let x = self.embedder_drop.forward(x);
|
||||||
|
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
|
||||||
|
let x = self.final_norm.forward(x);
|
||||||
|
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
|
||||||
|
|
||||||
|
return logits;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Module, Debug)]
|
||||||
|
pub struct TransformerBlock<B: Backend> {
|
||||||
|
attention: MultiheadAttention<B>,
|
||||||
|
|
||||||
|
/// TODO: wtf?
|
||||||
|
ff: PositionWiseFeedForward<B>,
|
||||||
|
|
||||||
|
/// TODO: wtf?
|
||||||
|
norm_a: LayerNorm<B>,
|
||||||
|
norm_b: LayerNorm<B>,
|
||||||
|
|
||||||
|
drop_shortcut: Dropout,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> TransformerBlock<B> {
|
||||||
|
pub fn new(cfg: &Config, device: &B::Device) -> Self {
|
||||||
|
Self {
|
||||||
|
attention: MultiheadAttention::new(
|
||||||
|
cfg.embed_dim,
|
||||||
|
cfg.head_dim,
|
||||||
|
cfg.n_heads,
|
||||||
|
cfg.context_size,
|
||||||
|
cfg.attention_drop_rate,
|
||||||
|
device,
|
||||||
|
),
|
||||||
|
|
||||||
|
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
|
||||||
|
.with_dropout(0.0)
|
||||||
|
.init(device),
|
||||||
|
|
||||||
|
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||||
|
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
|
||||||
|
|
||||||
|
drop_shortcut: Dropout {
|
||||||
|
prob: cfg.shortcut_drop_rate,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||||
|
let input = {
|
||||||
|
let shortcut = input.clone();
|
||||||
|
let x = self.norm_a.forward(input);
|
||||||
|
let x = self.attention.forward(x);
|
||||||
|
let x = self.drop_shortcut.forward(x);
|
||||||
|
x + shortcut
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = {
|
||||||
|
// TODO: wtf?
|
||||||
|
let shortcut = input.clone();
|
||||||
|
let x = self.norm_b.forward(input);
|
||||||
|
let x = self.ff.forward(x);
|
||||||
|
let x = self.drop_shortcut.forward(x);
|
||||||
|
x + shortcut
|
||||||
|
};
|
||||||
|
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user