1
0

Chapter 5: training loop

This commit is contained in:
2025-12-15 19:24:40 -08:00
parent e29b25c162
commit c59ca2164e
2 changed files with 369 additions and 179 deletions

View File

@@ -16,10 +16,10 @@ pub enum SubCommand {
args: train_tokenizer::TrainTokenizerArgs, args: train_tokenizer::TrainTokenizerArgs,
}, },
/// Sample data /// Train model
SampleData { TrainModel {
#[command(flatten)] #[command(flatten)]
args: sample_data::SampleDataArgs, args: sample_data::TrainModelArgs,
}, },
} }
@@ -28,7 +28,7 @@ impl SubCommand {
match self { match self {
Self::Download { args } => args.run(mp), Self::Download { args } => args.run(mp),
Self::TrainTokenizer { args } => args.run(mp), Self::TrainTokenizer { args } => args.run(mp),
Self::SampleData { args } => args.run(mp), Self::TrainModel { args } => args.run(mp),
} }
} }
} }

View File

@@ -1,221 +1,351 @@
use ahash::AHasher;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use burn::{ use burn::{
Tensor, Tensor,
backend::{Cuda, cuda::CudaDevice}, backend::{Autodiff, Cuda, cuda::CudaDevice},
module::{Module, Param, ParamId}, config::Config,
module::{AutodiffModule, Module, Param, ParamId},
nn::{ nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
loss::CrossEntropyLossConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}, transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
}, },
prelude::Backend, optim::{AdamConfig, GradientsParams, Optimizer},
prelude::{Backend, ToElement},
tensor::{Bool, Distribution, Int, activation::softmax}, tensor::{Bool, Distribution, Int, activation::softmax},
}; };
use burn_train::ClassificationOutput;
use clap::Args; use clap::Args;
use indicatif::MultiProgress; use indicatif::MultiProgress;
use ndarray::Array2; use ndarray::{Array1, Array2};
use std::{f32, fs::File, path::PathBuf}; use std::{
collections::VecDeque,
f32,
fs::File,
hash::Hasher,
path::{Path, PathBuf},
};
use tokenizer::Tokenizer; use tokenizer::Tokenizer;
use tracing::{debug, info};
use crate::data_reader::DataReader; use crate::data_reader::{DataReader, DataReaderError};
// Text generation routine
/*
{
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
let n_tokens = tokens.len();
let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
for _ in 0..100 {
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
println!("{:?}", tokens);
println!("{}", tokenizer.decode(&tokens));
// Crop idx to context size;
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/
struct TrainTestIterator<'a, B: Backend> {
reader: DataReader,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
error: bool,
// Tokenized input/output pairs
pairs: VecDeque<(Vec<u32>, u32)>,
}
impl<'a, B: Backend> TrainTestIterator<'a, B> {
pub fn new(
data_dir: impl AsRef<Path>,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
) -> Result<Self, std::io::Error> {
let reader = DataReader::new(10, data_dir)?;
Ok(Self {
reader,
ccfg,
mcfg,
tokenizer,
eval,
device,
error: false,
pairs: VecDeque::new(),
})
}
}
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
type Item = Result<TrainBatch<B>, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
if self.error {
return None;
}
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
let stride = self.mcfg.context_size;
while inputs.len() < self.ccfg.batch_size {
match self.pairs.pop_front() {
Some((i, t)) => {
// train/test split
{
let mut hasher = AHasher::default();
hasher.write(self.ccfg.eval_salt.as_bytes());
// Don't care about endianness, ahash output is unstable anyway
hasher.write(unsafe { std::mem::transmute(&i[..]) });
hasher.write_u32(t);
let test = // is this point in the test set?
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
if test ^ self.eval {
continue;
}
}
inputs.push(i);
targets.push(t);
}
None => {
let text = match self.reader.next() {
None => break,
Some(Ok(x)) => x,
Some(Err(x)) => {
self.error = true;
return Some(Err(x));
}
};
let emb = self.tokenizer.encode(&text);
// Skip small texts
//
// TODO: do this better
// TODO: maybe using <|bos|>?
// TODO: non-uniform batches?
if emb.len() < self.mcfg.context_size {
continue;
}
let pairs = emb
.windows(self.mcfg.context_size + 1)
.step_by(stride)
.map(|x| {
(
x[..self.mcfg.context_size].to_vec(),
x[self.mcfg.context_size],
)
});
self.pairs.extend(pairs);
}
}
}
if inputs.is_empty() {
return None;
}
let shape = [inputs.len(), self.mcfg.context_size];
// Arrange data in memory
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
let targets: Array1<u32> = Array1::from_vec(targets);
// Create tensors on gpu
#[expect(clippy::unwrap_used)]
let inputs =
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
#[expect(clippy::unwrap_used)]
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
return Some(Ok(TrainBatch { inputs, targets }));
}
}
#[derive(Debug, Args, Clone)] #[derive(Debug, Args, Clone)]
pub struct SampleDataArgs { pub struct TrainModelArgs {
/// Path to training data /// Path to training data
#[clap(long, default_value = "data")] data: PathBuf,
data_dir: PathBuf,
/// Path to tokenizer /// Path to tokenizer
#[clap(long)] #[clap(long)]
tokenizer: PathBuf, tokenizer: PathBuf,
/// How many texts to return
#[clap(long, short = 'n', default_value = "10")]
n: usize,
/// How many texts to skip
#[clap(long, short = 's', default_value = "0")]
skip: usize,
} }
#[derive(Debug, Clone)] pub struct ComputeConfig {
pub struct Config { pub batch_size: usize,
/// Number of tokens pub eval_frac: f64,
pub vocab_size: u32, pub eval_salt: String,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
} }
impl SampleDataArgs { impl TrainModelArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> { pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0); let device = CudaDevice::new(0);
//let device = WgpuDevice::DiscreteGpu(0);
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?; let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer = let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?; serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let config = Config { let ccfg = ComputeConfig {
batch_size: 10,
eval_frac: 0.1,
eval_salt: "salt".into(),
};
let mcfg = GptModelConfig {
vocab_size: tokenizer.vocab_size(), vocab_size: tokenizer.vocab_size(),
context_size: 4, context_size: 256,
embed_dim: 768, embed_dim: 768,
n_heads: 12, n_heads: 12,
head_dim: 64, // = 768 / 12 head_dim: 64, // = 768 / 12
n_layers: 12, n_layers: 1,
embed_drop_rate: 0.1, embed_drop_rate: 0.1,
attention_drop_rate: 0.1, attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1, shortcut_drop_rate: 0.1,
}; };
let stride = config.context_size; let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
let batch_size = 10;
let mut input_batch = Vec::with_capacity(batch_size);
let mut output_batch = Vec::with_capacity(batch_size);
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
let model = GptModel::new(&config, &device);
// Text generation routine
/* /*
{ let loader_train = DataLoaderBuilder::new(batcher.clone())
let init = "Initial context. This is "; .batch_size(ccfg.batch_size)
let tokens = tokenizer.encode(&init); //.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let n_tokens = tokens.len(); let loader_test = DataLoaderBuilder::new(batcher)
let input: Array1<u32> = Array1::from_vec(tokens); .batch_size(ccfg.batch_size)
let mut input: Tensor<Cuda, 1, Int> = //.shuffle(config.seed)
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device) .num_workers(5)
.reshape([n_tokens]); .build(Loader::new(&self.data_dir).context("while initializing loader")?);
for _ in 0..100 { let learner = LearnerBuilder::new("./tmp")
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap(); .metric_train_numeric(AccuracyMetric::new())
println!("{:?}", tokens); .metric_valid_numeric(AccuracyMetric::new())
println!("{}", tokenizer.decode(&tokens)); .metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
.num_epochs(10)
.summary()
.build(model, AdamConfig::new().init(), 1e-4);
// Crop idx to context size; learner.fit(loader_train, loader_test);
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/ */
for i in iter { // Initialize optimizer
let tokens = tokenizer.encode(&i); let mut optim = AdamConfig::new().init();
let learning_rate = 1e-4;
// Skip small texts. for epoch in 0..10 {
// TODO: do this better debug!("Running epoch {epoch}");
// TODO: non-uniform batches?
if tokens.len() < config.context_size { // Training phase
continue; let mut train_loss_sum = 0.0;
let mut train_total = 0;
for batch in
TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
// Forward pass with gradients
let output = model.forward_train(batch.inputs, batch.targets);
train_total += output.targets.dims()[0] as i32;
train_loss_sum += output.loss.clone().into_scalar().to_f32();
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
} }
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip( let mut valid_loss_sum = 0.0;
tokens[stride..] let mut valid_total = 0;
.windows(config.context_size)
.step_by(stride),
) {
input_batch.push(a.to_owned());
output_batch.push(b.to_owned());
/* let mut n_eval = 0;
let context = a; debug!("Evaluating batches");
let desired = &b[b.len() - 1..];
println!("{context:?} -> {desired:?}");
let input = tokenizer.decode(context); for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device)
let target = tokenizer.decode(desired); .context("while initializing reader")?
println!("{input:?} -> {target:?}"); {
*/ let batch = batch.context("while reading batch")?;
n_eval += batch.targets.shape()[0];
if input_batch.len() >= batch_size { // Forward pass without gradients
let shape = [input_batch.len(), config.context_size]; let output = model.valid().forward_train(batch.inputs, batch.targets);
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); valid_total += output.targets.dims()[0] as i32;
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]); valid_loss_sum += output.loss.into_scalar().to_f32();
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape(shape);
let output =
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
.reshape(shape);
self.batch(&config, input, &model);
}
} }
}
if !input_batch.is_empty() { // Compute and log epoch results
let shape = [input_batch.len(), config.context_size]; let train_loss = if train_total > 0 {
train_loss_sum / train_total as f32
} else {
0.0
};
let valid_loss = if valid_total > 0 {
valid_loss_sum / valid_total as f32
} else {
0.0
};
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size)); info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
self.batch(&config, input, &model);
} }
Ok(()) Ok(())
} }
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
let logits = model.forward(input);
println!("{:?}", logits.shape());
}
} }
//
// MARK: model
//
/// Multihead attention. /// Multihead attention.
/// ///
/// Equivalent to many stacked CausalAttention layers. /// Equivalent to many stacked CausalAttention layers.
@@ -315,7 +445,7 @@ impl<B: Backend> MultiheadAttention<B> {
}, },
device.clone(), device.clone(),
true, true,
[embedding_dim, total_dim].into(), [total_dim, total_dim].into(),
), ),
dropout: Dropout { prob: dropout }, dropout: Dropout { prob: dropout },
@@ -389,6 +519,7 @@ impl<B: Backend> MultiheadAttention<B> {
let mask = self let mask = self
.utri_mask .utri_mask
.clone() .clone()
.slice([0..tokens, 0..tokens])
.unsqueeze_dim::<3>(0) .unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0) .unsqueeze_dim::<4>(0)
.expand(attn_scores.shape()); .expand(attn_scores.shape());
@@ -422,35 +553,50 @@ impl<B: Backend> MultiheadAttention<B> {
} }
} }
#[derive(Module, Debug)] #[derive(Config, Debug)]
pub struct GptModel<B: Backend> { pub struct GptModelConfig {
embedder_tok: Embedding<B>, /// Number of tokens
embedder_pos: Embedding<B>, pub vocab_size: u32,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>, /// Maximum number of input tokens with positional embeddings
final_norm: LayerNorm<B>, pub context_size: usize,
out_head: Param<Tensor<B, 2>>,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
} }
impl<B: Backend> GptModel<B> { impl GptModelConfig {
pub fn new(cfg: &Config, device: &B::Device) -> Self { pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize]; let out_head_shape = [self.embed_dim, self.vocab_size as usize];
Self { GptModel {
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device), embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
.init(device),
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device), embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
embedder_drop: Dropout { embedder_drop: Dropout {
prob: cfg.embed_drop_rate, prob: self.embed_drop_rate,
}, },
trf_blocks: (0..cfg.n_layers) trf_blocks: (0..self.n_layers)
.map(|_| TransformerBlock::new(cfg, device)) .map(|_| TransformerBlock::new(&self, device))
.collect(), .collect(),
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device), final_norm: LayerNormConfig::new(self.embed_dim).init(device),
out_head: Param::uninitialized( out_head: Param::uninitialized(
ParamId::new(), ParamId::new(),
@@ -464,13 +610,34 @@ impl<B: Backend> GptModel<B> {
), ),
} }
} }
}
#[derive(Debug, Clone)]
pub struct TrainBatch<B: Backend> {
pub inputs: Tensor<B, 2, Int>,
/// Correct next token for each input
pub targets: Tensor<B, 1, Int>,
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> { pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1]; let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone()); let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self let embed_pos = self
.embedder_tok .embedder_pos
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0)); .forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos; let x = embed_tok + embed_pos;
@@ -481,6 +648,29 @@ impl<B: Backend> GptModel<B> {
return logits; return logits;
} }
pub fn forward_train(
&self,
inputs: Tensor<B, 2, Int>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// shape: [batch, n_tokens, n_vocabulary]
let output = self.forward(inputs);
// Get last token
// shape: [batch, n_vocabulary]
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
let loss = CrossEntropyLossConfig::new()
.init(&targets.device())
.forward(output.clone(), targets.clone());
ClassificationOutput {
loss,
output,
targets,
}
}
} }
#[derive(Module, Debug)] #[derive(Module, Debug)]
@@ -498,7 +688,7 @@ pub struct TransformerBlock<B: Backend> {
} }
impl<B: Backend> TransformerBlock<B> { impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self { pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
Self { Self {
attention: MultiheadAttention::new( attention: MultiheadAttention::new(
cfg.embed_dim, cfg.embed_dim,