1
0

Refactor
Some checks failed
CI / Check links (push) Successful in 11s
CI / Check typos (push) Failing after 11s
CI / Clippy (push) Failing after 1m18s
CI / Build and test (push) Failing after 2m2s

This commit is contained in:
2025-12-15 21:59:14 -08:00
parent c59ca2164e
commit 993e813b7e
19 changed files with 1583 additions and 1066 deletions

View File

@@ -10,15 +10,18 @@ workspace = true
[dependencies]
tokenizer = { workspace = true }
ahash = { workspace = true }
anstyle = { workspace = true }
anyhow = { workspace = true }
burn = { workspace = true }
burn-train = { workspace = true }
clap = { workspace = true }
futures-util = { workspace = true }
indicatif = { workspace = true }
ndarray = { workspace = true }
parking_lot = { workspace = true }
parquet = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }

View File

@@ -62,3 +62,16 @@ pub fn progress_bytes() -> ProgressStyle {
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
]);
}
#[expect(clippy::unwrap_used)]
pub fn progress_persec() -> ProgressStyle {
return ProgressStyle::default_bar()
.template(
" {bar:16.red/white.dim} {elapsed_precise:.dim} {pos}/{len} ({per_sec:>3}) {msg:.dim} ({eta})",
)
.unwrap()
.progress_chars("---")
.tick_strings(&[
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
]);
}

View File

@@ -21,9 +21,8 @@ const MAX_SHARD: usize = 1822;
#[derive(Debug, Args, Clone)]
pub struct DownloadArgs {
/// Training data dir
#[clap(default_value = "data")]
data_dir: PathBuf,
/// Training data directory (will be created)
data: PathBuf,
/// Number of shards to download (-1 for all)
#[arg(short = 'n', long, default_value = "-1")]
@@ -37,7 +36,7 @@ pub struct DownloadArgs {
impl DownloadArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
info!("Downloading files from {BASE_URL}");
fs::create_dir_all(&self.data_dir)?;
fs::create_dir_all(&self.data)?;
let num_shards_to_download = if self.num_files == -1 {
MAX_SHARD + 1
@@ -48,7 +47,7 @@ impl DownloadArgs {
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
info!("Downloading {} shards...", ids_to_download.len(),);
info!("Target directory: {}", self.data_dir.display());
info!("Target directory: {}", self.data.display());
let main_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
@@ -70,7 +69,7 @@ impl DownloadArgs {
ids_to_download
.into_par_iter()
.for_each_with(tx, |tx, index| {
let target = self.data_dir.clone();
let target = self.data.clone();
let main_pb = main_pb.clone();
let mp_clone = mp.clone();
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread

View File

@@ -1,5 +1,5 @@
mod download;
mod sample_data;
mod train_model;
mod train_tokenizer;
#[derive(Debug, clap::Subcommand)]
@@ -19,7 +19,7 @@ pub enum SubCommand {
/// Train model
TrainModel {
#[command(flatten)]
args: sample_data::TrainModelArgs,
args: train_model::TrainModelArgs,
},
}

View File

@@ -1,735 +0,0 @@
use ahash::AHasher;
use anyhow::{Context, Result};
use burn::{
Tensor,
backend::{Autodiff, Cuda, cuda::CudaDevice},
config::Config,
module::{AutodiffModule, Module, Param, ParamId},
nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
loss::CrossEntropyLossConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
},
optim::{AdamConfig, GradientsParams, Optimizer},
prelude::{Backend, ToElement},
tensor::{Bool, Distribution, Int, activation::softmax},
};
use burn_train::ClassificationOutput;
use clap::Args;
use indicatif::MultiProgress;
use ndarray::{Array1, Array2};
use std::{
collections::VecDeque,
f32,
fs::File,
hash::Hasher,
path::{Path, PathBuf},
};
use tokenizer::Tokenizer;
use tracing::{debug, info};
use crate::data_reader::{DataReader, DataReaderError};
// Text generation routine
/*
{
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
let n_tokens = tokens.len();
let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
for _ in 0..100 {
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
println!("{:?}", tokens);
println!("{}", tokenizer.decode(&tokens));
// Crop idx to context size;
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/
struct TrainTestIterator<'a, B: Backend> {
reader: DataReader,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
error: bool,
// Tokenized input/output pairs
pairs: VecDeque<(Vec<u32>, u32)>,
}
impl<'a, B: Backend> TrainTestIterator<'a, B> {
pub fn new(
data_dir: impl AsRef<Path>,
ccfg: &'a ComputeConfig,
mcfg: &'a GptModelConfig,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
) -> Result<Self, std::io::Error> {
let reader = DataReader::new(10, data_dir)?;
Ok(Self {
reader,
ccfg,
mcfg,
tokenizer,
eval,
device,
error: false,
pairs: VecDeque::new(),
})
}
}
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
type Item = Result<TrainBatch<B>, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
if self.error {
return None;
}
let mut inputs = Vec::with_capacity(self.ccfg.batch_size);
let mut targets = Vec::with_capacity(self.ccfg.batch_size);
let stride = self.mcfg.context_size;
while inputs.len() < self.ccfg.batch_size {
match self.pairs.pop_front() {
Some((i, t)) => {
// train/test split
{
let mut hasher = AHasher::default();
hasher.write(self.ccfg.eval_salt.as_bytes());
// Don't care about endianness, ahash output is unstable anyway
hasher.write(unsafe { std::mem::transmute(&i[..]) });
hasher.write_u32(t);
let test = // is this point in the test set?
hasher.finish() > (u64::MAX as f64 * self.ccfg.eval_frac).to_u64();
if test ^ self.eval {
continue;
}
}
inputs.push(i);
targets.push(t);
}
None => {
let text = match self.reader.next() {
None => break,
Some(Ok(x)) => x,
Some(Err(x)) => {
self.error = true;
return Some(Err(x));
}
};
let emb = self.tokenizer.encode(&text);
// Skip small texts
//
// TODO: do this better
// TODO: maybe using <|bos|>?
// TODO: non-uniform batches?
if emb.len() < self.mcfg.context_size {
continue;
}
let pairs = emb
.windows(self.mcfg.context_size + 1)
.step_by(stride)
.map(|x| {
(
x[..self.mcfg.context_size].to_vec(),
x[self.mcfg.context_size],
)
});
self.pairs.extend(pairs);
}
}
}
if inputs.is_empty() {
return None;
}
let shape = [inputs.len(), self.mcfg.context_size];
// Arrange data in memory
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
let targets: Array1<u32> = Array1::from_vec(targets);
// Create tensors on gpu
#[expect(clippy::unwrap_used)]
let inputs =
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
#[expect(clippy::unwrap_used)]
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
return Some(Ok(TrainBatch { inputs, targets }));
}
}
#[derive(Debug, Args, Clone)]
pub struct TrainModelArgs {
/// Path to training data
data: PathBuf,
/// Path to tokenizer
#[clap(long)]
tokenizer: PathBuf,
}
pub struct ComputeConfig {
pub batch_size: usize,
pub eval_frac: f64,
pub eval_salt: String,
}
impl TrainModelArgs {
pub fn run(self, _mp: Option<MultiProgress>) -> Result<()> {
let device = CudaDevice::new(0);
//let device = WgpuDevice::DiscreteGpu(0);
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let ccfg = ComputeConfig {
batch_size: 10,
eval_frac: 0.1,
eval_salt: "salt".into(),
};
let mcfg = GptModelConfig {
vocab_size: tokenizer.vocab_size(),
context_size: 256,
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 1,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
};
let mut model: GptModel<Autodiff<Cuda>> = mcfg.init(&device);
/*
let loader_train = DataLoaderBuilder::new(batcher.clone())
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let loader_test = DataLoaderBuilder::new(batcher)
.batch_size(ccfg.batch_size)
//.shuffle(config.seed)
.num_workers(5)
.build(Loader::new(&self.data_dir).context("while initializing loader")?);
let learner = LearnerBuilder::new("./tmp")
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
.num_epochs(10)
.summary()
.build(model, AdamConfig::new().init(), 1e-4);
learner.fit(loader_train, loader_test);
*/
// Initialize optimizer
let mut optim = AdamConfig::new().init();
let learning_rate = 1e-4;
for epoch in 0..10 {
debug!("Running epoch {epoch}");
// Training phase
let mut train_loss_sum = 0.0;
let mut train_total = 0;
for batch in
TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, false, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
// Forward pass with gradients
let output = model.forward_train(batch.inputs, batch.targets);
train_total += output.targets.dims()[0] as i32;
train_loss_sum += output.loss.clone().into_scalar().to_f32();
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
}
let mut valid_loss_sum = 0.0;
let mut valid_total = 0;
let mut n_eval = 0;
debug!("Evaluating batches");
for batch in TrainTestIterator::new(&self.data, &ccfg, &mcfg, &tokenizer, true, &device)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
n_eval += batch.targets.shape()[0];
// Forward pass without gradients
let output = model.valid().forward_train(batch.inputs, batch.targets);
valid_total += output.targets.dims()[0] as i32;
valid_loss_sum += output.loss.into_scalar().to_f32();
}
// Compute and log epoch results
let train_loss = if train_total > 0 {
train_loss_sum / train_total as f32
} else {
0.0
};
let valid_loss = if valid_total > 0 {
valid_loss_sum / valid_total as f32
} else {
0.0
};
info!(message = "Ran epoch", epoch, train_loss, valid_loss, n_eval);
}
Ok(())
}
}
//
// MARK: model
//
/// Multihead attention.
///
/// Equivalent to many stacked CausalAttention layers.
/// These are packed inside one big tensor for efficiency.
#[derive(Module, Debug)]
pub struct MultiheadAttention<B: Backend> {
n_heads: usize,
head_dim: usize,
// Can also use Linear layers with disabled bias
// (they may also have a better initialization routine)
// TODO: see source code, make this equivalent
/// Query weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, this learns "what question to ask about the text"
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
w_query: Param<Tensor<B, 2>>,
/// Key weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, this learns what properties a certain token
/// has when it appears as a context (non-query) token.
w_key: Param<Tensor<B, 2>>,
/// Value weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, ???
w_value: Param<Tensor<B, 2>>,
/// Optional final projection.
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
w_output: Param<Tensor<B, 2>>,
dropout: Dropout,
/// Upper-triangular matrix of ones, excluding diagonal.
/// Used to mask future tokens.
utri_mask: Tensor<B, 2, Bool>,
}
impl<B: Backend> MultiheadAttention<B> {
pub fn new(
embedding_dim: usize,
head_dim: usize,
n_heads: usize,
context_length: usize,
dropout: f64,
device: &B::Device,
) -> Self {
let total_dim = head_dim * n_heads;
Self {
n_heads,
head_dim,
w_query: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
),
w_key: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
),
w_value: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[embedding_dim, total_dim].into(),
),
w_output: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([total_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[total_dim, total_dim].into(),
),
dropout: Dropout { prob: dropout },
utri_mask: Tensor::<B, 2, _>::tril_mask([context_length, context_length], 0, device),
}
}
/// Compute self-attention vector for the given batch
///
/// - input shape is [batch, token, token_dim]
/// - input shape is [batch, token, n_heads * head_dim]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
// But adds an "inner latent space" using Wq, Qk, and Wv.
//
// Multiple heads are batched into one tensor.
let batch = input.dims()[0];
let tokens = input.dims()[1];
let w_query = self
.w_query
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_key = self
.w_key
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_value = self
.w_value
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_output = self
.w_output
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
// Map batch to inner latent space.
// shape: [batch, token, inner_dim]
let queries = input.clone().matmul(w_query);
let keys = input.clone().matmul(w_key);
let values = input.clone().matmul(w_value);
// Split head dimensions
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
// from: [batch, tok, head, head_dim]
// to: [batch, head, tok, head_dim]
let keys = keys.swap_dims(1, 2);
let values = values.swap_dims(1, 2);
let queries = queries.swap_dims(1, 2);
// Compute attention scores for each head
// (cosine similarity of each query token to each context token, per head)
//
// lhs shape: [batch, head, tok, head_dim]
// rhs shape: [batch, head, head_dim, tok]
// output shape: [batch, head, query_token, context_token]
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
let mask = self
.utri_mask
.clone()
.slice([0..tokens, 0..tokens])
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
.expand(attn_scores.shape());
// Mask out future tokens by filling
// upper-triangular with -inf, which becomes 0.0 after softmax.
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
// Normalize attn weights.
//
// Divide by sqrt(inner_dim) because...
// - dot products get larger with larger dimensions
// - this causes softmax to "saturate", making all other values very small
// - which makes gradients vanish during training
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
let attn_weights = self.dropout.forward(attn_weights);
// lhs shape: [batch, head, query_token, context_token]
// rhs shape: [batch, head, tok, head_dim]
// matmul shape: [batch, head, tok, head_dim]
// out shape: [batch, tok, head, head_dim]
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
// shape: [batch, tok, stacked_dim]
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
// Apply final projection (optional)
let context_vec = context_vec.matmul(w_output);
return context_vec;
}
}
#[derive(Config, Debug)]
pub struct GptModelConfig {
/// Number of tokens
pub vocab_size: u32,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
pub n_layers: usize,
pub embed_drop_rate: f64,
pub attention_drop_rate: f64,
pub shortcut_drop_rate: f64,
}
impl GptModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
GptModel {
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
.init(device),
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
embedder_drop: Dropout {
prob: self.embed_drop_rate,
},
trf_blocks: (0..self.n_layers)
.map(|_| TransformerBlock::new(&self, device))
.collect(),
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
out_head: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random(out_head_shape, Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
out_head_shape.into(),
),
}
}
}
#[derive(Debug, Clone)]
pub struct TrainBatch<B: Backend> {
pub inputs: Tensor<B, 2, Int>,
/// Correct next token for each input
pub targets: Tensor<B, 1, Int>,
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self
.embedder_pos
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos;
let x = self.embedder_drop.forward(x);
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
let x = self.final_norm.forward(x);
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
return logits;
}
pub fn forward_train(
&self,
inputs: Tensor<B, 2, Int>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// shape: [batch, n_tokens, n_vocabulary]
let output = self.forward(inputs);
// Get last token
// shape: [batch, n_vocabulary]
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
let loss = CrossEntropyLossConfig::new()
.init(&targets.device())
.forward(output.clone(), targets.clone());
ClassificationOutput {
loss,
output,
targets,
}
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
attention: MultiheadAttention<B>,
/// TODO: wtf?
ff: PositionWiseFeedForward<B>,
/// TODO: wtf?
norm_a: LayerNorm<B>,
norm_b: LayerNorm<B>,
drop_shortcut: Dropout,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
Self {
attention: MultiheadAttention::new(
cfg.embed_dim,
cfg.head_dim,
cfg.n_heads,
cfg.context_size,
cfg.attention_drop_rate,
device,
),
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
.with_dropout(0.0)
.init(device),
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
drop_shortcut: Dropout {
prob: cfg.shortcut_drop_rate,
},
}
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let input = {
let shortcut = input.clone();
let x = self.norm_a.forward(input);
let x = self.attention.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
let input = {
// TODO: wtf?
let shortcut = input.clone();
let x = self.norm_b.forward(input);
let x = self.ff.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
return input;
}
}

View File

@@ -0,0 +1,312 @@
use anyhow::{Context, Result};
use burn::{
backend::Autodiff,
module::{AutodiffModule, Module},
optim::{AdamConfig, GradientsParams, Optimizer},
prelude::ToElement,
record::{FullPrecisionSettings, NamedMpkFileRecorder},
tensor::backend::AutodiffBackend,
};
use clap::Args;
use indicatif::{MultiProgress, ProgressBar};
use std::{
f32,
fs::File,
num::NonZero,
path::PathBuf,
time::{Duration, Instant},
};
use tokenizer::Tokenizer;
use tracing::{debug, info};
use crate::{
InferenceDevice,
cli::{progress_big, progress_persec},
parts::{GptModel, GptModelConfig},
train_test_iterator::TrainTestIterator,
};
// Text generation routine
/*
{
let init = "Initial context. This is ";
let tokens = tokenizer.encode(&init);
let n_tokens = tokens.len();
let input: Array1<u32> = Array1::from_vec(tokens);
let mut input: Tensor<Cuda, 1, Int> =
Tensor::<_, 1, Int>::from_ints(input.as_slice().unwrap(), &device)
.reshape([n_tokens]);
for _ in 0..100 {
let tokens: Vec<u32> = input.clone().to_data().convert::<u32>().into_vec().unwrap();
println!("{:?}", tokens);
println!("{}", tokenizer.decode(&tokens));
// Crop idx to context size;
let batch = input
.clone()
.slice([0..config.context_size])
.unsqueeze_dim(0);
// shape: [tokens, vocab_size]
let logits = model.forward(batch).squeeze_dim::<2>(0);
// shape: [vocab_size]
let logits = logits.slice([config.context_size - 1]).squeeze_dim::<1>(0);
let probs = softmax(logits, 0); // shape: [n_tokens]
let id_next = probs.argmax(0); // shape: [1]
input = Tensor::cat(vec![input.slice([1..]), id_next], 0);
}
}
*/
#[derive(Debug, Args, Clone)]
pub struct TrainModelArgs {
/// Path to training data
data: PathBuf,
/// Path to tokenizer
#[clap(long, default_value = "tokenizer.json")]
tokenizer: PathBuf,
/// directory to save checkpoints
#[clap(long, default_value = "checkpoints")]
checkpoints: PathBuf,
/// The device to use for compute. `wgpu:n`, `cuda:n`, or `cpu`
#[clap(long, default_value = "cpu")]
device: InferenceDevice,
/// Training batch size
#[clap(long, default_value = "10")]
batch: NonZero<usize>,
/// Proportion of data reserved for evaluation
#[clap(long, default_value = "0.1")]
eval_frac: f64,
/// Eval hasher salt
#[clap(long, default_value = "eval-salt")]
eval_salt: String,
/// Number of threads reading data
#[clap(long, default_value = "5")]
readers: usize,
}
pub struct ComputeConfig {
pub batch_size: usize,
pub eval_frac: f64,
pub eval_salt: String,
}
impl TrainModelArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
match self.device {
InferenceDevice::Cpu => {
use burn::backend::NdArray;
use burn::backend::ndarray::NdArrayDevice;
let device = NdArrayDevice::Cpu;
self.run_inner::<Autodiff<NdArray>>(mp, device)?;
}
InferenceDevice::Cuda(x) => {
use burn::backend::Cuda;
use burn::backend::cuda::CudaDevice;
let device = CudaDevice::new(x);
self.run_inner::<Autodiff<Cuda>>(mp, device)?;
}
InferenceDevice::Wgpu(x) => {
use burn::backend::Wgpu;
use burn::backend::wgpu::WgpuDevice;
let device = WgpuDevice::DiscreteGpu(x);
self.run_inner::<Autodiff<Wgpu>>(mp, device)?;
}
};
return Ok(());
}
fn run_inner<B: AutodiffBackend>(
self,
mp: Option<MultiProgress>,
device: B::Device,
) -> Result<()> {
let tokenizer = File::open(&self.tokenizer).context("while opening tokenizer")?;
let tokenizer: Tokenizer =
serde_json::from_reader(tokenizer).context("while loading tokenizer")?;
let ccfg = ComputeConfig {
batch_size: self.batch.get(),
eval_frac: self.eval_frac,
eval_salt: self.eval_salt.clone(),
};
let mcfg = GptModelConfig {
vocab_size: tokenizer.vocab_size(),
context_size: 256, // TODO: MORE!
embed_dim: 768,
n_heads: 12,
head_dim: 64, // = 768 / 12
n_layers: 12,
embed_drop_rate: 0.1,
attention_drop_rate: 0.1,
shortcut_drop_rate: 0.1,
};
let mut model: GptModel<B> = mcfg.init(&device);
let mut optim = AdamConfig::new().init();
let learning_rate = 1e-4;
std::fs::create_dir_all(&self.checkpoints).context("while creating checkpoint dir")?;
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
let main_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(10 as u64));
pb.set_style(progress_big());
pb.set_message("Training model");
pb.enable_steady_tick(Duration::from_millis(100));
pb
});
for epoch in 0..10 {
let start = Instant::now();
debug!("Running epoch {epoch}");
let epoch_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::no_length());
pb.set_style(progress_persec());
pb.set_message(format!("Running epoch {epoch}"));
pb.enable_steady_tick(Duration::from_millis(100));
pb
});
// Training phase
let mut train_loss_sum = 0.0;
let mut train_total = 0;
let mut n_train = 0u64;
for batch in TrainTestIterator::new(
&self.data,
&tokenizer,
false,
ccfg.batch_size,
mcfg.context_size,
ccfg.eval_frac,
&ccfg.eval_salt,
self.readers,
&device,
)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
epoch_pb.as_ref().map(|x| x.set_position(n_train));
n_train += batch.inputs.shape()[0] as u64;
// Forward pass with gradients
let output = model.forward_train(batch.inputs, batch.targets);
train_total += output.targets.dims()[0] as i32;
train_loss_sum += output.loss.clone().into_scalar().to_f32();
let grads = output.loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optim.step(learning_rate, model, grads);
}
epoch_pb.map(|x| x.finish_and_clear());
let mut valid_loss_sum = 0.0;
let mut valid_total = 0;
let mut n_eval = 0;
debug!("Evaluating batches");
let eval_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::no_length());
pb.set_style(progress_persec());
pb.set_message(format!("Evaluating epoch {epoch}"));
pb.enable_steady_tick(Duration::from_millis(100));
pb
});
for batch in TrainTestIterator::new(
&self.data,
&tokenizer,
true,
ccfg.batch_size,
mcfg.context_size,
ccfg.eval_frac,
&ccfg.eval_salt,
self.readers,
&device,
)
.context("while initializing reader")?
{
let batch = batch.context("while reading batch")?;
eval_pb.as_ref().map(|x| x.set_position(n_eval));
n_eval += batch.inputs.shape()[0] as u64;
// Forward pass without gradients
let output = model.valid().forward_train(batch.inputs, batch.targets);
valid_total += output.targets.dims()[0] as i32;
valid_loss_sum += output.loss.into_scalar().to_f32();
}
eval_pb.map(|x| x.finish_and_clear());
// Compute and log epoch results
let train_loss = if train_total > 0 {
train_loss_sum / train_total as f32
} else {
0.0
};
let valid_loss = if valid_total > 0 {
valid_loss_sum / valid_total as f32
} else {
0.0
};
info!(
message = "Ran epoch",
epoch,
train_loss,
valid_loss,
n_train,
n_eval,
time_ms = start.elapsed().as_millis()
);
main_pb.as_ref().map(|x| x.inc(1));
{
let target = self.checkpoints.join(format!("epoch-{epoch:02}"));
info!(message = "Saving checkpoint", ?target);
std::fs::create_dir_all(&self.checkpoints)
.context("while creating checkpoint dir")?;
model
.clone()
.save_file(target, &recorder)
.context("while saving checkpoint")?;
}
}
if let Some(pb) = main_pb.as_ref() {
pb.finish_and_clear();
info!("Training complete");
}
Ok(())
}
}

View File

@@ -12,22 +12,25 @@ use crate::data_reader::DataReader;
#[derive(Debug, Args, Clone)]
pub struct TrainTokenizerArgs {
/// Where to save tokenizer
#[clap(default_value = "tokenizer.json")]
target: PathBuf,
/// Path to training data
#[clap(long, default_value = "data")]
data_dir: PathBuf,
data: PathBuf,
/// Where to save tokenizer
#[clap(long, default_value = "tokenizer.json")]
target: PathBuf,
/// Only train on the first n texts
#[clap(long)]
first_n: Option<usize>,
/// Number of threads to use for training
/// Number of threads to use for training. 0 to autodetect.
#[clap(long, default_value = "0")]
threads: usize,
/// Number of threads reading data
#[clap(long, default_value = "5")]
readers: usize,
/// Tokenizer vocabulary size
#[clap(long, default_value = "65535")]
n_tokens: u32,
@@ -35,7 +38,8 @@ pub struct TrainTokenizerArgs {
impl TrainTokenizerArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
let iter = DataReader::new(self.readers.max(1), &self.data)
.context("while initializing data reader")?;
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap());

View File

@@ -3,6 +3,7 @@ use parking_lot::Mutex;
use parquet::errors::ParquetError;
use parquet::file::reader::{FileReader, SerializedFileReader};
use parquet::record::RowAccessor;
use rand::seq::SliceRandom;
use std::fs::File;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -25,10 +26,11 @@ pub enum DataReaderError {
///
/// All parquet files have exactly one text column.
/// No promises about this struct's behavior if this is not the case.
#[derive(Clone)]
pub struct DataReader {
rx: Receiver<Result<String, DataReaderError>>,
rx: Arc<Mutex<Receiver<Result<String, DataReaderError>>>>,
total_rows: usize,
consumed_rows: AtomicUsize,
consumed_rows: Arc<AtomicUsize>,
}
impl DataReader {
@@ -57,6 +59,8 @@ impl DataReader {
files.push(path);
}
}
files.shuffle(&mut rand::rng());
(Arc::new(Mutex::new(files)), total_rows)
};
@@ -147,9 +151,9 @@ impl DataReader {
});
Ok(Self {
rx,
rx: Arc::new(Mutex::new(rx)),
total_rows,
consumed_rows: AtomicUsize::new(0),
consumed_rows: Arc::new(AtomicUsize::new(0)),
})
}
@@ -157,7 +161,7 @@ impl DataReader {
/// Order is arbitrary.
/// Returns `None` when all rows have been read.
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
self.rx.recv().ok()
self.rx.lock().recv().ok()
}
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {

View File

@@ -12,18 +12,16 @@ use tracing_subscriber::{
// MARK: loglevel
//
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
#[derive(Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum, Default)]
pub enum LogLevel {
Trace,
Debug,
#[default]
Info,
Info,
Warn,
Error,
}
impl Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -47,7 +45,7 @@ pub struct LoggingConfig {
pub silence: LogLevel,
// Bins
pub nanochat: LogLevel,
pub llmfs: LogLevel,
}
impl From<LoggingConfig> for EnvFilter {
@@ -71,7 +69,7 @@ impl From<LoggingConfig> for EnvFilter {
//
// Bins
//
format!("nanochat_rs={}", conf.nanochat),
format!("llmfs={}", conf.llmfs),
conf.other.to_string(),
]
.join(","),
@@ -164,31 +162,31 @@ impl LogFilterPreset {
Self::Error => LoggingConfig {
other: LogLevel::Error,
silence: LogLevel::Error,
nanochat: LogLevel::Error,
llmfs: LogLevel::Error,
},
Self::Warn => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Warn,
llmfs: LogLevel::Warn,
},
Self::Info => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Info,
llmfs: LogLevel::Info,
},
Self::Debug => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Debug,
llmfs: LogLevel::Debug,
},
Self::Trace => LoggingConfig {
other: LogLevel::Trace,
silence: LogLevel::Warn,
nanochat: LogLevel::Trace,
llmfs: LogLevel::Trace,
},
}
}
@@ -216,16 +214,14 @@ pub enum LoggingTarget {
}
/// How to print logs
#[derive(Debug, Clone, Copy, Deserialize)]
#[derive(Default)]
#[derive(Debug, Clone, Copy, Deserialize, Default)]
pub enum LoggingFormat {
#[default]
Ansi,
Ansi,
AnsiNoColor,
Json,
}
pub struct LoggingInitializer {
/// Log filter for printed logs
pub preset: LogFilterPreset,

View File

@@ -1,5 +1,8 @@
#![recursion_limit = "256"] // needed to resolve burn types
use clap::Parser;
use indicatif::MultiProgress;
use serde::{Deserialize, Deserializer};
use tracing::error;
use crate::{
@@ -11,6 +14,8 @@ mod cli;
mod command;
mod data_reader;
mod logging;
mod parts;
mod train_test_iterator;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None, styles=crate::cli::clap_styles())]
@@ -60,3 +65,66 @@ fn main() {
std::process::exit(1);
}
}
//
//
//
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum InferenceDevice {
#[default]
Cpu,
Cuda(usize),
Wgpu(usize),
}
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
#[error("{0}")]
pub struct ParseDeviceError(String);
impl std::str::FromStr for InferenceDevice {
type Err = ParseDeviceError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_lowercase();
if s == "cpu" {
return Ok(InferenceDevice::Cpu);
}
if let Some(index_str) = s.strip_prefix("cuda:") {
return match index_str.parse::<usize>() {
Ok(index) => Ok(InferenceDevice::Cuda(index)),
Err(_) => Err(ParseDeviceError(format!(
"Invalid device index: '{}'",
index_str
))),
};
}
if let Some(index_str) = s.strip_prefix("wgpu:") {
return match index_str.parse::<usize>() {
Ok(index) => Ok(InferenceDevice::Wgpu(index)),
Err(_) => Err(ParseDeviceError(format!(
"Invalid device index: '{}'",
index_str
))),
};
}
return Err(ParseDeviceError(format!(
"Invalid device format: '{}'. Expected 'cpu', 'cuda:N', 'wgpu:N'",
s
)));
}
}
impl<'de> Deserialize<'de> for InferenceDevice {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}

View File

@@ -0,0 +1,228 @@
use burn::{
Tensor,
config::Config,
module::{Module, Param, ParamId},
nn::Dropout,
prelude::Backend,
tensor::{Bool, Distribution, activation::softmax},
};
use std::f32;
#[derive(Debug, Config)]
pub struct MultiheadAttentionConfig {
pub context_size: usize,
pub embed_dim: usize,
pub n_heads: usize,
pub head_dim: usize,
pub drop_rate: f64,
}
impl MultiheadAttentionConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiheadAttention<B> {
let total_dim = self.head_dim * self.n_heads;
let embedding_dim = self.embed_dim;
MultiheadAttention {
n_heads: self.n_heads,
head_dim: self.head_dim,
w_query: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[self.embed_dim, total_dim].into(),
),
w_key: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[self.embed_dim, total_dim].into(),
),
w_value: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([embedding_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[self.embed_dim, total_dim].into(),
),
w_output: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random([total_dim, total_dim], Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
[total_dim, total_dim].into(),
),
dropout: Dropout {
prob: self.drop_rate,
},
utri_mask: Tensor::<B, 2, _>::tril_mask(
[self.context_size, self.context_size],
0,
device,
),
}
}
}
/// Multihead attention.
///
/// Equivalent to many stacked CausalAttention layers.
/// These are packed inside one big tensor for efficiency.
#[derive(Module, Debug)]
pub struct MultiheadAttention<B: Backend> {
n_heads: usize,
head_dim: usize,
// Can also use Linear layers with disabled bias
// (they may also have a better initialization routine)
// TODO: see source code, make this equivalent
/// Query weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, this learns "what question to ask about the text"
/// for a given query token. (e.g, "it" -> what does "it" refer to?)
w_query: Param<Tensor<B, 2>>,
/// Key weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, this learns what properties a certain token
/// has when it appears as a context (non-query) token.
w_key: Param<Tensor<B, 2>>,
/// Value weight matrices for each head, stacked on the last dimension.
/// (so that shape is [tokens, n_heads * head_dim])
///
/// Intuitively, ???
w_value: Param<Tensor<B, 2>>,
/// Optional final projection.
/// Maps [total_dim, total_dim] to [total_dim, total_dim]
w_output: Param<Tensor<B, 2>>,
dropout: Dropout,
/// Upper-triangular matrix of ones, excluding diagonal.
/// Used to mask future tokens.
utri_mask: Tensor<B, 2, Bool>,
}
impl<B: Backend> MultiheadAttention<B> {
/// Compute self-attention vector for the given batch
///
/// - input shape is [batch, token, token_dim]
/// - input shape is [batch, token, n_heads * head_dim]
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
// Works similarly to self-attention, (where attn = softmax(tok @ tok^T); context = attn @ tok)
// But adds an "inner latent space" using Wq, Qk, and Wv.
//
// Multiple heads are batched into one tensor.
let batch = input.dims()[0];
let tokens = input.dims()[1];
let w_query = self
.w_query
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_key = self
.w_key
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_value = self
.w_value
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
let w_output = self
.w_output
.val()
.unsqueeze_dim::<3>(0)
.expand([batch as i64, -1, -1]);
// Map batch to inner latent space.
// shape: [batch, token, inner_dim]
let queries = input.clone().matmul(w_query);
let keys = input.clone().matmul(w_key);
let values = input.clone().matmul(w_value);
// Split head dimensions
let keys = keys.reshape([batch, tokens, self.n_heads, self.head_dim]);
let values = values.reshape([batch, tokens, self.n_heads, self.head_dim]);
let queries = queries.reshape([batch, tokens, self.n_heads, self.head_dim]);
// from: [batch, tok, head, head_dim]
// to: [batch, head, tok, head_dim]
let keys = keys.swap_dims(1, 2);
let values = values.swap_dims(1, 2);
let queries = queries.swap_dims(1, 2);
// Compute attention scores for each head
// (cosine similarity of each query token to each context token, per head)
//
// lhs shape: [batch, head, tok, head_dim]
// rhs shape: [batch, head, head_dim, tok]
// output shape: [batch, head, query_token, context_token]
let attn_scores = queries.matmul(keys.clone().swap_dims(2, 3));
let mask = self
.utri_mask
.clone()
.slice([0..tokens, 0..tokens])
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0)
.expand(attn_scores.shape());
// Mask out future tokens by filling
// upper-triangular with -inf, which becomes 0.0 after softmax.
let attn_scores = attn_scores.mask_fill(mask, f32::NEG_INFINITY);
// Normalize attn weights.
//
// Divide by sqrt(inner_dim) because...
// - dot products get larger with larger dimensions
// - this causes softmax to "saturate", making all other values very small
// - which makes gradients vanish during training
let attn_weights = softmax(attn_scores / (keys.shape()[3] as f32).sqrt(), 3);
let attn_weights = self.dropout.forward(attn_weights);
// lhs shape: [batch, head, query_token, context_token]
// rhs shape: [batch, head, tok, head_dim]
// matmul shape: [batch, head, tok, head_dim]
// out shape: [batch, tok, head, head_dim]
let context_vec = attn_weights.matmul(values).swap_dims(1, 2);
// shape: [batch, tok, stacked_dim]
let context_vec = context_vec.reshape([batch, tokens, self.n_heads * self.head_dim]);
// Apply final projection (optional)
let context_vec = context_vec.matmul(w_output);
return context_vec;
}
}

View File

@@ -0,0 +1,5 @@
mod attention;
pub use attention::*;
mod model;
pub use model::*;

View File

@@ -0,0 +1,194 @@
use burn::{
Tensor,
config::Config,
module::{Module, Param, ParamId},
nn::{
Dropout, Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig,
loss::CrossEntropyLossConfig,
transformer::{PositionWiseFeedForward, PositionWiseFeedForwardConfig},
},
prelude::Backend,
tensor::{Distribution, Int},
};
use burn_train::ClassificationOutput;
use crate::parts::{MultiheadAttention, MultiheadAttentionConfig};
#[derive(Debug, Config)]
pub struct GptModelConfig {
/// Number of tokens
pub vocab_size: u32,
/// Maximum number of input tokens with positional embeddings
pub context_size: usize,
/// Dimension of each token's embedding
pub embed_dim: usize,
/// Number of attention heads
pub n_heads: usize,
/// Dimension of each attn head
pub head_dim: usize,
/// Number of transformer blocks
#[config(default = 12)]
pub n_layers: usize,
#[config(default = 0.1)]
pub embed_drop_rate: f64,
#[config(default = 0.1)]
pub attention_drop_rate: f64,
#[config(default = 0.1)]
pub shortcut_drop_rate: f64,
}
impl GptModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> GptModel<B> {
let out_head_shape = [self.embed_dim, self.vocab_size as usize];
GptModel {
embedder_tok: EmbeddingConfig::new(self.vocab_size as usize, self.embed_dim)
.init(device),
embedder_pos: EmbeddingConfig::new(self.context_size, self.embed_dim).init(device),
embedder_drop: Dropout {
prob: self.embed_drop_rate,
},
trf_blocks: (0..self.n_layers)
.map(|_| TransformerBlock::new(&self, device))
.collect(),
final_norm: LayerNormConfig::new(self.embed_dim).init(device),
out_head: Param::uninitialized(
ParamId::new(),
move |device, is_require_grad| {
Tensor::random(out_head_shape, Distribution::Default, device)
.set_require_grad(is_require_grad)
},
device.clone(),
true,
out_head_shape.into(),
),
}
}
}
#[derive(Module, Debug)]
pub struct GptModel<B: Backend> {
embedder_tok: Embedding<B>,
embedder_pos: Embedding<B>,
embedder_drop: Dropout,
trf_blocks: Vec<TransformerBlock<B>>,
final_norm: LayerNorm<B>,
out_head: Param<Tensor<B, 2>>,
}
impl<B: Backend> GptModel<B> {
pub fn forward(&self, input: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let n_tokens = input.shape()[1];
let embed_tok = self.embedder_tok.forward(input.clone());
let embed_pos = self
.embedder_pos
.forward(Tensor::arange(0..n_tokens as i64, &input.device()).unsqueeze_dim(0));
let x = embed_tok + embed_pos;
let x = self.embedder_drop.forward(x);
let x = self.trf_blocks.iter().fold(x, |x, l| l.forward(x));
let x = self.final_norm.forward(x);
let logits = x.matmul(self.out_head.val().unsqueeze_dim(0));
return logits;
}
pub fn forward_train(
&self,
inputs: Tensor<B, 2, Int>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
// shape: [batch, n_tokens, n_vocabulary]
let output = self.forward(inputs);
// Get last token
// shape: [batch, n_vocabulary]
let output = output.slice_dim(1, -1).squeeze_dim::<2>(1);
let loss = CrossEntropyLossConfig::new()
.init(&targets.device())
.forward(output.clone(), targets.clone());
ClassificationOutput {
loss,
output,
targets,
}
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
attention: MultiheadAttention<B>,
/// TODO: wtf?
ff: PositionWiseFeedForward<B>,
/// TODO: wtf?
norm_a: LayerNorm<B>,
norm_b: LayerNorm<B>,
drop_shortcut: Dropout,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(cfg: &GptModelConfig, device: &B::Device) -> Self {
Self {
attention: MultiheadAttentionConfig {
embed_dim: cfg.embed_dim,
head_dim: cfg.head_dim,
n_heads: cfg.n_heads,
context_size: cfg.context_size,
drop_rate: cfg.attention_drop_rate,
}
.init(device),
ff: PositionWiseFeedForwardConfig::new(cfg.embed_dim, 4 * cfg.embed_dim)
.with_dropout(0.0)
.init(device),
norm_a: LayerNormConfig::new(cfg.embed_dim).init(device),
norm_b: LayerNormConfig::new(cfg.embed_dim).init(device),
drop_shortcut: Dropout {
prob: cfg.shortcut_drop_rate,
},
}
}
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let input = {
let shortcut = input.clone();
let x = self.norm_a.forward(input);
let x = self.attention.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
let input = {
// TODO: wtf?
let shortcut = input.clone();
let x = self.norm_b.forward(input);
let x = self.ff.forward(x);
let x = self.drop_shortcut.forward(x);
x + shortcut
};
return input;
}
}

View File

@@ -0,0 +1,164 @@
use ahash::AHasher;
use anyhow::Result;
use burn::{
Tensor,
prelude::{Backend, ToElement},
tensor::Int,
};
use ndarray::{Array1, Array2};
use std::{collections::VecDeque, hash::Hasher, path::Path};
use tokenizer::Tokenizer;
use crate::data_reader::{DataReader, DataReaderError};
#[derive(Debug, Clone)]
pub struct TrainBatch<B: Backend> {
/// Input texts.
/// shape: [batch, context_size]
pub inputs: Tensor<B, 2, Int>,
/// Correct next token for each input.
/// shape: [batch]
pub targets: Tensor<B, 1, Int>,
}
/// Read texts from a [DataReader], then
/// - extract context windows
/// - deterministically classify these as "test" or "train"
/// - batch output into tensors of token ids
pub struct TrainTestIterator<'a, B: Backend> {
reader: DataReader,
tokenizer: &'a Tokenizer,
eval: bool,
device: &'a B::Device,
batch_size: usize,
context_size: usize,
eval_frac: f64,
eval_salt: String,
// Tokenized input/output pairs
pairs: VecDeque<(Vec<u32>, u32)>,
error: bool,
}
impl<'a, B: Backend> TrainTestIterator<'a, B> {
pub fn new(
data_dir: impl AsRef<Path>,
tokenizer: &'a Tokenizer,
eval: bool,
batch_size: usize,
context_size: usize,
eval_frac: f64,
eval_salt: impl Into<String>,
readers: usize,
device: &'a B::Device,
) -> Result<Self, std::io::Error> {
let reader = DataReader::new(readers.max(1), data_dir)?;
Ok(Self {
reader,
tokenizer,
eval,
device,
batch_size,
context_size,
eval_frac,
eval_salt: eval_salt.into(),
pairs: VecDeque::new(),
error: false,
})
}
}
impl<B: Backend> Iterator for TrainTestIterator<'_, B> {
type Item = Result<TrainBatch<B>, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
if self.error {
return None;
}
let mut inputs = Vec::with_capacity(self.batch_size);
let mut targets = Vec::with_capacity(self.batch_size);
let stride = self.context_size;
while inputs.len() < self.batch_size {
match self.pairs.pop_front() {
Some((i, t)) => {
// train/test split
{
let mut hasher = AHasher::default();
hasher.write(self.eval_salt.as_bytes());
// Don't care about endianness, ahash output is unstable anyway
hasher.write(unsafe { std::mem::transmute(&i[..]) });
hasher.write_u32(t);
let train = // is this point in the training set?
hasher.finish() > (u64::MAX as f64 * self.eval_frac).to_u64();
if train && self.eval {
continue;
}
}
inputs.push(i);
targets.push(t);
}
None => {
let text = match self.reader.next() {
None => break,
Some(Ok(x)) => x,
Some(Err(x)) => {
self.error = true;
return Some(Err(x));
}
};
let emb = self.tokenizer.encode(&text);
// Skip small texts
//
// TODO: do this better
// TODO: maybe using <|bos|>?
// TODO: non-uniform batches?
if emb.len() < self.context_size {
continue;
}
let pairs = emb
.windows(self.context_size + 1)
.step_by(stride)
.map(|x| (x[..self.context_size].to_vec(), x[self.context_size]));
self.pairs.extend(pairs);
}
}
}
if inputs.is_empty() {
return None;
}
let shape = [inputs.len(), self.context_size];
// Arrange data in memory
let inputs: Array2<u32> = Array2::from_shape_fn(shape, |(a, b)| inputs[a][b]);
let targets: Array1<u32> = Array1::from_vec(targets);
// Create tensors on gpu
#[expect(clippy::unwrap_used)]
let inputs =
Tensor::<B, 1, Int>::from_ints(inputs.as_slice().unwrap(), self.device).reshape(shape);
#[expect(clippy::unwrap_used)]
let targets = Tensor::<B, 1, Int>::from_ints(targets.as_slice().unwrap(), self.device);
return Some(Ok(TrainBatch { inputs, targets }));
}
}

View File

@@ -19,8 +19,7 @@ use tracing::{debug, info};
use crate::{progress_big, split::regex_segment};
// TODO:
// - maybe don't use regex
// Maybe don't use regex for performance?
#[derive(Debug, Clone, thiserror::Error)]
pub enum TokenizerTrainError {