Compare commits

..

1 Commits

Author SHA1 Message Date
d5bf7ac5d1 Train tokenizer 2025-10-16 07:03:54 -07:00
5 changed files with 71 additions and 887 deletions

View File

@ -41,7 +41,7 @@ pub fn clap_styles() -> clap::builder::Styles {
pub fn progress_big() -> ProgressStyle { pub fn progress_big() -> ProgressStyle {
return ProgressStyle::default_bar() return ProgressStyle::default_bar()
.template( .template(
" {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim} ({eta})", " {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim}",
) )
.unwrap() .unwrap()
.progress_chars("=>-") .progress_chars("=>-")

View File

@ -27,10 +27,6 @@ pub struct TrainTokenizerArgs {
/// Number of threads to use for training /// Number of threads to use for training
#[clap(long, default_value = "0")] #[clap(long, default_value = "0")]
threads: usize, threads: usize,
/// Tokenizer vocabulary size
#[clap(long, default_value = "65535")]
n_tokens: u32,
} }
impl TrainTokenizerArgs { impl TrainTokenizerArgs {
@ -47,8 +43,8 @@ impl TrainTokenizerArgs {
let tokenizer = pool let tokenizer = pool
.install(|| match self.first_n { .install(|| match self.first_n {
Some(n) => Tokenizer::train(mp, iter.take(n), self.n_tokens), Some(n) => Tokenizer::train(mp, iter.take(n), 1024),
None => Tokenizer::train(mp, iter, self.n_tokens), None => Tokenizer::train(mp, iter, 1024),
}) })
.context("while training tokenizer")?; .context("while training tokenizer")?;

View File

@ -11,7 +11,6 @@ mod cli;
mod command; mod command;
mod data_reader; mod data_reader;
mod logging; mod logging;
mod split;
mod tokenizer; mod tokenizer;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]

View File

@ -1,266 +0,0 @@
use fancy_regex::Regex;
/// Split text using a regex while keeping both the matched parts and the parts between matches
pub fn regex_segment<'a>(re: &Regex, text: &'a str) -> Vec<&'a str> {
let mut result = Vec::new();
let mut last = 0;
for mat in re.find_iter(text) {
#[expect(clippy::unwrap_used)]
let mat = mat.unwrap();
if mat.start() > last {
result.push(&text[last..mat.start()]);
}
result.push(mat.as_str());
last = mat.end();
}
if last < text.len() {
result.push(&text[last..]);
}
result.retain(|x| !x.is_empty());
result
}
//
// MARK: tests
//
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::Tokenizer;
#[test]
fn basic() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,banana;cherry";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", "banana", ";", "cherry"]);
}
#[test]
fn empty_string() {
let re = Regex::new(r"[,;]").unwrap();
let text = "";
let result = regex_segment(&re, text);
assert_eq!(result, Vec::<&str>::new());
}
#[test]
fn no_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple banana cherry";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple banana cherry"]);
}
#[test]
fn only_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = ",;,";
let result = regex_segment(&re, text);
assert_eq!(result, vec![",", ";", ","]);
}
#[test]
fn starts_with_match() {
let re = Regex::new(r"[,;]").unwrap();
let text = ",apple;banana";
let result = regex_segment(&re, text);
assert_eq!(result, vec![",", "apple", ";", "banana"]);
}
#[test]
fn ends_with_match() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,banana;";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", "banana", ";"]);
}
#[test]
fn consecutive_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,,banana";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", ",", "banana"]);
}
#[test]
fn word_boundaries() {
let re = Regex::new(r"\b").unwrap();
let text = "hello world";
let result = regex_segment(&re, text);
// Word boundaries are zero-width, so we get empty matches between word chars and non-word chars
assert_eq!(result, vec!["hello", " ", "world"]);
}
#[test]
fn digits() {
let re = Regex::new(r"\d+").unwrap();
let text = "abc123def456ghi";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["abc", "123", "def", "456", "ghi"]);
}
#[test]
fn special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "Hello <|user_start|>world<|user_end|> test";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["Hello ", "<|user_start|>", "world", "<|user_end|>", " test"]
);
}
#[test]
fn unicode() {
let re = Regex::new(r"[=<3D>=<3D>]+").unwrap();
let text = "Hello=<3D>world=<3D>test";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["Hello", "=<3D>", "world", "=<3D>", "test"]);
}
#[test]
fn single_char() {
let re = Regex::new(r"x").unwrap();
let text = "x";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["x"]);
}
#[test]
fn multichar_match() {
let re = Regex::new(r"abc").unwrap();
let text = "123abc456abc789";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["123", "abc", "456", "abc", "789"]);
}
#[test]
fn bos_token() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|bos|>This is a document";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|bos|>", "This is a document"]);
}
#[test]
fn conversation_flow() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi there!<|assistant_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"<|user_start|>",
"Hello",
"<|user_end|>",
"<|assistant_start|>",
"Hi there!",
"<|assistant_end|>"
]
);
}
#[test]
fn python_code_block() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "Code: <|python_start|>print('hello')<|python_end|> Output: <|output_start|>hello<|output_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"Code: ",
"<|python_start|>",
"print('hello')",
"<|python_end|>",
" Output: ",
"<|output_start|>",
"hello",
"<|output_end|>"
]
);
}
#[test]
fn mixed_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text =
"<|bos|><|user_start|>Question<|user_end|><|assistant_start|>Answer<|assistant_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"<|bos|>",
"<|user_start|>",
"Question",
"<|user_end|>",
"<|assistant_start|>",
"Answer",
"<|assistant_end|>"
]
);
}
#[test]
fn no_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "This is just regular text with no special tokens";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["This is just regular text with no special tokens"]
);
}
#[test]
fn malformed_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "This has <|invalid_token> and <user_start> which shouldn't match";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["This has <|invalid_token> and <user_start> which shouldn't match"]
);
}
#[test]
fn special_tokens_with_whitespace() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = " <|bos|> \n<|user_start|>\tHello\n<|user_end|> ";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
" ",
"<|bos|>",
" \n",
"<|user_start|>",
"\tHello\n",
"<|user_end|>",
" "
]
);
}
#[test]
fn only_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|bos|><|user_start|><|user_end|>";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|bos|>", "<|user_start|>", "<|user_end|>"]);
}
#[test]
fn nested() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|<|bos|>|>";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|", "<|bos|>", "|>"]);
}
}

View File

@ -8,100 +8,26 @@ use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator, IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{cmp::Ordering, collections::HashMap, sync::atomic::AtomicUsize, time::Instant};
cmp::Ordering,
collections::{HashMap, VecDeque},
sync::atomic::AtomicUsize,
time::Instant,
};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::cli::progress_big; use crate::cli::progress_big;
use crate::split::regex_segment;
// TODO: // Notes:
// - maybe don't use regex //
// ## Why ahash?
// See docs.
// Ahash provides a uses a higher performance hash fn,
// but is not resistant to DOS and thus cannot be used
// with untrusted input.
//
// ## Changes from original impl
// - idiomatic Rust
// - significantly less memory usage
/// A pair of adjacent tokens /// A pair of adjacent tokens
type Pair = (u32, u32); type Pair = (u32, u32);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum SpecialTokens {
/// every document begins with the Beginning of Sequence (BOS) token that delimits documents
BeginningOfSequence,
// User messages
UserStart,
UserEnd,
// Assistant messages
AssistantStart,
AssistantEnd,
/// Assistant invokes python REPL tool
PythonStart,
PythonEnd,
// Python REPL outputs back to assistant
OutputStart,
OutputEnd,
}
impl SpecialTokens {
/// An array that contains every varaiant of this enum
pub const VARIANTS: &[Self] = &[
Self::BeginningOfSequence,
Self::UserStart,
Self::UserEnd,
Self::AssistantStart,
Self::AssistantEnd,
Self::PythonStart,
Self::PythonEnd,
Self::OutputStart,
Self::OutputEnd,
];
/// Get the string representation of this variant
#[inline]
pub fn as_str(&self) -> &'static str {
match self {
Self::BeginningOfSequence => "<|bos|>",
Self::UserStart => "<|user_start|>",
Self::UserEnd => "<|user_end|>",
Self::AssistantStart => "<|assistant_start|>",
Self::AssistantEnd => "<|assistant_end|>",
Self::PythonStart => "<|python_start|>",
Self::PythonEnd => "<|python_end|>",
Self::OutputStart => "<|output_start|>",
Self::OutputEnd => "<|output_end|>",
}
}
/// Get the string representation of this variant
#[inline]
pub fn from_str(s: &str) -> Option<Self> {
Some(match s {
"<|bos|>" => Self::BeginningOfSequence,
"<|user_start|>" => Self::UserStart,
"<|user_end|>" => Self::UserEnd,
"<|assistant_start|>" => Self::AssistantStart,
"<|assistant_end|>" => Self::AssistantEnd,
"<|python_start|>" => Self::PythonStart,
"<|python_end|>" => Self::PythonEnd,
"<|output_start|>" => Self::OutputStart,
"<|output_end|>" => Self::OutputEnd,
_ => return None,
})
}
/// Get the index of this variant in [Self::VARIANTS]
#[expect(clippy::unwrap_used)]
#[inline]
pub fn idx(&self) -> u32 {
Self::VARIANTS.iter().position(|p| p == self).unwrap() as u32
}
}
// //
// MARK: word // MARK: word
// //
@ -230,14 +156,11 @@ pub struct Tokenizer {
/// Maps pairs of token IDs to their merged token ID /// Maps pairs of token IDs to their merged token ID
merges: HashMap<Pair, u32>, merges: HashMap<Pair, u32>,
/// Inverse of merges n_tokens: u32,
unmerges: Vec<Pair>,
vocab_size: u32,
/// The regex pattern used for text splitting /// The regex pattern used for text splitting
#[expect(dead_code)]
split_regex: Regex, split_regex: Regex,
special_regex: Regex,
/// Source of split_regex /// Source of split_regex
/// (debug info) /// (debug info)
@ -269,7 +192,6 @@ impl Serialize for Tokenizer {
state.serialize_field("split_regex_string", &self.split_regex_string)?; state.serialize_field("split_regex_string", &self.split_regex_string)?;
state.serialize_field("n_train_texts", &self.n_train_texts)?; state.serialize_field("n_train_texts", &self.n_train_texts)?;
state.serialize_field("n_train_words", &self.n_train_words)?; state.serialize_field("n_train_words", &self.n_train_words)?;
state.serialize_field("vocab_size", &self.vocab_size)?;
state.end() state.end()
} }
} }
@ -286,28 +208,20 @@ impl<'de> Deserialize<'de> for Tokenizer {
split_regex_string: String, split_regex_string: String,
n_train_texts: u64, n_train_texts: u64,
n_train_words: u64, n_train_words: u64,
vocab_size: u32, n_tokens: u32,
} }
let data = TokenizerData::deserialize(deserializer)?; let data = TokenizerData::deserialize(deserializer)?;
let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?; let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?;
let merges: HashMap<Pair, u32> = data.merges.into_iter().collect(); let merges: HashMap<Pair, u32> = data.merges.into_iter().collect();
#[expect(clippy::unwrap_used)] // Special regex must be valid
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
Ok(Tokenizer { Ok(Tokenizer {
unmerges: Self::reverse_merges(&merges),
merges, merges,
split_regex, split_regex,
split_regex_string: data.split_regex_string, split_regex_string: data.split_regex_string,
special_regex,
n_train_texts: data.n_train_texts, n_train_texts: data.n_train_texts,
n_train_words: data.n_train_words, n_train_words: data.n_train_words,
vocab_size: data.vocab_size, n_tokens: data.n_tokens,
}) })
} }
} }
@ -316,138 +230,64 @@ impl Tokenizer {
/// Default regex pattern for splitting text /// Default regex pattern for splitting text
const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
/// Regex for splitting special tokens
pub const SPECIAL_REGEX: &str = "<[|][a-z_]+[|]>";
/// Minimum size of this tokenizer's vocabulary.
/// - 256 for all `u8`s (token ids `0..=255`)
/// - `n` for special tokens (token ids `256..256+n`)
pub const MIN_VOCAB_SIZE: u32 = 256 + SpecialTokens::VARIANTS.len() as u32;
/// Reverse a merge map, returning an array of pairs.
/// `result[token_id - Self::MIN_VOCAB_SIZE`]` returns the pair that `token_id` represents.
fn reverse_merges(merges: &HashMap<Pair, u32>) -> Vec<Pair> {
let iter = merges.iter().map(|x| (*x.0, *x.1));
let mut array = Vec::from_iter(iter);
array.sort_by_key(|x| x.1);
let mut unmerges = Vec::with_capacity(array.len());
for (i, p) in array.iter().enumerate() {
assert_eq!(p.1, i as u32 + Self::MIN_VOCAB_SIZE);
unmerges.push(p.0);
}
unmerges
}
/// Return the regex pattern used to split words /// Return the regex pattern used to split words
#[inline] #[inline]
pub fn get_regex(&self) -> &str { pub fn get_regex(&self) -> &str {
&self.split_regex_string &self.split_regex_string
} }
/// Return the total number of tokens this tokenizer can produce
#[inline]
pub fn vocab_size(&self) -> u32 {
self.vocab_size
}
/// Return this tokenizer's "beginning of seq" token
#[inline]
pub fn bos_token(&self) -> u32 {
256 + SpecialTokens::BeginningOfSequence.idx()
}
/// Tokenize a string /// Tokenize a string
pub fn encode(&self, text: &str) -> Vec<u32> { pub fn tokenize(&self, text: &str) -> Vec<u32> {
let mut all_ids = Vec::new(); let mut all_ids = Vec::new();
let special = regex_segment(&self.special_regex, text); for m in self.split_regex.find_iter(text) {
for s in special { #[expect(clippy::unwrap_used)] // Shouldn't ever fail
if let Some(special) = SpecialTokens::from_str(s) { let word = m.unwrap().as_str();
all_ids.push(256 + special.idx()); let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
continue;
}
for m in self.split_regex.find_iter(s) { // Apply merges
#[expect(clippy::unwrap_used)] // Shouldn't ever fail while word_tokens.len() >= 2 {
let word = m.unwrap().as_str(); // Merge the pair with the largest token idx
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>(); // (pair_start_idx, replace_with)
let mut best_pair: Option<(usize, u32)> = None;
// Apply merges for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
while word_tokens.len() >= 2 { let new_id = match self.merges.get(&pair) {
// Merge the pair with the largest token idx None => continue,
// (pair_start_idx, replace_with) Some(x) => *x,
let mut best_pair: Option<(usize, u32)> = None; };
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() { #[expect(clippy::unwrap_used)]
let new_id = match self.merges.get(&pair) { if best_pair.is_none() || new_id < best_pair.unwrap().1 {
None => continue, best_pair = Some((i, new_id));
Some(x) => *x,
};
#[expect(clippy::unwrap_used)]
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
best_pair = Some((i, new_id));
}
}
match best_pair {
Some((idx, new_id)) => {
word_tokens[idx] = new_id;
word_tokens.remove(idx + 1);
}
None => {
// No merges possible
break;
}
} }
} }
all_ids.extend(word_tokens); match best_pair {
Some((idx, new_id)) => {
word_tokens[idx] = new_id;
word_tokens.remove(idx + 1);
}
None => {
// No merges possible
break;
}
}
} }
all_ids.extend(word_tokens);
} }
all_ids all_ids
} }
/// Decode an array of tokens
pub fn decode(&self, tokens: &[u32]) -> String {
let mut str_bytes = Vec::new();
let mut buffer = VecDeque::from_iter(tokens.iter().copied());
while let Some(t) = buffer.pop_front() {
// This is a byte
if t < 256 {
str_bytes.push(t as u8);
continue;
}
let t = t - 256;
// This is a special token
if let Some(t) = SpecialTokens::VARIANTS.get(t as usize) {
str_bytes.extend_from_slice(t.as_str().as_bytes());
continue;
}
let t = t - SpecialTokens::VARIANTS.len() as u32;
// This is a compound token
#[expect(clippy::unwrap_used)]
let pair = self.unmerges.get(t as usize).unwrap();
buffer.push_front(pair.1);
buffer.push_front(pair.0);
}
return String::from_utf8_lossy(&str_bytes).to_string();
}
// //
// MARK: main training code // MARK: main training code
// //
// TODO: pool config
/// Given an array of words an an array of counts, /// Given an array of words an an array of counts,
/// - count the number of occurrences of each token pair /// - count the number of occurrences of each token pair
/// - map each pair to the indices of the words that contain it /// - map each pair to the indices of the words that contain it
@ -495,7 +335,7 @@ impl Tokenizer {
/// - `counts`: same length as `words`, count per chunk. /// - `counts`: same length as `words`, count per chunk.
/// ///
/// ## Notes: /// ## Notes:
/// - vocab size must be >= Self::MIN_VOCAB_SIZE /// - vocab size must be >= 256
/// - will panic if `words.len() != counts.len()` /// - will panic if `words.len() != counts.len()`
/// - will not behave correctly if words are repeated /// - will not behave correctly if words are repeated
fn train_core( fn train_core(
@ -504,13 +344,9 @@ impl Tokenizer {
counts: Vec<i32>, counts: Vec<i32>,
vocab_size: u32, vocab_size: u32,
) -> HashMap<Pair, u32> { ) -> HashMap<Pair, u32> {
assert!( assert!(vocab_size >= 256, "vocab_size must be at least 256");
vocab_size >= Self::MIN_VOCAB_SIZE,
"vocab_size must be at least {}",
Self::MIN_VOCAB_SIZE
);
let num_merges = vocab_size - Self::MIN_VOCAB_SIZE; let num_merges = vocab_size - 256;
let mut merges = HashMap::with_capacity(num_merges as usize); let mut merges = HashMap::with_capacity(num_merges as usize);
info!(message = "Training tokenizer", num_merges); info!(message = "Training tokenizer", num_merges);
let now = Instant::now(); let now = Instant::now();
@ -567,7 +403,7 @@ impl Tokenizer {
} }
// Record merge // Record merge
let new_id = Self::MIN_VOCAB_SIZE + merges_done; let new_id = 256 + merges_done;
merges.insert(top.pair, new_id); merges.insert(top.pair, new_id);
// Merge this pair in all words where it occurs // Merge this pair in all words where it occurs
@ -624,20 +460,14 @@ impl Tokenizer {
I: Iterator<Item = String> + ExactSizeIterator, I: Iterator<Item = String> + ExactSizeIterator,
I: ParallelBridge + Send, I: ParallelBridge + Send,
{ {
if vocab_size < Self::MIN_VOCAB_SIZE { if vocab_size < 256 {
bail!( bail!("vocab_size must be at least 256, but it is {vocab_size}");
"vocab_size must be at least {}, but it is {vocab_size}",
Self::MIN_VOCAB_SIZE
);
} }
let split_regex_string = Self::DEFAULT_REGEX.to_owned(); let split_regex_string = Self::DEFAULT_REGEX.to_owned();
#[expect(clippy::unwrap_used)] // Default regex must be valid #[expect(clippy::unwrap_used)] // Default regex must be valid
let split_regex = Regex::new(&split_regex_string).unwrap(); let split_regex = Regex::new(&split_regex_string).unwrap();
#[expect(clippy::unwrap_used)] // Special regex must be valid
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
let now = Instant::now(); let now = Instant::now();
let n_train_texts = iterator.len() as u64; let n_train_texts = iterator.len() as u64;
debug!("Counting words in {} texts", n_train_texts); debug!("Counting words in {} texts", n_train_texts);
@ -654,29 +484,18 @@ impl Tokenizer {
.par_bridge() .par_bridge()
.map(|text| { .map(|text| {
let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new(); let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new();
for mat in split_regex.find_iter(&text) {
#[expect(clippy::unwrap_used)]
let piece = mat.unwrap().as_str();
let special = regex_segment(&special_regex, &text); *local_counts.entry(CompactString::from(piece)).or_default() += 1;
for s in special {
if let Some(special) = SpecialTokens::from_str(s) {
*local_counts
.entry(CompactString::from(special.as_str()))
.or_default() += 1;
continue;
}
for mat in split_regex.find_iter(s) {
#[expect(clippy::unwrap_used)]
let piece = mat.unwrap().as_str();
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
}
} }
let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if let Some(ref pb) = pb if let Some(ref pb) = pb {
&& count % 1000 == 0 if count % 1000 == 0 {
{ pb.inc(1000);
pb.inc(1000); }
} }
local_counts local_counts
@ -703,16 +522,9 @@ impl Tokenizer {
let mut cvec = Vec::with_capacity(counts.len()); let mut cvec = Vec::with_capacity(counts.len());
let mut n_train_words = 0u64; let mut n_train_words = 0u64;
for (chunk, c) in counts.into_iter() { for (chunk, c) in counts.into_iter() {
let token_ids = match SpecialTokens::from_str(&chunk) { words.push(Word::new(
Some(x) => vec![256 + x.idx()], chunk.as_bytes().iter().map(|&b| b as u32).collect(),
None => chunk ));
.as_bytes()
.iter()
.map(|&b| b as u32)
.collect::<Vec<_>>(),
};
words.push(Word::new(token_ids));
cvec.push(c); cvec.push(c);
n_train_words += c as u64; n_train_words += c as u64;
} }
@ -720,369 +532,12 @@ impl Tokenizer {
let merges = Self::train_core(mp, words, cvec, vocab_size); let merges = Self::train_core(mp, words, cvec, vocab_size);
Ok(Self { Ok(Self {
merges,
split_regex, split_regex,
split_regex_string, split_regex_string,
special_regex,
n_train_texts, n_train_texts,
n_train_words, n_train_words,
vocab_size: Self::MIN_VOCAB_SIZE + merges.len() as u32, n_tokens: vocab_size,
unmerges: Self::reverse_merges(&merges),
merges,
}) })
} }
} }
//
// MARK: tests
//
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode() {
let tokenizer = Tokenizer::train(
None,
[
"hello world".to_string(),
"hello there".to_string(),
"hello hello world".to_string(),
]
.into_iter(),
300,
)
.unwrap();
let test_cases = vec![
"hello",
"world",
"hello world",
"hello there world",
"abc123",
"",
];
for test_str in test_cases {
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded, "Failed roundtrip for: '{}'", test_str);
}
}
//
// MARK: individual tests
//
fn get_test_tokenizer() -> Tokenizer {
Tokenizer::train(
None,
[
"hello world".to_string(),
"<|bos|> test".to_string(),
"user: <|user_start|>hi<|user_end|>".to_string(),
]
.into_iter(),
300,
)
.unwrap()
}
#[test]
fn bos_token() {
let tokenizer = get_test_tokenizer();
let test_str = "<|bos|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert_eq!(
[256 + SpecialTokens::BeginningOfSequence.idx()],
&encoded[..]
);
}
#[test]
fn user_start_token() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_start|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert_eq!([256 + SpecialTokens::UserStart.idx()], &encoded[..]);
}
#[test]
fn partial_bos_start() {
let tokenizer = get_test_tokenizer();
let test_str = "<|bos";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
assert_eq!(5, encoded.len())
}
#[test]
fn partial_bos_end() {
let tokenizer = get_test_tokenizer();
let test_str = "|bos|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
assert_eq!(6, encoded.len())
}
#[test]
fn malformed_user_star() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_star|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
assert_eq!(10, encoded.len())
}
#[test]
fn partial_user_start_no_end() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_start";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
assert_eq!(9, encoded.len())
}
#[test]
fn partial_user_start_no_begin() {
let tokenizer = get_test_tokenizer();
let test_str = "user_start|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
assert_eq!(9, encoded.len())
}
#[test]
fn nested_bos() {
let tokenizer = get_test_tokenizer();
let test_str = "<|<|bos|>|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
assert_eq!(5, encoded.len())
}
#[test]
fn double_nested_user() {
let tokenizer = get_test_tokenizer();
let test_str = "<<||user_start||>>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn conversation_flow() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi<|assistant_end|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::AssistantEnd.idx())));
}
#[test]
fn bos_with_control_chars() {
let tokenizer = get_test_tokenizer();
let test_str = "<|bos|>\n\t\r";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn user_with_emoji() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_start|>🚀💻<|user_end|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
}
#[test]
fn python_code_block() {
let tokenizer = get_test_tokenizer();
let test_str = "<|python_start|>print('hello')\n<|python_end|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
}
#[test]
fn bos_with_spaces() {
let tokenizer = get_test_tokenizer();
let test_str = " <|bos|> ";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn user_start_with_newlines() {
let tokenizer = get_test_tokenizer();
let test_str = "\n<|user_start|>\n";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn user_end_tab_assistant_start() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_end|>\t<|assistant_start|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
}
#[test]
fn empty_angle_brackets() {
let tokenizer = get_test_tokenizer();
let test_str = "<|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
}
#[test]
fn no_angle_brackets() {
let tokenizer = get_test_tokenizer();
let test_str = "|user_start|";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn regular_angle_brackets() {
let tokenizer = get_test_tokenizer();
let test_str = "<user_start>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn double_pipe_user_start() {
let tokenizer = get_test_tokenizer();
let test_str = "<|user_start||>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn double_pipe_prefix() {
let tokenizer = get_test_tokenizer();
let test_str = "<||user_start|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
}
#[test]
fn normal_text_with_bos() {
let tokenizer = get_test_tokenizer();
let test_str = "Normal text <|bos|> more text";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn code_with_output() {
let tokenizer = get_test_tokenizer();
let test_str =
"Code: <|python_start|>x = 1<|python_end|> Result: <|output_start|>1<|output_end|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::OutputStart.idx())));
assert!(encoded.contains(&(256 + SpecialTokens::OutputEnd.idx())));
}
#[test]
fn escaped_bos() {
let tokenizer = get_test_tokenizer();
let test_str = "\\<|bos|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn escaped_pipe_bos() {
let tokenizer = get_test_tokenizer();
let test_str = "<|bos\\|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn escaped_s_bos() {
let tokenizer = get_test_tokenizer();
let test_str = "<|bo\\s|>";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
}
#[test]
fn empty_string() {
let tokenizer = get_test_tokenizer();
let test_str = "";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
assert!(encoded.is_empty());
}
#[test]
fn spaces_only() {
let tokenizer = get_test_tokenizer();
let test_str = " ";
let encoded = tokenizer.encode(test_str);
let decoded = tokenizer.decode(&encoded);
assert_eq!(test_str, decoded);
}
}