|
|
|
@ -8,26 +8,100 @@ use rayon::iter::{
|
|
|
|
|
IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator,
|
|
|
|
|
};
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
use std::{cmp::Ordering, collections::HashMap, sync::atomic::AtomicUsize, time::Instant};
|
|
|
|
|
use std::{
|
|
|
|
|
cmp::Ordering,
|
|
|
|
|
collections::{HashMap, VecDeque},
|
|
|
|
|
sync::atomic::AtomicUsize,
|
|
|
|
|
time::Instant,
|
|
|
|
|
};
|
|
|
|
|
use tracing::{debug, info};
|
|
|
|
|
|
|
|
|
|
use crate::cli::progress_big;
|
|
|
|
|
use crate::split::regex_segment;
|
|
|
|
|
|
|
|
|
|
// Notes:
|
|
|
|
|
//
|
|
|
|
|
// ## Why ahash?
|
|
|
|
|
// See docs.
|
|
|
|
|
// Ahash provides a uses a higher performance hash fn,
|
|
|
|
|
// but is not resistant to DOS and thus cannot be used
|
|
|
|
|
// with untrusted input.
|
|
|
|
|
//
|
|
|
|
|
// ## Changes from original impl
|
|
|
|
|
// - idiomatic Rust
|
|
|
|
|
// - significantly less memory usage
|
|
|
|
|
// TODO:
|
|
|
|
|
// - maybe don't use regex
|
|
|
|
|
|
|
|
|
|
/// A pair of adjacent tokens
|
|
|
|
|
type Pair = (u32, u32);
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
|
|
|
|
pub enum SpecialTokens {
|
|
|
|
|
/// every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
|
|
|
|
BeginningOfSequence,
|
|
|
|
|
|
|
|
|
|
// User messages
|
|
|
|
|
UserStart,
|
|
|
|
|
UserEnd,
|
|
|
|
|
|
|
|
|
|
// Assistant messages
|
|
|
|
|
AssistantStart,
|
|
|
|
|
AssistantEnd,
|
|
|
|
|
|
|
|
|
|
/// Assistant invokes python REPL tool
|
|
|
|
|
PythonStart,
|
|
|
|
|
PythonEnd,
|
|
|
|
|
|
|
|
|
|
// Python REPL outputs back to assistant
|
|
|
|
|
OutputStart,
|
|
|
|
|
OutputEnd,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SpecialTokens {
|
|
|
|
|
/// An array that contains every varaiant of this enum
|
|
|
|
|
pub const VARIANTS: &[Self] = &[
|
|
|
|
|
Self::BeginningOfSequence,
|
|
|
|
|
Self::UserStart,
|
|
|
|
|
Self::UserEnd,
|
|
|
|
|
Self::AssistantStart,
|
|
|
|
|
Self::AssistantEnd,
|
|
|
|
|
Self::PythonStart,
|
|
|
|
|
Self::PythonEnd,
|
|
|
|
|
Self::OutputStart,
|
|
|
|
|
Self::OutputEnd,
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
/// Get the string representation of this variant
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn as_str(&self) -> &'static str {
|
|
|
|
|
match self {
|
|
|
|
|
Self::BeginningOfSequence => "<|bos|>",
|
|
|
|
|
Self::UserStart => "<|user_start|>",
|
|
|
|
|
Self::UserEnd => "<|user_end|>",
|
|
|
|
|
Self::AssistantStart => "<|assistant_start|>",
|
|
|
|
|
Self::AssistantEnd => "<|assistant_end|>",
|
|
|
|
|
Self::PythonStart => "<|python_start|>",
|
|
|
|
|
Self::PythonEnd => "<|python_end|>",
|
|
|
|
|
Self::OutputStart => "<|output_start|>",
|
|
|
|
|
Self::OutputEnd => "<|output_end|>",
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Get the string representation of this variant
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn from_str(s: &str) -> Option<Self> {
|
|
|
|
|
Some(match s {
|
|
|
|
|
"<|bos|>" => Self::BeginningOfSequence,
|
|
|
|
|
"<|user_start|>" => Self::UserStart,
|
|
|
|
|
"<|user_end|>" => Self::UserEnd,
|
|
|
|
|
"<|assistant_start|>" => Self::AssistantStart,
|
|
|
|
|
"<|assistant_end|>" => Self::AssistantEnd,
|
|
|
|
|
"<|python_start|>" => Self::PythonStart,
|
|
|
|
|
"<|python_end|>" => Self::PythonEnd,
|
|
|
|
|
"<|output_start|>" => Self::OutputStart,
|
|
|
|
|
"<|output_end|>" => Self::OutputEnd,
|
|
|
|
|
_ => return None,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Get the index of this variant in [Self::VARIANTS]
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn idx(&self) -> u32 {
|
|
|
|
|
Self::VARIANTS.iter().position(|p| p == self).unwrap() as u32
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// MARK: word
|
|
|
|
|
//
|
|
|
|
@ -156,11 +230,14 @@ pub struct Tokenizer {
|
|
|
|
|
/// Maps pairs of token IDs to their merged token ID
|
|
|
|
|
merges: HashMap<Pair, u32>,
|
|
|
|
|
|
|
|
|
|
n_tokens: u32,
|
|
|
|
|
/// Inverse of merges
|
|
|
|
|
unmerges: Vec<Pair>,
|
|
|
|
|
|
|
|
|
|
vocab_size: u32,
|
|
|
|
|
|
|
|
|
|
/// The regex pattern used for text splitting
|
|
|
|
|
#[expect(dead_code)]
|
|
|
|
|
split_regex: Regex,
|
|
|
|
|
special_regex: Regex,
|
|
|
|
|
|
|
|
|
|
/// Source of split_regex
|
|
|
|
|
/// (debug info)
|
|
|
|
@ -192,6 +269,7 @@ impl Serialize for Tokenizer {
|
|
|
|
|
state.serialize_field("split_regex_string", &self.split_regex_string)?;
|
|
|
|
|
state.serialize_field("n_train_texts", &self.n_train_texts)?;
|
|
|
|
|
state.serialize_field("n_train_words", &self.n_train_words)?;
|
|
|
|
|
state.serialize_field("vocab_size", &self.vocab_size)?;
|
|
|
|
|
state.end()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -208,20 +286,28 @@ impl<'de> Deserialize<'de> for Tokenizer {
|
|
|
|
|
split_regex_string: String,
|
|
|
|
|
n_train_texts: u64,
|
|
|
|
|
n_train_words: u64,
|
|
|
|
|
n_tokens: u32,
|
|
|
|
|
vocab_size: u32,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let data = TokenizerData::deserialize(deserializer)?;
|
|
|
|
|
let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?;
|
|
|
|
|
|
|
|
|
|
let merges: HashMap<Pair, u32> = data.merges.into_iter().collect();
|
|
|
|
|
|
|
|
|
|
#[expect(clippy::unwrap_used)] // Special regex must be valid
|
|
|
|
|
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
|
|
|
|
|
|
|
|
|
|
Ok(Tokenizer {
|
|
|
|
|
unmerges: Self::reverse_merges(&merges),
|
|
|
|
|
merges,
|
|
|
|
|
|
|
|
|
|
split_regex,
|
|
|
|
|
split_regex_string: data.split_regex_string,
|
|
|
|
|
special_regex,
|
|
|
|
|
|
|
|
|
|
n_train_texts: data.n_train_texts,
|
|
|
|
|
n_train_words: data.n_train_words,
|
|
|
|
|
n_tokens: data.n_tokens,
|
|
|
|
|
vocab_size: data.vocab_size,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -230,64 +316,138 @@ impl Tokenizer {
|
|
|
|
|
/// Default regex pattern for splitting text
|
|
|
|
|
const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
|
|
|
|
|
|
|
|
|
|
/// Regex for splitting special tokens
|
|
|
|
|
pub const SPECIAL_REGEX: &str = "<[|][a-z_]+[|]>";
|
|
|
|
|
|
|
|
|
|
/// Minimum size of this tokenizer's vocabulary.
|
|
|
|
|
/// - 256 for all `u8`s (token ids `0..=255`)
|
|
|
|
|
/// - `n` for special tokens (token ids `256..256+n`)
|
|
|
|
|
pub const MIN_VOCAB_SIZE: u32 = 256 + SpecialTokens::VARIANTS.len() as u32;
|
|
|
|
|
|
|
|
|
|
/// Reverse a merge map, returning an array of pairs.
|
|
|
|
|
/// `result[token_id - Self::MIN_VOCAB_SIZE`]` returns the pair that `token_id` represents.
|
|
|
|
|
fn reverse_merges(merges: &HashMap<Pair, u32>) -> Vec<Pair> {
|
|
|
|
|
let iter = merges.iter().map(|x| (*x.0, *x.1));
|
|
|
|
|
let mut array = Vec::from_iter(iter);
|
|
|
|
|
array.sort_by_key(|x| x.1);
|
|
|
|
|
|
|
|
|
|
let mut unmerges = Vec::with_capacity(array.len());
|
|
|
|
|
for (i, p) in array.iter().enumerate() {
|
|
|
|
|
assert_eq!(p.1, i as u32 + Self::MIN_VOCAB_SIZE);
|
|
|
|
|
unmerges.push(p.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unmerges
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Return the regex pattern used to split words
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn get_regex(&self) -> &str {
|
|
|
|
|
&self.split_regex_string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Return the total number of tokens this tokenizer can produce
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn vocab_size(&self) -> u32 {
|
|
|
|
|
self.vocab_size
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Return this tokenizer's "beginning of seq" token
|
|
|
|
|
#[inline]
|
|
|
|
|
pub fn bos_token(&self) -> u32 {
|
|
|
|
|
256 + SpecialTokens::BeginningOfSequence.idx()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Tokenize a string
|
|
|
|
|
pub fn tokenize(&self, text: &str) -> Vec<u32> {
|
|
|
|
|
pub fn encode(&self, text: &str) -> Vec<u32> {
|
|
|
|
|
let mut all_ids = Vec::new();
|
|
|
|
|
|
|
|
|
|
for m in self.split_regex.find_iter(text) {
|
|
|
|
|
#[expect(clippy::unwrap_used)] // Shouldn't ever fail
|
|
|
|
|
let word = m.unwrap().as_str();
|
|
|
|
|
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
// Apply merges
|
|
|
|
|
while word_tokens.len() >= 2 {
|
|
|
|
|
// Merge the pair with the largest token idx
|
|
|
|
|
// (pair_start_idx, replace_with)
|
|
|
|
|
let mut best_pair: Option<(usize, u32)> = None;
|
|
|
|
|
|
|
|
|
|
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
|
|
|
|
|
let new_id = match self.merges.get(&pair) {
|
|
|
|
|
None => continue,
|
|
|
|
|
Some(x) => *x,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
|
|
|
|
|
best_pair = Some((i, new_id));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
match best_pair {
|
|
|
|
|
Some((idx, new_id)) => {
|
|
|
|
|
word_tokens[idx] = new_id;
|
|
|
|
|
word_tokens.remove(idx + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
None => {
|
|
|
|
|
// No merges possible
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
let special = regex_segment(&self.special_regex, text);
|
|
|
|
|
for s in special {
|
|
|
|
|
if let Some(special) = SpecialTokens::from_str(s) {
|
|
|
|
|
all_ids.push(256 + special.idx());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
all_ids.extend(word_tokens);
|
|
|
|
|
for m in self.split_regex.find_iter(s) {
|
|
|
|
|
#[expect(clippy::unwrap_used)] // Shouldn't ever fail
|
|
|
|
|
let word = m.unwrap().as_str();
|
|
|
|
|
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
|
|
|
|
|
|
|
|
|
|
// Apply merges
|
|
|
|
|
while word_tokens.len() >= 2 {
|
|
|
|
|
// Merge the pair with the largest token idx
|
|
|
|
|
// (pair_start_idx, replace_with)
|
|
|
|
|
let mut best_pair: Option<(usize, u32)> = None;
|
|
|
|
|
|
|
|
|
|
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
|
|
|
|
|
let new_id = match self.merges.get(&pair) {
|
|
|
|
|
None => continue,
|
|
|
|
|
Some(x) => *x,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
|
|
|
|
|
best_pair = Some((i, new_id));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
match best_pair {
|
|
|
|
|
Some((idx, new_id)) => {
|
|
|
|
|
word_tokens[idx] = new_id;
|
|
|
|
|
word_tokens.remove(idx + 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
None => {
|
|
|
|
|
// No merges possible
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
all_ids.extend(word_tokens);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
all_ids
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Decode an array of tokens
|
|
|
|
|
pub fn decode(&self, tokens: &[u32]) -> String {
|
|
|
|
|
let mut str_bytes = Vec::new();
|
|
|
|
|
let mut buffer = VecDeque::from_iter(tokens.iter().copied());
|
|
|
|
|
|
|
|
|
|
while let Some(t) = buffer.pop_front() {
|
|
|
|
|
// This is a byte
|
|
|
|
|
if t < 256 {
|
|
|
|
|
str_bytes.push(t as u8);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let t = t - 256;
|
|
|
|
|
|
|
|
|
|
// This is a special token
|
|
|
|
|
if let Some(t) = SpecialTokens::VARIANTS.get(t as usize) {
|
|
|
|
|
str_bytes.extend_from_slice(t.as_str().as_bytes());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let t = t - SpecialTokens::VARIANTS.len() as u32;
|
|
|
|
|
|
|
|
|
|
// This is a compound token
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
let pair = self.unmerges.get(t as usize).unwrap();
|
|
|
|
|
buffer.push_front(pair.1);
|
|
|
|
|
buffer.push_front(pair.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return String::from_utf8_lossy(&str_bytes).to_string();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// MARK: main training code
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
// TODO: pool config
|
|
|
|
|
|
|
|
|
|
/// Given an array of words an an array of counts,
|
|
|
|
|
/// - count the number of occurrences of each token pair
|
|
|
|
|
/// - map each pair to the indices of the words that contain it
|
|
|
|
@ -335,7 +495,7 @@ impl Tokenizer {
|
|
|
|
|
/// - `counts`: same length as `words`, count per chunk.
|
|
|
|
|
///
|
|
|
|
|
/// ## Notes:
|
|
|
|
|
/// - vocab size must be >= 256
|
|
|
|
|
/// - vocab size must be >= Self::MIN_VOCAB_SIZE
|
|
|
|
|
/// - will panic if `words.len() != counts.len()`
|
|
|
|
|
/// - will not behave correctly if words are repeated
|
|
|
|
|
fn train_core(
|
|
|
|
@ -344,9 +504,13 @@ impl Tokenizer {
|
|
|
|
|
counts: Vec<i32>,
|
|
|
|
|
vocab_size: u32,
|
|
|
|
|
) -> HashMap<Pair, u32> {
|
|
|
|
|
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
|
|
|
|
assert!(
|
|
|
|
|
vocab_size >= Self::MIN_VOCAB_SIZE,
|
|
|
|
|
"vocab_size must be at least {}",
|
|
|
|
|
Self::MIN_VOCAB_SIZE
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let num_merges = vocab_size - 256;
|
|
|
|
|
let num_merges = vocab_size - Self::MIN_VOCAB_SIZE;
|
|
|
|
|
let mut merges = HashMap::with_capacity(num_merges as usize);
|
|
|
|
|
info!(message = "Training tokenizer", num_merges);
|
|
|
|
|
let now = Instant::now();
|
|
|
|
@ -403,7 +567,7 @@ impl Tokenizer {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Record merge
|
|
|
|
|
let new_id = 256 + merges_done;
|
|
|
|
|
let new_id = Self::MIN_VOCAB_SIZE + merges_done;
|
|
|
|
|
merges.insert(top.pair, new_id);
|
|
|
|
|
|
|
|
|
|
// Merge this pair in all words where it occurs
|
|
|
|
@ -460,14 +624,20 @@ impl Tokenizer {
|
|
|
|
|
I: Iterator<Item = String> + ExactSizeIterator,
|
|
|
|
|
I: ParallelBridge + Send,
|
|
|
|
|
{
|
|
|
|
|
if vocab_size < 256 {
|
|
|
|
|
bail!("vocab_size must be at least 256, but it is {vocab_size}");
|
|
|
|
|
if vocab_size < Self::MIN_VOCAB_SIZE {
|
|
|
|
|
bail!(
|
|
|
|
|
"vocab_size must be at least {}, but it is {vocab_size}",
|
|
|
|
|
Self::MIN_VOCAB_SIZE
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let split_regex_string = Self::DEFAULT_REGEX.to_owned();
|
|
|
|
|
#[expect(clippy::unwrap_used)] // Default regex must be valid
|
|
|
|
|
let split_regex = Regex::new(&split_regex_string).unwrap();
|
|
|
|
|
|
|
|
|
|
#[expect(clippy::unwrap_used)] // Special regex must be valid
|
|
|
|
|
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
|
|
|
|
|
|
|
|
|
|
let now = Instant::now();
|
|
|
|
|
let n_train_texts = iterator.len() as u64;
|
|
|
|
|
debug!("Counting words in {} texts", n_train_texts);
|
|
|
|
@ -484,18 +654,29 @@ impl Tokenizer {
|
|
|
|
|
.par_bridge()
|
|
|
|
|
.map(|text| {
|
|
|
|
|
let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new();
|
|
|
|
|
for mat in split_regex.find_iter(&text) {
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
let piece = mat.unwrap().as_str();
|
|
|
|
|
|
|
|
|
|
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
|
|
|
|
|
let special = regex_segment(&special_regex, &text);
|
|
|
|
|
for s in special {
|
|
|
|
|
if let Some(special) = SpecialTokens::from_str(s) {
|
|
|
|
|
*local_counts
|
|
|
|
|
.entry(CompactString::from(special.as_str()))
|
|
|
|
|
.or_default() += 1;
|
|
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for mat in split_regex.find_iter(s) {
|
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
|
|
|
let piece = mat.unwrap().as_str();
|
|
|
|
|
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
|
if let Some(ref pb) = pb {
|
|
|
|
|
if count % 1000 == 0 {
|
|
|
|
|
pb.inc(1000);
|
|
|
|
|
}
|
|
|
|
|
if let Some(ref pb) = pb
|
|
|
|
|
&& count % 1000 == 0
|
|
|
|
|
{
|
|
|
|
|
pb.inc(1000);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
local_counts
|
|
|
|
@ -522,9 +703,16 @@ impl Tokenizer {
|
|
|
|
|
let mut cvec = Vec::with_capacity(counts.len());
|
|
|
|
|
let mut n_train_words = 0u64;
|
|
|
|
|
for (chunk, c) in counts.into_iter() {
|
|
|
|
|
words.push(Word::new(
|
|
|
|
|
chunk.as_bytes().iter().map(|&b| b as u32).collect(),
|
|
|
|
|
));
|
|
|
|
|
let token_ids = match SpecialTokens::from_str(&chunk) {
|
|
|
|
|
Some(x) => vec![256 + x.idx()],
|
|
|
|
|
None => chunk
|
|
|
|
|
.as_bytes()
|
|
|
|
|
.iter()
|
|
|
|
|
.map(|&b| b as u32)
|
|
|
|
|
.collect::<Vec<_>>(),
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
words.push(Word::new(token_ids));
|
|
|
|
|
cvec.push(c);
|
|
|
|
|
n_train_words += c as u64;
|
|
|
|
|
}
|
|
|
|
@ -532,12 +720,369 @@ impl Tokenizer {
|
|
|
|
|
let merges = Self::train_core(mp, words, cvec, vocab_size);
|
|
|
|
|
|
|
|
|
|
Ok(Self {
|
|
|
|
|
merges,
|
|
|
|
|
split_regex,
|
|
|
|
|
split_regex_string,
|
|
|
|
|
special_regex,
|
|
|
|
|
n_train_texts,
|
|
|
|
|
n_train_words,
|
|
|
|
|
n_tokens: vocab_size,
|
|
|
|
|
vocab_size: Self::MIN_VOCAB_SIZE + merges.len() as u32,
|
|
|
|
|
|
|
|
|
|
unmerges: Self::reverse_merges(&merges),
|
|
|
|
|
merges,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// MARK: tests
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn encode_decode() {
|
|
|
|
|
let tokenizer = Tokenizer::train(
|
|
|
|
|
None,
|
|
|
|
|
[
|
|
|
|
|
"hello world".to_string(),
|
|
|
|
|
"hello there".to_string(),
|
|
|
|
|
"hello hello world".to_string(),
|
|
|
|
|
]
|
|
|
|
|
.into_iter(),
|
|
|
|
|
300,
|
|
|
|
|
)
|
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
|
|
let test_cases = vec![
|
|
|
|
|
"hello",
|
|
|
|
|
"world",
|
|
|
|
|
"hello world",
|
|
|
|
|
"hello there world",
|
|
|
|
|
"abc123",
|
|
|
|
|
"",
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
for test_str in test_cases {
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded, "Failed roundtrip for: '{}'", test_str);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// MARK: individual tests
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
fn get_test_tokenizer() -> Tokenizer {
|
|
|
|
|
Tokenizer::train(
|
|
|
|
|
None,
|
|
|
|
|
[
|
|
|
|
|
"hello world".to_string(),
|
|
|
|
|
"<|bos|> test".to_string(),
|
|
|
|
|
"user: <|user_start|>hi<|user_end|>".to_string(),
|
|
|
|
|
]
|
|
|
|
|
.into_iter(),
|
|
|
|
|
300,
|
|
|
|
|
)
|
|
|
|
|
.unwrap()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn bos_token() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|bos|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
[256 + SpecialTokens::BeginningOfSequence.idx()],
|
|
|
|
|
&encoded[..]
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn user_start_token() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_start|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert_eq!([256 + SpecialTokens::UserStart.idx()], &encoded[..]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn partial_bos_start() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|bos";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
assert_eq!(5, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn partial_bos_end() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "|bos|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
assert_eq!(6, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn malformed_user_star() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_star|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
assert_eq!(10, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn partial_user_start_no_end() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_start";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
assert_eq!(9, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn partial_user_start_no_begin() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "user_start|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
assert_eq!(9, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn nested_bos() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|<|bos|>|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
assert_eq!(5, encoded.len())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn double_nested_user() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<<||user_start||>>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn conversation_flow() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi<|assistant_end|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantEnd.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn bos_with_control_chars() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|bos|>\n\t\r";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn user_with_emoji() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_start|>🚀💻<|user_end|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn python_code_block() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|python_start|>print('hello')\n<|python_end|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn bos_with_spaces() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = " <|bos|> ";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn user_start_with_newlines() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "\n<|user_start|>\n";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn user_end_tab_assistant_start() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_end|>\t<|assistant_start|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn empty_angle_brackets() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn no_angle_brackets() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "|user_start|";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn regular_angle_brackets() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<user_start>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn double_pipe_user_start() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|user_start||>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn double_pipe_prefix() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<||user_start|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn normal_text_with_bos() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "Normal text <|bos|> more text";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn code_with_output() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str =
|
|
|
|
|
"Code: <|python_start|>x = 1<|python_end|> Result: <|output_start|>1<|output_end|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::OutputStart.idx())));
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::OutputEnd.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn escaped_bos() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "\\<|bos|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn escaped_pipe_bos() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|bos\\|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn escaped_s_bos() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "<|bo\\s|>";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn empty_string() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = "";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
assert!(encoded.is_empty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn spaces_only() {
|
|
|
|
|
let tokenizer = get_test_tokenizer();
|
|
|
|
|
let test_str = " ";
|
|
|
|
|
let encoded = tokenizer.encode(test_str);
|
|
|
|
|
let decoded = tokenizer.decode(&encoded);
|
|
|
|
|
assert_eq!(test_str, decoded);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|