use ahash::{AHashMap, AHashSet}; use compact_str::CompactString; use dary_heap::DaryHeap; use fancy_regex::Regex; use indicatif::{MultiProgress, ProgressBar}; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator, }; use serde::{Deserialize, Serialize}; use std::{ cmp::Ordering, collections::{HashMap, VecDeque}, str::FromStr, sync::atomic::AtomicUsize, time::Instant, }; use strum::{AsRefStr, EnumString}; use tracing::{debug, info}; use crate::{progress_big, split::regex_segment}; // Maybe don't use regex for performance? #[derive(Debug, Clone, thiserror::Error)] pub enum TokenizerTrainError { #[error("vocab_size must be at least {is}, but it is {min}")] InvalidVocabularySize { is: u32, min: u32 }, } /// A pair of adjacent tokens type Pair = (u32, u32); #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, EnumString, AsRefStr)] pub enum SpecialTokens { /// every document begins with the Beginning of Sequence (BOS) token that delimits documents #[strum(to_string = "<|bos|>")] BeginningOfSequence, // User messages #[strum(to_string = "<|user_start|>")] UserStart, #[strum(to_string = "<|user_end|>")] UserEnd, // Assistant messages #[strum(to_string = "<|assistant_start|>")] AssistantStart, #[strum(to_string = "<|assistant_end|>")] AssistantEnd, /// Assistant invokes python REPL tool #[strum(to_string = "<|python_start|>")] PythonStart, #[strum(to_string = "<|python_end|>")] PythonEnd, // Python REPL outputs back to assistant #[strum(to_string = "<|output_start|>")] OutputStart, #[strum(to_string = "<|output_end|>")] OutputEnd, } impl SpecialTokens { /// An array that contains every variant of this enum pub const VARIANTS: &[Self] = &[ Self::BeginningOfSequence, Self::UserStart, Self::UserEnd, Self::AssistantStart, Self::AssistantEnd, Self::PythonStart, Self::PythonEnd, Self::OutputStart, Self::OutputEnd, ]; /// Get the index of this variant in [Self::VARIANTS] #[expect(clippy::unwrap_used)] #[inline] pub fn idx(&self) -> u32 { Self::VARIANTS.iter().position(|p| p == self).unwrap() as u32 } } // // MARK: word // /// A segment of text that is undergoing tokenization. /// /// `ids` contains the ids of the tokens this word consists of. /// Initially, these are ascii u8s. More are added as we merge pairs. /// /// This implementation of BPE does not cross word boundaries. /// - one token cannot represent more than one word, which may be a feature. /// - this keeps multi-byte unicode together, since the split regex handles it correctly. #[derive(Clone, Debug)] struct Word { ids: Vec, } impl Word { #[inline] fn new(ids: Vec) -> Self { Self { ids } } /// Return an iterator over all adjacent pairs /// of tokens in this word #[inline] fn pairs<'a>(&'a self) -> impl Iterator + 'a { self.ids.windows(2).map(|w| (w[0], w[1])) } /// Merge all non-overlapping occurrences of pair -> new_id. /// Returns a small Vec of local pair-count deltas for THIS word only: /// -1 for removed pairs, +1 for newly created pairs. /// /// NOTE: this version deliberately avoids a HashMap in the hot loop. fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> { let (a, b) = pair; let n = self.ids.len(); if n < 2 { return Vec::new(); } let mut out: Vec = Vec::with_capacity(n); let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6); let mut i = 0; while i < n { if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b { let left = out.last().copied(); let right = if i + 2 < n { Some(self.ids[i + 2]) } else { None }; // remove old pairs if let Some(x) = left { deltas.push(((x, a), -1)); deltas.push(((x, new_id), 1)); } deltas.push(((a, b), -1)); if let Some(y) = right { deltas.push(((b, y), -1)); deltas.push(((new_id, y), 1)); } // write merged token out.push(new_id); i += 2; // skip 'a' and 'b' } else { out.push(self.ids[i]); i += 1; } } self.ids = out; deltas } } // // MARK: mergejob // #[derive(Debug, Eq)] struct MergeJob { pair: Pair, count: u64, /// set of word indices where this pair may occur and needs processing pos: AHashSet, } impl PartialEq for MergeJob { fn eq(&self, other: &Self) -> bool { self.count == other.count && self.pair == other.pair } } impl PartialOrd for MergeJob { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for MergeJob { fn cmp(&self, other: &Self) -> Ordering { self.count .cmp(&other.count) .then_with(|| other.pair.cmp(&self.pair)) } } // // MARK: tokenizer // /// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation /// /// BPE's goal is to build compound tokens for a text corpus. /// It starts with a simple initial state---in this case, each ascii u8 /// is a token. We then examine all adjacent pairs of tokens, replacing /// common pairs with a new token id. This is repeated until our vocabulary /// is as large as it needs to be. pub struct Tokenizer { /// Maps pairs of token IDs to their merged token ID merges: HashMap, /// Inverse of merges unmerges: Vec, vocab_size: u32, /// The regex pattern used for text splitting split_regex: Regex, special_regex: Regex, /// Source of split_regex /// (debug info) split_regex_string: String, /// The number of texts this tokenizer was trained on /// (debug info) n_train_texts: u64, /// The number of words this tokenizer was trained on /// (debug info) n_train_words: u64, } impl Serialize for Tokenizer { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { use serde::ser::SerializeStruct; let mut state = serializer.serialize_struct("Tokenizer", 4)?; // Convert HashMap to Vec for serialization // (only string keys are valid in json) let merges_vec: Vec<(Pair, u32)> = self.merges.iter().map(|(&k, &v)| (k, v)).collect(); state.serialize_field("merges", &merges_vec)?; state.serialize_field("split_regex_string", &self.split_regex_string)?; state.serialize_field("n_train_texts", &self.n_train_texts)?; state.serialize_field("n_train_words", &self.n_train_words)?; state.serialize_field("vocab_size", &self.vocab_size)?; state.end() } } // Custom deserializer that automatically compiles split_regex impl<'de> Deserialize<'de> for Tokenizer { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { #[derive(Deserialize)] struct TokenizerData { merges: Vec<(Pair, u32)>, split_regex_string: String, n_train_texts: u64, n_train_words: u64, vocab_size: u32, } let data = TokenizerData::deserialize(deserializer)?; let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?; let merges: HashMap = data.merges.into_iter().collect(); #[expect(clippy::unwrap_used)] // Special regex must be valid let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap(); Ok(Tokenizer { unmerges: Self::reverse_merges(&merges), merges, split_regex, split_regex_string: data.split_regex_string, special_regex, n_train_texts: data.n_train_texts, n_train_words: data.n_train_words, vocab_size: data.vocab_size, }) } } impl Tokenizer { /// Default regex pattern for splitting text const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; /// Regex for splitting special tokens pub const SPECIAL_REGEX: &str = "<[|][a-z_]+[|]>"; /// Minimum size of this tokenizer's vocabulary. /// - 256 for all `u8`s (token ids `0..=255`) /// - `n` for special tokens (token ids `256..256+n`) pub const MIN_VOCAB_SIZE: u32 = 256 + SpecialTokens::VARIANTS.len() as u32; /// Reverse a merge map, returning an array of pairs. /// `result[token_id - Self::MIN_VOCAB_SIZE`]` returns the pair that `token_id` represents. fn reverse_merges(merges: &HashMap) -> Vec { let iter = merges.iter().map(|x| (*x.0, *x.1)); let mut array = Vec::from_iter(iter); array.sort_by_key(|x| x.1); let mut unmerges = Vec::with_capacity(array.len()); for (i, p) in array.iter().enumerate() { assert_eq!(p.1, i as u32 + Self::MIN_VOCAB_SIZE); unmerges.push(p.0); } unmerges } /// Return the regex pattern used to split words #[inline] pub fn get_regex(&self) -> &str { &self.split_regex_string } /// Return the total number of tokens this tokenizer can produce #[inline] pub fn vocab_size(&self) -> u32 { self.vocab_size } /// Return this tokenizer's "beginning of seq" token #[inline] pub fn bos_token(&self) -> u32 { 256 + SpecialTokens::BeginningOfSequence.idx() } /// Tokenize a string pub fn encode(&self, text: &str) -> Vec { let mut all_ids = Vec::new(); let special = regex_segment(&self.special_regex, text); for s in special { if let Ok(special) = SpecialTokens::from_str(s) { all_ids.push(256 + special.idx()); continue; } for m in self.split_regex.find_iter(s) { #[expect(clippy::unwrap_used)] // Shouldn't ever fail let word = m.unwrap().as_str(); let mut word_tokens = word.bytes().map(|b| b as u32).collect::>(); // Apply merges while word_tokens.len() >= 2 { // Merge the pair with the largest token idx // (pair_start_idx, replace_with) let mut best_pair: Option<(usize, u32)> = None; for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() { let new_id = match self.merges.get(&pair) { None => continue, Some(x) => *x, }; #[expect(clippy::unwrap_used)] if best_pair.is_none() || new_id < best_pair.unwrap().1 { best_pair = Some((i, new_id)); } } match best_pair { Some((idx, new_id)) => { word_tokens[idx] = new_id; word_tokens.remove(idx + 1); } None => { // No merges possible break; } } } all_ids.extend(word_tokens); } } all_ids } /// Decode an array of tokens pub fn decode(&self, tokens: &[u32]) -> String { let mut str_bytes = Vec::new(); let mut buffer = VecDeque::from_iter(tokens.iter().copied()); while let Some(t) = buffer.pop_front() { // This is a byte if t < 256 { str_bytes.push(t as u8); continue; } let t = t - 256; // This is a special token if let Some(t) = SpecialTokens::VARIANTS.get(t as usize) { str_bytes.extend_from_slice(t.as_ref().as_bytes()); continue; } let t = t - SpecialTokens::VARIANTS.len() as u32; // This is a compound token #[expect(clippy::unwrap_used)] let pair = self.unmerges.get(t as usize).unwrap(); buffer.push_front(pair.1); buffer.push_front(pair.0); } return String::from_utf8_lossy(&str_bytes).to_string(); } // // MARK: main training code // /// Given an array of words an an array of counts, /// - count the number of occurrences of each token pair /// - map each pair to the indices of the words that contain it /// /// ## Notes: /// - will panic if `words.len() != counts.len()` /// - will not behave correctly if words are repeated #[inline] fn count_pairs( words: &[Word], counts: &[i32], ) -> (AHashMap, AHashMap>) { words .par_iter() .enumerate() .map(|(i, w)| { let mut pair_counts: AHashMap = AHashMap::new(); let mut word_map: AHashMap> = AHashMap::new(); if w.ids.len() >= 2 && counts[i] != 0 { for (a, b) in w.pairs() { *pair_counts.entry((a, b)).or_default() += counts[i]; word_map.entry((a, b)).or_default().insert(i); } } (pair_counts, word_map) }) .reduce( || (AHashMap::new(), AHashMap::new()), |(mut acc_pc, mut acc_wtu), (pc, wtu)| { for (k, v) in pc { *acc_pc.entry(k).or_default() += v; } for (k, s) in wtu { acc_wtu.entry(k).or_default().extend(s); } (acc_pc, acc_wtu) }, ) } /// Core incremental BPE training given unique words and their counts. /// - `words`: one entry per unique chunk (Vec of token-ids/bytes). /// - `counts`: same length as `words`, count per chunk. /// /// ## Notes: /// - vocab size must be >= Self::MIN_VOCAB_SIZE /// - will panic if `words.len() != counts.len()` /// - will not behave correctly if words are repeated fn train_core( mp: Option, mut words: Vec, counts: Vec, vocab_size: u32, ) -> HashMap { assert!( vocab_size >= Self::MIN_VOCAB_SIZE, "vocab_size must be at least {}", Self::MIN_VOCAB_SIZE ); let num_merges = vocab_size - Self::MIN_VOCAB_SIZE; let mut merges = HashMap::with_capacity(num_merges as usize); info!(message = "Training tokenizer", num_merges); let now = Instant::now(); info!( message = "Computing initial pair counts", n_words = words.len() ); let (mut pair_counts, mut where_to_update) = Self::count_pairs(&words, &counts); let mut heap = { debug!("Building heap with {} unique pairs", pair_counts.len()); let mut heap = DaryHeap::<_, 8>::with_capacity(pair_counts.len()); for (pair, pos) in where_to_update.drain() { let c = *pair_counts.get(&pair).unwrap_or(&0); if c > 0 { heap.push(MergeJob { pair, count: c as u64, pos, }); } } heap }; let pb = mp.as_ref().map(|mp| { let pb = mp.add(ProgressBar::new(num_merges as u64)); pb.set_style(progress_big()); pb.set_message("Computing merges"); pb }); let mut merges_done = 0u32; while merges_done < num_merges { let Some(mut top) = heap.pop() else { break; }; // Lazy refresh let current = *pair_counts.get(&top.pair).unwrap_or(&0); if top.count != current as u64 { top.count = current as u64; if top.count > 0 { heap.push(top); } continue; } if top.count == 0 { break; } // Record merge let new_id = Self::MIN_VOCAB_SIZE + merges_done; merges.insert(top.pair, new_id); // Merge this pair in all words where it occurs let mut local_pos_updates: AHashMap> = AHashMap::new(); for &word_idx in &top.pos { // Apply merge to this word and collect pair-count deltas let changes = words[word_idx].merge_pair(top.pair, new_id); // Update global pair counts based on this word's count for (pair, delta) in changes { let delta_total = delta * counts[word_idx]; if delta_total != 0 { *pair_counts.entry(pair).or_default() += delta_total; if delta > 0 { local_pos_updates.entry(pair).or_default().insert(word_idx); } } } } // Add the updated pair counts back to the heap for (pair, pos) in local_pos_updates { let cnt = *pair_counts.get(&pair).unwrap_or(&0); if cnt > 0 { heap.push(MergeJob { pair, count: cnt as u64, pos, }); } } merges_done += 1; if let Some(pb) = pb.as_ref() { pb.set_position(merges_done as u64); } } if let Some(pb) = pb { pb.finish_and_clear(); } info!("Computed merges in {:.03?}", now.elapsed()); return merges; } // // MARK: train // /// Train a new tokenizer from an iterator of texts. pub fn train( mp: Option, iterator: I, vocab_size: u32, ) -> Result where I: Iterator + ExactSizeIterator, I: ParallelBridge + Send, { if vocab_size < Self::MIN_VOCAB_SIZE { return Err(TokenizerTrainError::InvalidVocabularySize { is: vocab_size, min: Self::MIN_VOCAB_SIZE, }); } let split_regex_string = Self::DEFAULT_REGEX.to_owned(); #[expect(clippy::unwrap_used)] // Default regex must be valid let split_regex = Regex::new(&split_regex_string).unwrap(); #[expect(clippy::unwrap_used)] // Special regex must be valid let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap(); let now = Instant::now(); let n_train_texts = iterator.len() as u64; debug!("Counting words in {} texts", n_train_texts); let pb = mp.as_ref().map(|mp| { let pb = mp.add(ProgressBar::new(iterator.len() as u64)); pb.set_style(progress_big()); pb.set_message("Counting words"); pb }); // Process texts in parallel and collect chunk counts let counter = AtomicUsize::new(0); let counts: AHashMap = iterator .par_bridge() .map(|text| { let mut local_counts: AHashMap = AHashMap::new(); let special = regex_segment(&special_regex, &text); for s in special { if let Ok(special) = SpecialTokens::from_str(s) { *local_counts .entry(CompactString::from(special.as_ref())) .or_default() += 1; continue; } for mat in split_regex.find_iter(s) { #[expect(clippy::unwrap_used)] let piece = mat.unwrap().as_str(); *local_counts.entry(CompactString::from(piece)).or_default() += 1; } } let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); if let Some(ref pb) = pb && count.is_multiple_of(1000) { pb.inc(1000); } local_counts }) .reduce(AHashMap::new, |mut acc, local_counts| { for (k, v) in local_counts { *acc.entry(k).or_default() += v; } acc }); if let Some(pb) = pb { pb.finish_and_clear(); } info!( "Counted {} unique words in {:.03?}", counts.len(), now.elapsed() ); // Materialize words & counts let mut words = Vec::with_capacity(counts.len()); let mut cvec = Vec::with_capacity(counts.len()); let mut n_train_words = 0u64; for (chunk, c) in counts.into_iter() { let token_ids = match SpecialTokens::from_str(&chunk).ok() { Some(x) => vec![256 + x.idx()], None => chunk .as_bytes() .iter() .map(|&b| b as u32) .collect::>(), }; words.push(Word::new(token_ids)); cvec.push(c); n_train_words += c as u64; } let merges = Self::train_core(mp, words, cvec, vocab_size); Ok(Self { split_regex, split_regex_string, special_regex, n_train_texts, n_train_words, vocab_size: Self::MIN_VOCAB_SIZE + merges.len() as u32, unmerges: Self::reverse_merges(&merges), merges, }) } }