1
0

Chapter 5: training loop

This commit is contained in:
2025-12-15 19:24:40 -08:00
parent e29b25c162
commit c59ca2164e
2 changed files with 369 additions and 179 deletions

View File

@@ -16,10 +16,10 @@ pub enum SubCommand {
args: train_tokenizer::TrainTokenizerArgs,
},
/// Sample data
SampleData {
/// Train model
TrainModel {
#[command(flatten)]
args: sample_data::SampleDataArgs,
args: sample_data::TrainModelArgs,
},
}
@@ -28,7 +28,7 @@ impl SubCommand {
match self {
Self::Download { args } => args.run(mp),
Self::TrainTokenizer { args } => args.run(mp),
Self::SampleData { args } => args.run(mp),
Self::TrainModel { args } => args.run(mp),
}
}
}

View File

@@ -1,100 +1,34 @@
use ahash::AHasher;
use anyhow::{Context, Result};
use burn::{
Tensor,
backend::{Cuda, cuda::CudaDevice},
module::{Module, Param, ParamId},
backend::{Autodiff, Cuda, cuda::CudaDevice},
config::Config,
module::{AutodiffModule, Module, Param, ParamId},
nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
loss::CrossEntropyLossConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
},
prelude::Backend,
optim::{AdamConfig, GradientsParams, Optimizer},
prelude::{Backend, ToElement},
tensor::{Bool, Distribution, Int, activation::softmax},
};
use burn_train::ClassificationOutput;
use clap::Args;
use indicatif::MultiProgress;
use ndarray::Array2;
use std::{f32, fs::File, path::PathBuf};
use tokenizer::Tokenizer;
use crate::data_reader::DataReader;
#[derive(Debug, Args, Clone)]
pub struct SampleDataArgs {
/// Path to training data
#[clap(long, default_value = "data")]
data_dir: PathBuf,
/// Path to tokenizer
#[clap(long)]
tokenizer: PathBuf,
/// How many texts to return
#[clap(long, short = 'n', default_value = "10")]
n: usize,
/// How many texts to skip
#[clap(long, short = 's', default_value = "0")]
skip: usize,
}
#[derive(Debug, Clone)]
pub struct Config {
/// Number of tokens
pub vocab_size: u32,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
}
impl SampleDataArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0);
let iter = DataReader::new(1, &self.data_dir).context("while initializing data reader")?;
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let config = Config {
vocab_size: tokenizer.vocab_size(),
context_size: 4,
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 12,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
use ndarray::{Array1, Array2};
use std::{
collections::VecDeque,
f32,
fs::File,
hash::Hasher,
path::{Path, PathBuf},
};
use tokenizer::Tokenizer;
use tracing::{debug, info};
let stride = config.context_size;
let batch_size = 10;
let mut input_batch = Vec::with_capacity(batch_size);
let mut output_batch = Vec::with_capacity(batch_size);
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap()).skip(self.skip).take(self.n);
let model = GptModel::new(&config, &device);
use crate::data_reader::{DataReader, DataReaderError};
// Text generation routine
@@ -134,87 +68,283 @@ impl SampleDataArgs {
}
*/
for i in iter {
let tokens = tokenizer.encode(&i);
struct TrainTestIterator<'a, B: Backend> {
reader: DataReader,
// Skip small texts.
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
error: bool,
// Tokenized input/output pairs
pairs: VecDeque<(Vec<u32>, u32)>,
}
impl<'a, B: Backend> TrainTestIterator<'a, B> {
pub fn new(
data_dir: impl AsRef<Path>,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
) -> Result<Self, std::io::Error> {
let reader = DataReader::new(10, data_dir)?;
Ok(Self {
reader,
ccfg,
mcfg,
tokenizer,
eval,
device,
error: false,
pairs: VecDeque::new(),
})
}
}
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
type Item = Result<TrainBatch<B>, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
if self.error {
return None;
}
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
let stride = self.mcfg.context_size;
while inputs.len() < self.ccfg.batch_size {
match self.pairs.pop_front() {
Some((i, t)) => {
// train/test split
{
let mut hasher = AHasher::default();
hasher.write(self.ccfg.eval_salt.as_bytes());
// Don't care about endianness, ahash output is unstable anyway
hasher.write(unsafe { std::mem::transmute(&i[..]) });
hasher.write_u32(t);
let test = // is this point in the test set?
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
if test ^ self.eval {
continue;
}
}
inputs.push(i);
targets.push(t);
}
None => {
let text = match self.reader.next() {
None => break,
Some(Ok(x)) => x,
Some(Err(x)) => {
self.error = true;
return Some(Err(x));
}
};
let emb = self.tokenizer.encode(&text);
// Skip small texts
//
// TODO: do this better
// TODO: maybe using <|bos|>?
// TODO: non-uniform batches?
if tokens.len() < config.context_size {
if emb.len() < self.mcfg.context_size {
continue;
}
for (a, b) in tokens.windows(config.context_size).step_by(stride).zip(
tokens[stride..]
.windows(config.context_size)
.step_by(stride),
) {
input_batch.push(a.to_owned());
output_batch.push(b.to_owned());
let pairs = emb
.windows(self.mcfg.context_size + 1)
.step_by(stride)
.map(|x| {
(
x[..self.mcfg.context_size].to_vec(),
x[self.mcfg.context_size],
)
});
self.pairs.extend(pairs);
}
}
}
if inputs.is_empty() {
return None;
}
let shape = [inputs.len(), self.mcfg.context_size];
// Arrange data in memory
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
let targets: Array1<u32> = Array1::from_vec(targets);
// Create tensors on gpu
#[expect(clippy::unwrap_used)]
let inputs =
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
#[expect(clippy::unwrap_used)]
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
return Some(Ok(TrainBatch { inputs, targets }));
}
}
#[derive(Debug, Args, Clone)]
pub struct TrainModelArgs {
/// Path to training data
data: PathBuf,
/// Path to tokenizer
#[clap(long)]
tokenizer: PathBuf,
}
pub struct ComputeConfig {
pub batch_size: usize,
pub eval_frac: f64,
pub eval_salt: String,
}
impl TrainModelArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0);
//let device = WgpuDevice::DiscreteGpu(0);
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let ccfg = ComputeConfig {
batch_size: 10,
eval_frac: 0.1,
eval_salt: "salt".into(),
};
let mcfg = GptModelConfig {
vocab_size: tokenizer.vocab_size(),
context_size: 256,
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 1,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
};
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
/*
let context = a;
let desired = &b[b.len() - 1..];
println!("{context:?} -> {desired:?}");
let loader_train = DataLoaderBuilder::new(batcher.clone())
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let input = tokenizer.decode(context);
let target = tokenizer.decode(desired);
println!("{input:?} -> {target:?}");
let loader_test = DataLoaderBuilder::new(batcher)
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let learner = LearnerBuilder::new("./tmp")
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
.num_epochs(10)
.summary()
.build(model, AdamConfig::new().init(), 1e-4);
learner.fit(loader_train, loader_test);
*/
if input_batch.len() >= batch_size {
let shape = [input_batch.len(), config.context_size];
// Initialize optimizer
let mut optim = AdamConfig::new().init();
let learning_rate = 1e-4;
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
for epoch in 0..10 {
debug!("Running epoch {epoch}");
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape(shape);
// Training phase
let mut train_loss_sum = 0.0;
let mut train_total = 0;
let output =
std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
for batch in
TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device)
.reshape(shape);
// Forward pass with gradients
let output = model.forward_train(batch.inputs, batch.targets);
self.batch(&config, input, &model);
}
}
train_total += output.targets.dims()[0] as i32;
train_loss_sum += output.loss.clone().into_scalar().to_f32();
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
}
if !input_batch.is_empty() {
let shape = [input_batch.len(), config.context_size];
let mut valid_loss_sum = 0.0;
let mut valid_total = 0;
let input = std::mem::replace(&mut input_batch, Vec::with_capacity(batch_size));
let input: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| input[a][b]);
let mut n_eval = 0;
debug!("Evaluating batches");
#[expect(clippy::unwrap_used)]
let input: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device).reshape(shape);
for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
n_eval += batch.targets.shape()[0];
let output = std::mem::replace(&mut output_batch, Vec::with_capacity(batch_size));
let output: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| output[a][b]);
// Forward pass without gradients
let output = model.valid().forward_train(batch.inputs, batch.targets);
#[expect(clippy::unwrap_used)]
let output: Tensor<Cuda, 2, Int> =
Tensor::<_, 1, Int>::from_ints(output.as_slice().unwrap(), &device).reshape(shape);
valid_total += output.targets.dims()[0] as i32;
valid_loss_sum += output.loss.into_scalar().to_f32();
}
self.batch(&config, input, &model);
// Compute and log epoch results
let train_loss = if train_total > 0 {
train_loss_sum / train_total as f32
} else {
0.0
};
let valid_loss = if valid_total > 0 {
valid_loss_sum / valid_total as f32
} else {
0.0
};
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
}
Ok(())
}
}
fn batch(&self, _cfg: &Config, input: Tensor<Cuda, 2, Int>, model: &GptModel<Cuda>) {
let logits = model.forward(input);
println!("{:?}", logits.shape());
}
}
//
// MARK: model
//
/// Multihead attention.
///
@@ -315,7 +445,7 @@ impl<B: Backend> MultiheadAttention<B> {
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
[total_dim, total_dim].into(),
),
dropout: Dropout { prob: dropout },
@@ -389,6 +519,7 @@ impl<B: Backend> MultiheadAttention<B> {
let mask = self
.utri_mask
.clone()
.slice([0..tokens, 0..tokens])
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
.expand(attn_scores.shape());
@@ -422,35 +553,50 @@ impl<B: Backend> MultiheadAttention<B> {
}
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
#[derive(Config, Debug)]
pub struct GptModelConfig {
/// Number of tokens
pub vocab_size: u32,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
}
impl<B: Backend> GptModel<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
let out_head_shape = [cfg.embed_dim, cfg.vocab_size as usize];
impl GptModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
Self {
embedder_tok: EmbeddingConfig::new(cfg.vocab_size as usize, cfg.embed_dim).init(device),
GptModel {
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
.init(device),
embedder_pos: EmbeddingConfig::new(cfg.context_size, cfg.embed_dim).init(device),
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
embedder_drop: Dropout {
prob: cfg.embed_drop_rate,
prob: self.embed_drop_rate,
},
trf_blocks: (0..cfg.n_layers)
.map(|_| TransformerBlock::new(cfg, device))
trf_blocks: (0..self.n_layers)
.map(|_| TransformerBlock::new(&self, device))
.collect(),
final_norm: LayerNormConfig::new(cfg.embed_dim).init(device),
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
out_head: Param::uninitialized(
ParamId::new(),
@@ -464,13 +610,34 @@ impl<B: Backend> GptModel<B> {
),
}
}
}
#[derive(Debug, Clone)]
pub struct TrainBatch<B: Backend> {
pub inputs: Tensor<B, 2, Int>,
/// Correct next token for each input
pub targets: Tensor<B, 1, Int>,
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self
.embedder_tok
.embedder_pos
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos;
@@ -481,6 +648,29 @@ impl<B: Backend> GptModel<B> {
return logits;
}
pub fn forward_train(
&self,
inputs: Tensor<B, 2, Int>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// shape: [batch, n_tokens, n_vocabulary]
let output = self.forward(inputs);
// Get last token
// shape: [batch, n_vocabulary]
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
let loss = CrossEntropyLossConfig::new()
.init(&targets.device())
.forward(output.clone(), targets.clone());
ClassificationOutput {
loss,
output,
targets,
}
}
}
#[derive(Module, Debug)]
@@ -498,7 +688,7 @@ pub struct TransformerBlock<B: Backend> {
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &Config, device: &B::Device) -> Self {
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
Self {
attention: MultiheadAttention::new(
cfg.embed_dim,