1
0

Tokenizer

This commit is contained in:
2025-12-11 17:37:33 -08:00
parent 62fcf781c1
commit 1805b7f430
10 changed files with 7678 additions and 0 deletions

View File

@@ -0,0 +1,720 @@
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use dary_heap::DaryHeap;
use fancy_regex::Regex;
use indicatif::{MultiProgress, ProgressBar};
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator,
};
use serde::{Deserialize, Serialize};
use std::{
cmp::Ordering,
collections::{HashMap, VecDeque},
str::FromStr,
sync::atomic::AtomicUsize,
time::Instant,
};
use strum::{AsRefStr, EnumString};
use tracing::{debug, info};
use crate::{progress_big, split::regex_segment};
// TODO:
// - maybe don't use regex
#[derive(Debug, Clone, thiserror::Error)]
pub enum TokenizerTrainError {
#[error("vocab_size must be at least {is}, but it is {min}")]
InvalidVocabularySize { is: u32, min: u32 },
}
/// A pair of adjacent tokens
type Pair = (u32, u32);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, EnumString, AsRefStr)]
pub enum SpecialTokens {
/// every document begins with the Beginning of Sequence (BOS) token that delimits documents
#[strum(to_string = "<|bos|>")]
BeginningOfSequence,
// User messages
#[strum(to_string = "<|user_start|>")]
UserStart,
#[strum(to_string = "<|user_end|>")]
UserEnd,
// Assistant messages
#[strum(to_string = "<|assistant_start|>")]
AssistantStart,
#[strum(to_string = "<|assistant_end|>")]
AssistantEnd,
/// Assistant invokes python REPL tool
#[strum(to_string = "<|python_start|>")]
PythonStart,
#[strum(to_string = "<|python_end|>")]
PythonEnd,
// Python REPL outputs back to assistant
#[strum(to_string = "<|output_start|>")]
OutputStart,
#[strum(to_string = "<|output_end|>")]
OutputEnd,
}
impl SpecialTokens {
/// An array that contains every variant of this enum
pub const VARIANTS: &[Self] = &[
Self::BeginningOfSequence,
Self::UserStart,
Self::UserEnd,
Self::AssistantStart,
Self::AssistantEnd,
Self::PythonStart,
Self::PythonEnd,
Self::OutputStart,
Self::OutputEnd,
];
/// Get the index of this variant in [Self::VARIANTS]
#[expect(clippy::unwrap_used)]
#[inline]
pub fn idx(&self) -> u32 {
Self::VARIANTS.iter().position(|p| p == self).unwrap() as u32
}
}
//
// MARK: word
//
/// A segment of text that is undergoing tokenization.
///
/// `ids` contains the ids of the tokens this word consists of.
/// Initially, these are ascii u8s. More are added as we merge pairs.
///
/// This implementation of BPE does not cross word boundaries.
/// - one token cannot represent more than one word, which may be a feature.
/// - this keeps multi-byte unicode together, since the split regex handles it correctly.
#[derive(Clone, Debug)]
struct Word {
ids: Vec<u32>,
}
impl Word {
#[inline]
fn new(ids: Vec<u32>) -> Self {
Self { ids }
}
/// Return an iterator over all adjacent pairs
/// of tokens in this word
#[inline]
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
self.ids.windows(2).map(|w| (w[0], w[1]))
}
/// Merge all non-overlapping occurrences of pair -> new_id.
/// Returns a small Vec of local pair-count deltas for THIS word only:
/// -1 for removed pairs, +1 for newly created pairs.
///
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
let (a, b) = pair;
let n = self.ids.len();
if n < 2 {
return Vec::new();
}
let mut out: Vec<u32> = Vec::with_capacity(n);
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
let mut i = 0;
while i < n {
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
let left = out.last().copied();
let right = if i + 2 < n {
Some(self.ids[i + 2])
} else {
None
};
// remove old pairs
if let Some(x) = left {
deltas.push(((x, a), -1));
deltas.push(((x, new_id), 1));
}
deltas.push(((a, b), -1));
if let Some(y) = right {
deltas.push(((b, y), -1));
deltas.push(((new_id, y), 1));
}
// write merged token
out.push(new_id);
i += 2; // skip 'a' and 'b'
} else {
out.push(self.ids[i]);
i += 1;
}
}
self.ids = out;
deltas
}
}
//
// MARK: mergejob
//
#[derive(Debug, Eq)]
struct MergeJob {
pair: Pair,
count: u64,
/// set of word indices where this pair may occur and needs processing
pos: AHashSet<usize>,
}
impl PartialEq for MergeJob {
fn eq(&self, other: &Self) -> bool {
self.count == other.count && self.pair == other.pair
}
}
impl PartialOrd for MergeJob {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MergeJob {
fn cmp(&self, other: &Self) -> Ordering {
self.count
.cmp(&other.count)
.then_with(|| other.pair.cmp(&self.pair))
}
}
//
// MARK: tokenizer
//
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
///
/// BPE's goal is to build compound tokens for a text corpus.
/// It starts with a simple initial state---in this case, each ascii u8
/// is a token. We then examine all adjacent pairs of tokens, replacing
/// common pairs with a new token id. This is repeated until our vocabulary
/// is as large as it needs to be.
pub struct Tokenizer {
/// Maps pairs of token IDs to their merged token ID
merges: HashMap<Pair, u32>,
/// Inverse of merges
unmerges: Vec<Pair>,
vocab_size: u32,
/// The regex pattern used for text splitting
split_regex: Regex,
special_regex: Regex,
/// Source of split_regex
/// (debug info)
split_regex_string: String,
/// The number of texts this tokenizer was trained on
/// (debug info)
n_train_texts: u64,
/// The number of words this tokenizer was trained on
/// (debug info)
n_train_words: u64,
}
impl Serialize for Tokenizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("Tokenizer", 4)?;
// Convert HashMap to Vec for serialization
// (only string keys are valid in json)
let merges_vec: Vec<(Pair, u32)> = self.merges.iter().map(|(&k, &v)| (k, v)).collect();
state.serialize_field("merges", &merges_vec)?;
state.serialize_field("split_regex_string", &self.split_regex_string)?;
state.serialize_field("n_train_texts", &self.n_train_texts)?;
state.serialize_field("n_train_words", &self.n_train_words)?;
state.serialize_field("vocab_size", &self.vocab_size)?;
state.end()
}
}
// Custom deserializer that automatically compiles split_regex
impl<'de> Deserialize<'de> for Tokenizer {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct TokenizerData {
merges: Vec<(Pair, u32)>,
split_regex_string: String,
n_train_texts: u64,
n_train_words: u64,
vocab_size: u32,
}
let data = TokenizerData::deserialize(deserializer)?;
let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?;
let merges: HashMap<Pair, u32> = data.merges.into_iter().collect();
#[expect(clippy::unwrap_used)] // Special regex must be valid
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
Ok(Tokenizer {
unmerges: Self::reverse_merges(&merges),
merges,
split_regex,
split_regex_string: data.split_regex_string,
special_regex,
n_train_texts: data.n_train_texts,
n_train_words: data.n_train_words,
vocab_size: data.vocab_size,
})
}
}
impl Tokenizer {
/// Default regex pattern for splitting text
const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
/// Regex for splitting special tokens
pub const SPECIAL_REGEX: &str = "<[|][a-z_]+[|]>";
/// Minimum size of this tokenizer's vocabulary.
/// - 256 for all `u8`s (token ids `0..=255`)
/// - `n` for special tokens (token ids `256..256+n`)
pub const MIN_VOCAB_SIZE: u32 = 256 + SpecialTokens::VARIANTS.len() as u32;
/// Reverse a merge map, returning an array of pairs.
/// `result[token_id - Self::MIN_VOCAB_SIZE`]` returns the pair that `token_id` represents.
fn reverse_merges(merges: &HashMap<Pair, u32>) -> Vec<Pair> {
let iter = merges.iter().map(|x| (*x.0, *x.1));
let mut array = Vec::from_iter(iter);
array.sort_by_key(|x| x.1);
let mut unmerges = Vec::with_capacity(array.len());
for (i, p) in array.iter().enumerate() {
assert_eq!(p.1, i as u32 + Self::MIN_VOCAB_SIZE);
unmerges.push(p.0);
}
unmerges
}
/// Return the regex pattern used to split words
#[inline]
pub fn get_regex(&self) -> &str {
&self.split_regex_string
}
/// Return the total number of tokens this tokenizer can produce
#[inline]
pub fn vocab_size(&self) -> u32 {
self.vocab_size
}
/// Return this tokenizer's "beginning of seq" token
#[inline]
pub fn bos_token(&self) -> u32 {
256 + SpecialTokens::BeginningOfSequence.idx()
}
/// Tokenize a string
pub fn encode(&self, text: &str) -> Vec<u32> {
let mut all_ids = Vec::new();
let special = regex_segment(&self.special_regex, text);
for s in special {
if let Ok(special) = SpecialTokens::from_str(s) {
all_ids.push(256 + special.idx());
continue;
}
for m in self.split_regex.find_iter(s) {
#[expect(clippy::unwrap_used)] // Shouldn't ever fail
let word = m.unwrap().as_str();
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
// Apply merges
while word_tokens.len() >= 2 {
// Merge the pair with the largest token idx
// (pair_start_idx, replace_with)
let mut best_pair: Option<(usize, u32)> = None;
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
let new_id = match self.merges.get(&pair) {
None => continue,
Some(x) => *x,
};
#[expect(clippy::unwrap_used)]
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
best_pair = Some((i, new_id));
}
}
match best_pair {
Some((idx, new_id)) => {
word_tokens[idx] = new_id;
word_tokens.remove(idx + 1);
}
None => {
// No merges possible
break;
}
}
}
all_ids.extend(word_tokens);
}
}
all_ids
}
/// Decode an array of tokens
pub fn decode(&self, tokens: &[u32]) -> String {
let mut str_bytes = Vec::new();
let mut buffer = VecDeque::from_iter(tokens.iter().copied());
while let Some(t) = buffer.pop_front() {
// This is a byte
if t < 256 {
str_bytes.push(t as u8);
continue;
}
let t = t - 256;
// This is a special token
if let Some(t) = SpecialTokens::VARIANTS.get(t as usize) {
str_bytes.extend_from_slice(t.as_ref().as_bytes());
continue;
}
let t = t - SpecialTokens::VARIANTS.len() as u32;
// This is a compound token
#[expect(clippy::unwrap_used)]
let pair = self.unmerges.get(t as usize).unwrap();
buffer.push_front(pair.1);
buffer.push_front(pair.0);
}
return String::from_utf8_lossy(&str_bytes).to_string();
}
//
// MARK: main training code
//
/// Given an array of words an an array of counts,
/// - count the number of occurrences of each token pair
/// - map each pair to the indices of the words that contain it
///
/// ## Notes:
/// - will panic if `words.len() != counts.len()`
/// - will not behave correctly if words are repeated
#[inline]
fn count_pairs(
words: &[Word],
counts: &[i32],
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
words
.par_iter()
.enumerate()
.map(|(i, w)| {
let mut pair_counts: AHashMap<Pair, i32> = AHashMap::new();
let mut word_map: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
if w.ids.len() >= 2 && counts[i] != 0 {
for (a, b) in w.pairs() {
*pair_counts.entry((a, b)).or_default() += counts[i];
word_map.entry((a, b)).or_default().insert(i);
}
}
(pair_counts, word_map)
})
.reduce(
|| (AHashMap::new(), AHashMap::new()),
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
for (k, v) in pc {
*acc_pc.entry(k).or_default() += v;
}
for (k, s) in wtu {
acc_wtu.entry(k).or_default().extend(s);
}
(acc_pc, acc_wtu)
},
)
}
/// Core incremental BPE training given unique words and their counts.
/// - `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
/// - `counts`: same length as `words`, count per chunk.
///
/// ## Notes:
/// - vocab size must be >= Self::MIN_VOCAB_SIZE
/// - will panic if `words.len() != counts.len()`
/// - will not behave correctly if words are repeated
fn train_core(
mp: Option<MultiProgress>,
mut words: Vec<Word>,
counts: Vec<i32>,
vocab_size: u32,
) -> HashMap<Pair, u32> {
assert!(
vocab_size >= Self::MIN_VOCAB_SIZE,
"vocab_size must be at least {}",
Self::MIN_VOCAB_SIZE
);
let num_merges = vocab_size - Self::MIN_VOCAB_SIZE;
let mut merges = HashMap::with_capacity(num_merges as usize);
info!(message = "Training tokenizer", num_merges);
let now = Instant::now();
info!(
message = "Computing initial pair counts",
n_words = words.len()
);
let (mut pair_counts, mut where_to_update) = Self::count_pairs(&words, &counts);
let mut heap = {
debug!("Building heap with {} unique pairs", pair_counts.len());
let mut heap = DaryHeap::<_, 8>::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() {
let c = *pair_counts.get(&pair).unwrap_or(&0);
if c > 0 {
heap.push(MergeJob {
pair,
count: c as u64,
pos,
});
}
}
heap
};
let pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(num_merges as u64));
pb.set_style(progress_big());
pb.set_message("Computing merges");
pb
});
let mut merges_done = 0u32;
while merges_done < num_merges {
let Some(mut top) = heap.pop() else {
break;
};
// Lazy refresh
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
if top.count != current as u64 {
top.count = current as u64;
if top.count > 0 {
heap.push(top);
}
continue;
}
if top.count == 0 {
break;
}
// Record merge
let new_id = Self::MIN_VOCAB_SIZE + merges_done;
merges.insert(top.pair, new_id);
// Merge this pair in all words where it occurs
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
for &word_idx in &top.pos {
// Apply merge to this word and collect pair-count deltas
let changes = words[word_idx].merge_pair(top.pair, new_id);
// Update global pair counts based on this word's count
for (pair, delta) in changes {
let delta_total = delta * counts[word_idx];
if delta_total != 0 {
*pair_counts.entry(pair).or_default() += delta_total;
if delta > 0 {
local_pos_updates.entry(pair).or_default().insert(word_idx);
}
}
}
}
// Add the updated pair counts back to the heap
for (pair, pos) in local_pos_updates {
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
if cnt > 0 {
heap.push(MergeJob {
pair,
count: cnt as u64,
pos,
});
}
}
merges_done += 1;
if let Some(pb) = pb.as_ref() {
pb.set_position(merges_done as u64);
}
}
if let Some(pb) = pb {
pb.finish_and_clear();
}
info!("Computed merges in {:.03?}", now.elapsed());
return merges;
}
//
// MARK: train
//
/// Train a new tokenizer from an iterator of texts.
pub fn train<I>(
mp: Option<MultiProgress>,
iterator: I,
vocab_size: u32,
) -> Result<Self, TokenizerTrainError>
where
I: Iterator<Item = String> + ExactSizeIterator,
I: ParallelBridge + Send,
{
if vocab_size < Self::MIN_VOCAB_SIZE {
return Err(TokenizerTrainError::InvalidVocabularySize {
is: vocab_size,
min: Self::MIN_VOCAB_SIZE,
});
}
let split_regex_string = Self::DEFAULT_REGEX.to_owned();
#[expect(clippy::unwrap_used)] // Default regex must be valid
let split_regex = Regex::new(&split_regex_string).unwrap();
#[expect(clippy::unwrap_used)] // Special regex must be valid
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
let now = Instant::now();
let n_train_texts = iterator.len() as u64;
debug!("Counting words in {} texts", n_train_texts);
let pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(iterator.len() as u64));
pb.set_style(progress_big());
pb.set_message("Counting words");
pb
});
// Process texts in parallel and collect chunk counts
let counter = AtomicUsize::new(0);
let counts: AHashMap<CompactString, i32> = iterator
.par_bridge()
.map(|text| {
let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new();
let special = regex_segment(&special_regex, &text);
for s in special {
if let Ok(special) = SpecialTokens::from_str(s) {
*local_counts
.entry(CompactString::from(special.as_ref()))
.or_default() += 1;
continue;
}
for mat in split_regex.find_iter(s) {
#[expect(clippy::unwrap_used)]
let piece = mat.unwrap().as_str();
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
}
}
let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if let Some(ref pb) = pb
&& count.is_multiple_of(1000)
{
pb.inc(1000);
}
local_counts
})
.reduce(AHashMap::new, |mut acc, local_counts| {
for (k, v) in local_counts {
*acc.entry(k).or_default() += v;
}
acc
});
if let Some(pb) = pb {
pb.finish_and_clear();
}
info!(
"Counted {} unique words in {:.03?}",
counts.len(),
now.elapsed()
);
// Materialize words & counts
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
let mut n_train_words = 0u64;
for (chunk, c) in counts.into_iter() {
let token_ids = match SpecialTokens::from_str(&chunk).ok() {
Some(x) => vec![256 + x.idx()],
None => chunk
.as_bytes()
.iter()
.map(|&b| b as u32)
.collect::<Vec<_>>(),
};
words.push(Word::new(token_ids));
cvec.push(c);
n_train_words += c as u64;
}
let merges = Self::train_core(mp, words, cvec, vocab_size);
Ok(Self {
split_regex,
split_regex_string,
special_regex,
n_train_texts,
n_train_words,
vocab_size: Self::MIN_VOCAB_SIZE + merges.len() as u32,
unmerges: Self::reverse_merges(&merges),
merges,
})
}
}