1
0

Chapter 4: full model
Some checks failed
CI / Check links (push) Successful in 6s
CI / Check typos (push) Successful in 10s
CI / Clippy (push) Failing after 1m21s
CI / Build and test (push) Failing after 2m59s

This commit is contained in:
2025-12-13 11:26:26 -08:00
parent 2f3c40b162
commit e29b25c162

View File

@@ -3,7 +3,10 @@ use burn::{
Tensor, Tensor,
backend::{Cuda, cuda::CudaDevice}, backend::{Cuda, cuda::CudaDevice},
module::{Module, Param, ParamId}, module::{Module, Param, ParamId},
nn::{Dropout, Embedding, EmbeddingConfig}, nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
},
prelude::Backend, prelude::Backend,
tensor::{Bool, Distribution, Int, activation::softmax}, tensor::{Bool, Distribution, Int, activation::softmax},
}; };
@@ -35,26 +38,54 @@ pub struct SampleDataArgs {
skip: usize, skip: usize,
} }
#[derive(Debug, Clone)]
pub struct Config {
/// Number of tokens
pub vocab_size: u32,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
}
impl SampleDataArgs { impl SampleDataArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> { pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0); let device = CudaDevice::new(0);
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?; let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
let tokenizer = File::open(self.tokenizer).context("while opening tokenizer")?; let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer = let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?; serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let context_size = 4; let config = Config {
let stride = 4; vocab_size: tokenizer.vocab_size(),
context_size: 4,
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 12,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
};
// Dimension of each token vector let stride = config.context_size;
let embedding_dim = 3;
// attention
let head_dim = 2;
let n_heads = 2;
let dropout = 0.5f64;
let batch_size = 10; let batch_size = 10;
let mut input_batch = Vec::with_capacity(batch_size); let mut input_batch = Vec::with_capacity(batch_size);
@@ -63,42 +94,64 @@ impl SampleDataArgs {
#[expect(clippy::unwrap_used)] // Lazy error handling #[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n); let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
// TODO: what is this? let model = GptModel::new(&config, &device);
let tok_embedder = EmbeddingConfig::new(tokenizer.vocab_size() as usize, embedding_dim);
let tok_embedder: Embedding<Cuda> = tok_embedder.init(&device);
// TODO: do we want to train this? // Text generation routine
let pos_embedder = EmbeddingConfig::new(context_size, embedding_dim);
let pos_embedder: Embedding<Cuda> = pos_embedder.init(&device);
let pos_tensor: Tensor<Cuda, 2, Int> = /*
Tensor::arange(0..context_size as i64, &device).unsqueeze_dim(0); {
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
// [context_size, dim] let n_tokens = tokens.len();
let pos_embedding = pos_embedder.forward(pos_tensor).squeeze_dim::<2>(0); let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
let attention: MultiheadAttention<Cuda> = MultiheadAttention::new( for _ in 0..100 {
embedding_dim, let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
head_dim, println!("{:?}", tokens);
n_heads, println!("{}", tokenizer.decode(&tokens));
context_size,
dropout, // Crop idx to context size;
&device, let batch = input
); .clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/
for i in iter { for i in iter {
let tokens = tokenizer.encode(&i); let tokens = tokenizer.encode(&i);
for (a, b) in tokens // Skip small texts.
.windows(context_size) // TODO: do this better
.step_by(stride) // TODO: non-uniform batches?
.zip(tokens[stride..].windows(context_size).step_by(stride)) if tokens.len() < config.context_size {
{ continue;
}
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
tokens[stride..]
.windows(config.context_size)
.step_by(stride),
) {
input_batch.push(a.to_owned()); input_batch.push(a.to_owned());
output_batch.push(b.to_owned()); output_batch.push(b.to_owned());
// TODO: non-uniform batches?
/* /*
let context = a; let context = a;
let desired = &b[b.len() - 1..]; let desired = &b[b.len() - 1..];
@@ -109,12 +162,13 @@ impl SampleDataArgs {
println!("{input:?} -> {target:?}"); println!("{input:?} -> {target:?}");
*/ */
// TODO: last batch
if input_batch.len() >= batch_size { if input_batch.len() >= batch_size {
let shape = [batch_size, context_size]; let shape = [input_batch.len(), config.context_size];
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> = let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape(shape); .reshape(shape);
@@ -122,25 +176,44 @@ impl SampleDataArgs {
let output = let output =
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size)); std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]); let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> = let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device) Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
.reshape(shape); .reshape(shape);
// Input token embeddings self.batch(&config, input, &model);
// dim: [batch, token, dim]
let tok_e = tok_embedder.forward(input);
let tok_e = tok_e.add(pos_embedding.clone().unsqueeze_dim(0));
// Attention
// shape: [batch, tokens, out_dim]
let a = attention.forward(tok_e);
} }
} }
} }
if !input_batch.is_empty() {
let shape = [input_batch.len(), config.context_size];
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
self.batch(&config, input, &model);
}
Ok(()) Ok(())
} }
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
let logits = model.forward(input);
println!("{:?}", logits.shape());
}
} }
/// Multihead attention. /// Multihead attention.
@@ -247,7 +320,7 @@ impl<B: Backend> MultiheadAttention<B> {
dropout: Dropout { prob: dropout }, dropout: Dropout { prob: dropout },
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, &device), utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, device),
} }
} }
@@ -348,3 +421,125 @@ impl<B: Backend> MultiheadAttention<B> {
return context_vec; return context_vec;
} }
} }
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
Self {
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
embedder_drop: Dropout {
prob: cfg.embed_drop_rate,
},
trf_blocks: (0..cfg.n_layers)
.map(|_| TransformerBlock::new(cfg, device))
.collect(),
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
out_head: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random(out_head_shape, Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
out_head_shape.into(),
),
}
}
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self
.embedder_tok
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos;
let x = self.embedder_drop.forward(x);
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
let x = self.final_norm.forward(x);
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
return logits;
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
attention: MultiheadAttention<B>,
/// TODO: wtf?
ff: PositionWiseFeedForward<B>,
/// TODO: wtf?
norm_a: LayerNorm<B>,
norm_b: LayerNorm<B>,
drop_shortcut: Dropout,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
Self {
attention: MultiheadAttention::new(
cfg.embed_dim,
cfg.head_dim,
cfg.n_heads,
cfg.context_size,
cfg.attention_drop_rate,
device,
),
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
.with_dropout(0.0)
.init(device),
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
drop_shortcut: Dropout {
prob: cfg.shortcut_drop_rate,
},
}
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let input = {
let shortcut = input.clone();
let x = self.norm_a.forward(input);
let x = self.attention.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
let input = {
// TODO: wtf?
let shortcut = input.clone();
let x = self.norm_b.forward(input);
let x = self.ff.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
return input;
}
}