Tokenizer
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,4 +2,6 @@
|
|||||||
**/.env
|
**/.env
|
||||||
**/*.ignore
|
**/*.ignore
|
||||||
**/.DS_Store
|
**/.DS_Store
|
||||||
|
|
||||||
.lycheecache
|
.lycheecache
|
||||||
|
**/tokenizer.json
|
||||||
|
|||||||
6180
Cargo.lock
generated
Normal file
6180
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
97
Cargo.toml
Normal file
97
Cargo.toml
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
[workspace]
|
||||||
|
members = ["crates/*"]
|
||||||
|
exclude = ["**.ignore"]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
|
[workspace.package]
|
||||||
|
rust-version = "1.91.0"
|
||||||
|
edition = "2024"
|
||||||
|
version = "0.0.1"
|
||||||
|
|
||||||
|
[workspace.lints.rust]
|
||||||
|
unused_import_braces = "deny"
|
||||||
|
unit_bindings = "deny"
|
||||||
|
single_use_lifetimes = "deny"
|
||||||
|
non_ascii_idents = "deny"
|
||||||
|
macro_use_extern_crate = "deny"
|
||||||
|
elided_lifetimes_in_paths = "deny"
|
||||||
|
absolute_paths_not_starting_with_crate = "deny"
|
||||||
|
explicit_outlives_requirements = "warn"
|
||||||
|
unused_crate_dependencies = "warn"
|
||||||
|
redundant_lifetimes = "warn"
|
||||||
|
missing_docs = "allow"
|
||||||
|
|
||||||
|
[workspace.lints.clippy]
|
||||||
|
todo = "deny"
|
||||||
|
uninlined_format_args = "allow"
|
||||||
|
result_large_err = "allow"
|
||||||
|
too_many_arguments = "allow"
|
||||||
|
upper_case_acronyms = "deny"
|
||||||
|
needless_return = "allow"
|
||||||
|
new_without_default = "allow"
|
||||||
|
tabs_in_doc_comments = "allow"
|
||||||
|
dbg_macro = "deny"
|
||||||
|
allow_attributes = "deny"
|
||||||
|
create_dir = "deny"
|
||||||
|
filetype_is_file = "deny"
|
||||||
|
integer_division = "allow"
|
||||||
|
lossy_float_literal = "deny"
|
||||||
|
map_err_ignore = "deny"
|
||||||
|
mutex_atomic = "deny"
|
||||||
|
needless_raw_strings = "deny"
|
||||||
|
str_to_string = "deny"
|
||||||
|
string_add = "deny"
|
||||||
|
implicit_clone = "deny"
|
||||||
|
use_debug = "allow"
|
||||||
|
verbose_file_reads = "deny"
|
||||||
|
large_types_passed_by_value = "deny"
|
||||||
|
wildcard_dependencies = "deny"
|
||||||
|
negative_feature_names = "deny"
|
||||||
|
redundant_feature_names = "deny"
|
||||||
|
multiple_crate_versions = "allow"
|
||||||
|
missing_safety_doc = "warn"
|
||||||
|
identity_op = "allow"
|
||||||
|
print_stderr = "deny"
|
||||||
|
print_stdout = "deny"
|
||||||
|
comparison_chain = "allow"
|
||||||
|
unimplemented = "deny"
|
||||||
|
unwrap_used = "warn"
|
||||||
|
expect_used = "warn"
|
||||||
|
type_complexity = "allow"
|
||||||
|
obfuscated_if_else = "allow"
|
||||||
|
|
||||||
|
#
|
||||||
|
# MARK: dependencies
|
||||||
|
#
|
||||||
|
|
||||||
|
[workspace.dependencies]
|
||||||
|
tokenizer = { path = "crates/tokenizer" }
|
||||||
|
|
||||||
|
anstyle = "1.0.13"
|
||||||
|
anyhow = "1.0.100"
|
||||||
|
ahash = "0.8.12"
|
||||||
|
clap = { version = "4.5.49", features = ["derive"] }
|
||||||
|
compact_str = "0.9.0"
|
||||||
|
dary_heap = "0.3.8"
|
||||||
|
fancy-regex = "0.16.2"
|
||||||
|
indicatif = { version = "0.18.3", features = ["improved_unicode"] }
|
||||||
|
futures-util = "0.3.31"
|
||||||
|
ndarray = { version = "0.16.1", features = ["serde"] }
|
||||||
|
parking_lot = "0.12.5"
|
||||||
|
parquet = "56.2.0"
|
||||||
|
rayon = "1.11.0"
|
||||||
|
reqwest = { version = "0.12.24", features = ["json", "stream"] }
|
||||||
|
serde = "1.0.228"
|
||||||
|
serde_json = "1.0.145"
|
||||||
|
strum = { version = "0.27.2", features = ["derive"] }
|
||||||
|
thiserror = "2.0.17"
|
||||||
|
tokio = { version = "1.48.0", features = ["full"] }
|
||||||
|
tracing = "0.1.43"
|
||||||
|
tracing-indicatif = "0.3.13"
|
||||||
|
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
||||||
|
url = "2.5.7"
|
||||||
|
|
||||||
|
[workspace.dependencies.burn]
|
||||||
|
version = "0.19.1"
|
||||||
|
default-features = false
|
||||||
|
features = ["std", "fusion", "ndarray", "webgpu", "cuda"]
|
||||||
29
crates/llmfs/Cargo.toml
Normal file
29
crates/llmfs/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
[package]
|
||||||
|
name = "llmfs"
|
||||||
|
version = { workspace = true }
|
||||||
|
rust-version = { workspace = true }
|
||||||
|
edition = { workspace = true }
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
tokenizer = { workspace = true }
|
||||||
|
|
||||||
|
anstyle = { workspace = true }
|
||||||
|
anyhow = { workspace = true }
|
||||||
|
clap = { workspace = true }
|
||||||
|
futures-util = { workspace = true }
|
||||||
|
indicatif = { workspace = true }
|
||||||
|
parking_lot = { workspace = true }
|
||||||
|
parquet = { workspace = true }
|
||||||
|
rayon = { workspace = true }
|
||||||
|
reqwest = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
thiserror = { workspace = true }
|
||||||
|
tokio = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
tracing-indicatif = { workspace = true }
|
||||||
|
tracing-subscriber = { workspace = true }
|
||||||
|
url = { workspace = true }
|
||||||
21
crates/tokenizer/Cargo.toml
Normal file
21
crates/tokenizer/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "tokenizer"
|
||||||
|
version = { workspace = true }
|
||||||
|
rust-version = { workspace = true }
|
||||||
|
edition = { workspace = true }
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
|
||||||
|
ahash = { workspace = true }
|
||||||
|
compact_str = { workspace = true }
|
||||||
|
dary_heap = { workspace = true }
|
||||||
|
fancy-regex = { workspace = true }
|
||||||
|
rayon = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
strum = { workspace = true }
|
||||||
|
thiserror = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
indicatif = { workspace = true }
|
||||||
24
crates/tokenizer/src/lib.rs
Normal file
24
crates/tokenizer/src/lib.rs
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
mod split;
|
||||||
|
mod tokenizer;
|
||||||
|
|
||||||
|
use indicatif::ProgressStyle;
|
||||||
|
pub use tokenizer::*;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests_split;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests_tokenizer;
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
pub(crate) fn progress_big() -> ProgressStyle {
|
||||||
|
return ProgressStyle::default_bar()
|
||||||
|
.template(
|
||||||
|
" {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim} ({eta})",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.progress_chars("=>-")
|
||||||
|
.tick_strings(&[
|
||||||
|
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
|
||||||
|
]);
|
||||||
|
}
|
||||||
25
crates/tokenizer/src/split.rs
Normal file
25
crates/tokenizer/src/split.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
use fancy_regex::Regex;
|
||||||
|
|
||||||
|
/// Split text using a regex while keeping both the matched parts and the parts between matches
|
||||||
|
pub fn regex_segment<'a>(re: &Regex, text: &'a str) -> Vec<&'a str> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
let mut last = 0;
|
||||||
|
|
||||||
|
for mat in re.find_iter(text) {
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let mat = mat.unwrap();
|
||||||
|
if mat.start() > last {
|
||||||
|
result.push(&text[last..mat.start()]);
|
||||||
|
}
|
||||||
|
result.push(mat.as_str());
|
||||||
|
last = mat.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
if last < text.len() {
|
||||||
|
result.push(&text[last..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.retain(|x| !x.is_empty());
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
234
crates/tokenizer/src/tests_split.rs
Normal file
234
crates/tokenizer/src/tests_split.rs
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
use fancy_regex::Regex;
|
||||||
|
|
||||||
|
use crate::{Tokenizer, split::regex_segment};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = "apple,banana;cherry";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["apple", ",", "banana", ";", "cherry"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_string() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = "";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, Vec::<&str>::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_matches() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = "apple banana cherry";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["apple banana cherry"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn only_matches() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = ",;,";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec![",", ";", ","]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn starts_with_match() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = ",apple;banana";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec![",", "apple", ";", "banana"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ends_with_match() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = "apple,banana;";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["apple", ",", "banana", ";"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn consecutive_matches() {
|
||||||
|
let re = Regex::new(r"[,;]").unwrap();
|
||||||
|
let text = "apple,,banana";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["apple", ",", ",", "banana"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn word_boundaries() {
|
||||||
|
let re = Regex::new(r"\b").unwrap();
|
||||||
|
let text = "hello world";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
// Word boundaries are zero-width, so we get empty matches between word chars and non-word chars
|
||||||
|
assert_eq!(result, vec!["hello", " ", "world"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn digits() {
|
||||||
|
let re = Regex::new(r"\d+").unwrap();
|
||||||
|
let text = "abc123def456ghi";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["abc", "123", "def", "456", "ghi"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn special_tokens() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "Hello <|user_start|>world<|user_end|> test";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec!["Hello ", "<|user_start|>", "world", "<|user_end|>", " test"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unicode() {
|
||||||
|
let re = Regex::new(r"[=<3D>=<3D>]+").unwrap();
|
||||||
|
let text = "Hello=<3D>world=<3D>test";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["Hello", "=<3D>", "world", "=<3D>", "test"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn single_char() {
|
||||||
|
let re = Regex::new(r"x").unwrap();
|
||||||
|
let text = "x";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["x"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multichar_match() {
|
||||||
|
let re = Regex::new(r"abc").unwrap();
|
||||||
|
let text = "123abc456abc789";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["123", "abc", "456", "abc", "789"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bos_token() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "<|bos|>This is a document";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["<|bos|>", "This is a document"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversation_flow() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi there!<|assistant_end|>";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![
|
||||||
|
"<|user_start|>",
|
||||||
|
"Hello",
|
||||||
|
"<|user_end|>",
|
||||||
|
"<|assistant_start|>",
|
||||||
|
"Hi there!",
|
||||||
|
"<|assistant_end|>"
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn python_code_block() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "Code: <|python_start|>print('hello')<|python_end|> Output: <|output_start|>hello<|output_end|>";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![
|
||||||
|
"Code: ",
|
||||||
|
"<|python_start|>",
|
||||||
|
"print('hello')",
|
||||||
|
"<|python_end|>",
|
||||||
|
" Output: ",
|
||||||
|
"<|output_start|>",
|
||||||
|
"hello",
|
||||||
|
"<|output_end|>"
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mixed_special_tokens() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text =
|
||||||
|
"<|bos|><|user_start|>Question<|user_end|><|assistant_start|>Answer<|assistant_end|>";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![
|
||||||
|
"<|bos|>",
|
||||||
|
"<|user_start|>",
|
||||||
|
"Question",
|
||||||
|
"<|user_end|>",
|
||||||
|
"<|assistant_start|>",
|
||||||
|
"Answer",
|
||||||
|
"<|assistant_end|>"
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_special_tokens() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "This is just regular text with no special tokens";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec!["This is just regular text with no special tokens"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn malformed_special_tokens() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "This has <|invalid_token> and <user_start> which shouldn't match";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec!["This has <|invalid_token> and <user_start> which shouldn't match"]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn special_tokens_with_whitespace() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = " <|bos|> \n<|user_start|>\tHello\n<|user_end|> ";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(
|
||||||
|
result,
|
||||||
|
vec![
|
||||||
|
" ",
|
||||||
|
"<|bos|>",
|
||||||
|
" \n",
|
||||||
|
"<|user_start|>",
|
||||||
|
"\tHello\n",
|
||||||
|
"<|user_end|>",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn only_special_tokens() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "<|bos|><|user_start|><|user_end|>";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["<|bos|>", "<|user_start|>", "<|user_end|>"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nested() {
|
||||||
|
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
|
||||||
|
let text = "<|<|bos|>|>";
|
||||||
|
let result = regex_segment(&re, text);
|
||||||
|
assert_eq!(result, vec!["<|", "<|bos|>", "|>"]);
|
||||||
|
}
|
||||||
346
crates/tokenizer/src/tests_tokenizer.rs
Normal file
346
crates/tokenizer/src/tests_tokenizer.rs
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
use crate::{SpecialTokens, Tokenizer};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encode_decode() {
|
||||||
|
let tokenizer = Tokenizer::train(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
"hello world".to_string(),
|
||||||
|
"hello there".to_string(),
|
||||||
|
"hello hello world".to_string(),
|
||||||
|
]
|
||||||
|
.into_iter(),
|
||||||
|
300,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let test_cases = vec![
|
||||||
|
"hello",
|
||||||
|
"world",
|
||||||
|
"hello world",
|
||||||
|
"hello there world",
|
||||||
|
"abc123",
|
||||||
|
"",
|
||||||
|
];
|
||||||
|
|
||||||
|
for test_str in test_cases {
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded, "Failed roundtrip for: '{}'", test_str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: individual tests
|
||||||
|
//
|
||||||
|
|
||||||
|
fn get_test_tokenizer() -> Tokenizer {
|
||||||
|
Tokenizer::train(
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
"hello world".to_string(),
|
||||||
|
"<|bos|> test".to_string(),
|
||||||
|
"user: <|user_start|>hi<|user_end|>".to_string(),
|
||||||
|
]
|
||||||
|
.into_iter(),
|
||||||
|
300,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bos_token() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|bos|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert_eq!(
|
||||||
|
[256 + SpecialTokens::BeginningOfSequence.idx()],
|
||||||
|
&encoded[..]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_start_token() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_start|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert_eq!([256 + SpecialTokens::UserStart.idx()], &encoded[..]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_bos_start() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|bos";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
assert_eq!(5, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_bos_end() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "|bos|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
assert_eq!(6, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn malformed_user_star() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_star|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
assert_eq!(10, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_user_start_no_end() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_start";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
assert_eq!(9, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn partial_user_start_no_begin() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "user_start|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
assert_eq!(9, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nested_bos() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|<|bos|>|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
assert_eq!(5, encoded.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn double_nested_user() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<<||user_start||>>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversation_flow() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi<|assistant_end|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantEnd.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bos_with_control_chars() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|bos|>\n\t\r";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_with_emoji() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_start|>🚀💻<|user_end|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn python_code_block() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|python_start|>print('hello')\n<|python_end|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bos_with_spaces() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = " <|bos|> ";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_start_with_newlines() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "\n<|user_start|>\n";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn user_end_tab_assistant_start() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_end|>\t<|assistant_start|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::UserEnd.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::AssistantStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_angle_brackets() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_angle_brackets() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "|user_start|";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn regular_angle_brackets() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<user_start>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn double_pipe_user_start() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|user_start||>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn double_pipe_prefix() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<||user_start|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::UserStart.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normal_text_with_bos() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "Normal text <|bos|> more text";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn code_with_output() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str =
|
||||||
|
"Code: <|python_start|>x = 1<|python_end|> Result: <|output_start|>1<|output_end|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::PythonEnd.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::OutputStart.idx())));
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::OutputEnd.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escaped_bos() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "\\<|bos|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escaped_pipe_bos() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|bos\\|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escaped_s_bos() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "<|bo\\s|>";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(!encoded.contains(&(256 + SpecialTokens::BeginningOfSequence.idx())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_string() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = "";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
assert!(encoded.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spaces_only() {
|
||||||
|
let tokenizer = get_test_tokenizer();
|
||||||
|
let test_str = " ";
|
||||||
|
let encoded = tokenizer.encode(test_str);
|
||||||
|
let decoded = tokenizer.decode(&encoded);
|
||||||
|
assert_eq!(test_str, decoded);
|
||||||
|
}
|
||||||
720
crates/tokenizer/src/tokenizer.rs
Normal file
720
crates/tokenizer/src/tokenizer.rs
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
use ahash::{AHashMap, AHashSet};
|
||||||
|
use compact_str::CompactString;
|
||||||
|
use dary_heap::DaryHeap;
|
||||||
|
use fancy_regex::Regex;
|
||||||
|
use indicatif::{MultiProgress, ProgressBar};
|
||||||
|
use rayon::iter::{
|
||||||
|
IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
cmp::Ordering,
|
||||||
|
collections::{HashMap, VecDeque},
|
||||||
|
str::FromStr,
|
||||||
|
sync::atomic::AtomicUsize,
|
||||||
|
time::Instant,
|
||||||
|
};
|
||||||
|
use strum::{AsRefStr, EnumString};
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use crate::{progress_big, split::regex_segment};
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
// - maybe don't use regex
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, thiserror::Error)]
|
||||||
|
pub enum TokenizerTrainError {
|
||||||
|
#[error("vocab_size must be at least {is}, but it is {min}")]
|
||||||
|
InvalidVocabularySize { is: u32, min: u32 },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A pair of adjacent tokens
|
||||||
|
type Pair = (u32, u32);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, EnumString, AsRefStr)]
|
||||||
|
pub enum SpecialTokens {
|
||||||
|
/// every document begins with the Beginning of Sequence (BOS) token that delimits documents
|
||||||
|
#[strum(to_string = "<|bos|>")]
|
||||||
|
BeginningOfSequence,
|
||||||
|
|
||||||
|
// User messages
|
||||||
|
#[strum(to_string = "<|user_start|>")]
|
||||||
|
UserStart,
|
||||||
|
#[strum(to_string = "<|user_end|>")]
|
||||||
|
UserEnd,
|
||||||
|
|
||||||
|
// Assistant messages
|
||||||
|
#[strum(to_string = "<|assistant_start|>")]
|
||||||
|
AssistantStart,
|
||||||
|
#[strum(to_string = "<|assistant_end|>")]
|
||||||
|
AssistantEnd,
|
||||||
|
|
||||||
|
/// Assistant invokes python REPL tool
|
||||||
|
#[strum(to_string = "<|python_start|>")]
|
||||||
|
PythonStart,
|
||||||
|
#[strum(to_string = "<|python_end|>")]
|
||||||
|
PythonEnd,
|
||||||
|
|
||||||
|
// Python REPL outputs back to assistant
|
||||||
|
#[strum(to_string = "<|output_start|>")]
|
||||||
|
OutputStart,
|
||||||
|
#[strum(to_string = "<|output_end|>")]
|
||||||
|
OutputEnd,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SpecialTokens {
|
||||||
|
/// An array that contains every varaiant of this enum
|
||||||
|
pub const VARIANTS: &[Self] = &[
|
||||||
|
Self::BeginningOfSequence,
|
||||||
|
Self::UserStart,
|
||||||
|
Self::UserEnd,
|
||||||
|
Self::AssistantStart,
|
||||||
|
Self::AssistantEnd,
|
||||||
|
Self::PythonStart,
|
||||||
|
Self::PythonEnd,
|
||||||
|
Self::OutputStart,
|
||||||
|
Self::OutputEnd,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Get the index of this variant in [Self::VARIANTS]
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
#[inline]
|
||||||
|
pub fn idx(&self) -> u32 {
|
||||||
|
Self::VARIANTS.iter().position(|p| p == self).unwrap() as u32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: word
|
||||||
|
//
|
||||||
|
|
||||||
|
/// A segment of text that is undergoing tokenization.
|
||||||
|
///
|
||||||
|
/// `ids` contains the ids of the tokens this word consists of.
|
||||||
|
/// Initially, these are ascii u8s. More are added as we merge pairs.
|
||||||
|
///
|
||||||
|
/// This implementation of BPE does not cross word boundaries.
|
||||||
|
/// - one token cannot represent more than one word, which may be a feature.
|
||||||
|
/// - this keeps multi-byte unicode together, since the split regex handles it correctly.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Word {
|
||||||
|
ids: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Word {
|
||||||
|
#[inline]
|
||||||
|
fn new(ids: Vec<u32>) -> Self {
|
||||||
|
Self { ids }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an iterator over all adjacent pairs
|
||||||
|
/// of tokens in this word
|
||||||
|
#[inline]
|
||||||
|
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
|
||||||
|
self.ids.windows(2).map(|w| (w[0], w[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Merge all non-overlapping occurrences of pair -> new_id.
|
||||||
|
/// Returns a small Vec of local pair-count deltas for THIS word only:
|
||||||
|
/// -1 for removed pairs, +1 for newly created pairs.
|
||||||
|
///
|
||||||
|
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
|
||||||
|
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
|
||||||
|
let (a, b) = pair;
|
||||||
|
let n = self.ids.len();
|
||||||
|
if n < 2 {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out: Vec<u32> = Vec::with_capacity(n);
|
||||||
|
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < n {
|
||||||
|
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
|
||||||
|
let left = out.last().copied();
|
||||||
|
let right = if i + 2 < n {
|
||||||
|
Some(self.ids[i + 2])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// remove old pairs
|
||||||
|
if let Some(x) = left {
|
||||||
|
deltas.push(((x, a), -1));
|
||||||
|
deltas.push(((x, new_id), 1));
|
||||||
|
}
|
||||||
|
deltas.push(((a, b), -1));
|
||||||
|
if let Some(y) = right {
|
||||||
|
deltas.push(((b, y), -1));
|
||||||
|
deltas.push(((new_id, y), 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// write merged token
|
||||||
|
out.push(new_id);
|
||||||
|
i += 2; // skip 'a' and 'b'
|
||||||
|
} else {
|
||||||
|
out.push(self.ids[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ids = out;
|
||||||
|
deltas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: mergejob
|
||||||
|
//
|
||||||
|
|
||||||
|
#[derive(Debug, Eq)]
|
||||||
|
struct MergeJob {
|
||||||
|
pair: Pair,
|
||||||
|
count: u64,
|
||||||
|
|
||||||
|
/// set of word indices where this pair may occur and needs processing
|
||||||
|
pos: AHashSet<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for MergeJob {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.count == other.count && self.pair == other.pair
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for MergeJob {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for MergeJob {
|
||||||
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
|
self.count
|
||||||
|
.cmp(&other.count)
|
||||||
|
.then_with(|| other.pair.cmp(&self.pair))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: tokenizer
|
||||||
|
//
|
||||||
|
|
||||||
|
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||||
|
///
|
||||||
|
/// BPE's goal is to build compound tokens for a text corpus.
|
||||||
|
/// It starts with a simple initial state---in this case, each ascii u8
|
||||||
|
/// is a token. We then examine all adjacent pairs of tokens, replacing
|
||||||
|
/// common pairs with a new token id. This is repeated until our vocabulary
|
||||||
|
/// is as large as it needs to be.
|
||||||
|
pub struct Tokenizer {
|
||||||
|
/// Maps pairs of token IDs to their merged token ID
|
||||||
|
merges: HashMap<Pair, u32>,
|
||||||
|
|
||||||
|
/// Inverse of merges
|
||||||
|
unmerges: Vec<Pair>,
|
||||||
|
|
||||||
|
vocab_size: u32,
|
||||||
|
|
||||||
|
/// The regex pattern used for text splitting
|
||||||
|
split_regex: Regex,
|
||||||
|
special_regex: Regex,
|
||||||
|
|
||||||
|
/// Source of split_regex
|
||||||
|
/// (debug info)
|
||||||
|
split_regex_string: String,
|
||||||
|
|
||||||
|
/// The number of texts this tokenizer was trained on
|
||||||
|
/// (debug info)
|
||||||
|
n_train_texts: u64,
|
||||||
|
|
||||||
|
/// The number of words this tokenizer was trained on
|
||||||
|
/// (debug info)
|
||||||
|
n_train_words: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Tokenizer {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
use serde::ser::SerializeStruct;
|
||||||
|
|
||||||
|
let mut state = serializer.serialize_struct("Tokenizer", 4)?;
|
||||||
|
|
||||||
|
// Convert HashMap to Vec for serialization
|
||||||
|
// (only string keys are valid in json)
|
||||||
|
let merges_vec: Vec<(Pair, u32)> = self.merges.iter().map(|(&k, &v)| (k, v)).collect();
|
||||||
|
state.serialize_field("merges", &merges_vec)?;
|
||||||
|
|
||||||
|
state.serialize_field("split_regex_string", &self.split_regex_string)?;
|
||||||
|
state.serialize_field("n_train_texts", &self.n_train_texts)?;
|
||||||
|
state.serialize_field("n_train_words", &self.n_train_words)?;
|
||||||
|
state.serialize_field("vocab_size", &self.vocab_size)?;
|
||||||
|
state.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom deserializer that automatically compiles split_regex
|
||||||
|
impl<'de> Deserialize<'de> for Tokenizer {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TokenizerData {
|
||||||
|
merges: Vec<(Pair, u32)>,
|
||||||
|
split_regex_string: String,
|
||||||
|
n_train_texts: u64,
|
||||||
|
n_train_words: u64,
|
||||||
|
vocab_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = TokenizerData::deserialize(deserializer)?;
|
||||||
|
let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?;
|
||||||
|
|
||||||
|
let merges: HashMap<Pair, u32> = data.merges.into_iter().collect();
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)] // Special regex must be valid
|
||||||
|
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
|
||||||
|
|
||||||
|
Ok(Tokenizer {
|
||||||
|
unmerges: Self::reverse_merges(&merges),
|
||||||
|
merges,
|
||||||
|
|
||||||
|
split_regex,
|
||||||
|
split_regex_string: data.split_regex_string,
|
||||||
|
special_regex,
|
||||||
|
|
||||||
|
n_train_texts: data.n_train_texts,
|
||||||
|
n_train_words: data.n_train_words,
|
||||||
|
vocab_size: data.vocab_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
/// Default regex pattern for splitting text
|
||||||
|
const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
|
||||||
|
|
||||||
|
/// Regex for splitting special tokens
|
||||||
|
pub const SPECIAL_REGEX: &str = "<[|][a-z_]+[|]>";
|
||||||
|
|
||||||
|
/// Minimum size of this tokenizer's vocabulary.
|
||||||
|
/// - 256 for all `u8`s (token ids `0..=255`)
|
||||||
|
/// - `n` for special tokens (token ids `256..256+n`)
|
||||||
|
pub const MIN_VOCAB_SIZE: u32 = 256 + SpecialTokens::VARIANTS.len() as u32;
|
||||||
|
|
||||||
|
/// Reverse a merge map, returning an array of pairs.
|
||||||
|
/// `result[token_id - Self::MIN_VOCAB_SIZE`]` returns the pair that `token_id` represents.
|
||||||
|
fn reverse_merges(merges: &HashMap<Pair, u32>) -> Vec<Pair> {
|
||||||
|
let iter = merges.iter().map(|x| (*x.0, *x.1));
|
||||||
|
let mut array = Vec::from_iter(iter);
|
||||||
|
array.sort_by_key(|x| x.1);
|
||||||
|
|
||||||
|
let mut unmerges = Vec::with_capacity(array.len());
|
||||||
|
for (i, p) in array.iter().enumerate() {
|
||||||
|
assert_eq!(p.1, i as u32 + Self::MIN_VOCAB_SIZE);
|
||||||
|
unmerges.push(p.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
unmerges
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the regex pattern used to split words
|
||||||
|
#[inline]
|
||||||
|
pub fn get_regex(&self) -> &str {
|
||||||
|
&self.split_regex_string
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the total number of tokens this tokenizer can produce
|
||||||
|
#[inline]
|
||||||
|
pub fn vocab_size(&self) -> u32 {
|
||||||
|
self.vocab_size
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return this tokenizer's "beginning of seq" token
|
||||||
|
#[inline]
|
||||||
|
pub fn bos_token(&self) -> u32 {
|
||||||
|
256 + SpecialTokens::BeginningOfSequence.idx()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenize a string
|
||||||
|
pub fn encode(&self, text: &str) -> Vec<u32> {
|
||||||
|
let mut all_ids = Vec::new();
|
||||||
|
|
||||||
|
let special = regex_segment(&self.special_regex, text);
|
||||||
|
for s in special {
|
||||||
|
if let Ok(special) = SpecialTokens::from_str(s) {
|
||||||
|
all_ids.push(256 + special.idx());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for m in self.split_regex.find_iter(s) {
|
||||||
|
#[expect(clippy::unwrap_used)] // Shouldn't ever fail
|
||||||
|
let word = m.unwrap().as_str();
|
||||||
|
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Apply merges
|
||||||
|
while word_tokens.len() >= 2 {
|
||||||
|
// Merge the pair with the largest token idx
|
||||||
|
// (pair_start_idx, replace_with)
|
||||||
|
let mut best_pair: Option<(usize, u32)> = None;
|
||||||
|
|
||||||
|
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
|
||||||
|
let new_id = match self.merges.get(&pair) {
|
||||||
|
None => continue,
|
||||||
|
Some(x) => *x,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
|
||||||
|
best_pair = Some((i, new_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match best_pair {
|
||||||
|
Some((idx, new_id)) => {
|
||||||
|
word_tokens[idx] = new_id;
|
||||||
|
word_tokens.remove(idx + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
None => {
|
||||||
|
// No merges possible
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids.extend(word_tokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode an array of tokens
|
||||||
|
pub fn decode(&self, tokens: &[u32]) -> String {
|
||||||
|
let mut str_bytes = Vec::new();
|
||||||
|
let mut buffer = VecDeque::from_iter(tokens.iter().copied());
|
||||||
|
|
||||||
|
while let Some(t) = buffer.pop_front() {
|
||||||
|
// This is a byte
|
||||||
|
if t < 256 {
|
||||||
|
str_bytes.push(t as u8);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = t - 256;
|
||||||
|
|
||||||
|
// This is a special token
|
||||||
|
if let Some(t) = SpecialTokens::VARIANTS.get(t as usize) {
|
||||||
|
str_bytes.extend_from_slice(t.as_ref().as_bytes());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = t - SpecialTokens::VARIANTS.len() as u32;
|
||||||
|
|
||||||
|
// This is a compound token
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let pair = self.unmerges.get(t as usize).unwrap();
|
||||||
|
buffer.push_front(pair.1);
|
||||||
|
buffer.push_front(pair.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return String::from_utf8_lossy(&str_bytes).to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: main training code
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Given an array of words an an array of counts,
|
||||||
|
/// - count the number of occurrences of each token pair
|
||||||
|
/// - map each pair to the indices of the words that contain it
|
||||||
|
///
|
||||||
|
/// ## Notes:
|
||||||
|
/// - will panic if `words.len() != counts.len()`
|
||||||
|
/// - will not behave correctly if words are repeated
|
||||||
|
#[inline]
|
||||||
|
fn count_pairs(
|
||||||
|
words: &[Word],
|
||||||
|
counts: &[i32],
|
||||||
|
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
|
||||||
|
words
|
||||||
|
.par_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, w)| {
|
||||||
|
let mut pair_counts: AHashMap<Pair, i32> = AHashMap::new();
|
||||||
|
let mut word_map: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||||
|
|
||||||
|
if w.ids.len() >= 2 && counts[i] != 0 {
|
||||||
|
for (a, b) in w.pairs() {
|
||||||
|
*pair_counts.entry((a, b)).or_default() += counts[i];
|
||||||
|
word_map.entry((a, b)).or_default().insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(pair_counts, word_map)
|
||||||
|
})
|
||||||
|
.reduce(
|
||||||
|
|| (AHashMap::new(), AHashMap::new()),
|
||||||
|
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
|
||||||
|
for (k, v) in pc {
|
||||||
|
*acc_pc.entry(k).or_default() += v;
|
||||||
|
}
|
||||||
|
for (k, s) in wtu {
|
||||||
|
acc_wtu.entry(k).or_default().extend(s);
|
||||||
|
}
|
||||||
|
(acc_pc, acc_wtu)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Core incremental BPE training given unique words and their counts.
|
||||||
|
/// - `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
|
||||||
|
/// - `counts`: same length as `words`, count per chunk.
|
||||||
|
///
|
||||||
|
/// ## Notes:
|
||||||
|
/// - vocab size must be >= Self::MIN_VOCAB_SIZE
|
||||||
|
/// - will panic if `words.len() != counts.len()`
|
||||||
|
/// - will not behave correctly if words are repeated
|
||||||
|
fn train_core(
|
||||||
|
mp: Option<MultiProgress>,
|
||||||
|
mut words: Vec<Word>,
|
||||||
|
counts: Vec<i32>,
|
||||||
|
vocab_size: u32,
|
||||||
|
) -> HashMap<Pair, u32> {
|
||||||
|
assert!(
|
||||||
|
vocab_size >= Self::MIN_VOCAB_SIZE,
|
||||||
|
"vocab_size must be at least {}",
|
||||||
|
Self::MIN_VOCAB_SIZE
|
||||||
|
);
|
||||||
|
|
||||||
|
let num_merges = vocab_size - Self::MIN_VOCAB_SIZE;
|
||||||
|
let mut merges = HashMap::with_capacity(num_merges as usize);
|
||||||
|
info!(message = "Training tokenizer", num_merges);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
message = "Computing initial pair counts",
|
||||||
|
n_words = words.len()
|
||||||
|
);
|
||||||
|
let (mut pair_counts, mut where_to_update) = Self::count_pairs(&words, &counts);
|
||||||
|
|
||||||
|
let mut heap = {
|
||||||
|
debug!("Building heap with {} unique pairs", pair_counts.len());
|
||||||
|
|
||||||
|
let mut heap = DaryHeap::<_, 8>::with_capacity(pair_counts.len());
|
||||||
|
for (pair, pos) in where_to_update.drain() {
|
||||||
|
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
||||||
|
if c > 0 {
|
||||||
|
heap.push(MergeJob {
|
||||||
|
pair,
|
||||||
|
count: c as u64,
|
||||||
|
pos,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heap
|
||||||
|
};
|
||||||
|
|
||||||
|
let pb = mp.as_ref().map(|mp| {
|
||||||
|
let pb = mp.add(ProgressBar::new(num_merges as u64));
|
||||||
|
pb.set_style(progress_big());
|
||||||
|
pb.set_message("Computing merges");
|
||||||
|
pb
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut merges_done = 0u32;
|
||||||
|
|
||||||
|
while merges_done < num_merges {
|
||||||
|
let Some(mut top) = heap.pop() else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Lazy refresh
|
||||||
|
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
|
||||||
|
if top.count != current as u64 {
|
||||||
|
top.count = current as u64;
|
||||||
|
if top.count > 0 {
|
||||||
|
heap.push(top);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if top.count == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record merge
|
||||||
|
let new_id = Self::MIN_VOCAB_SIZE + merges_done;
|
||||||
|
merges.insert(top.pair, new_id);
|
||||||
|
|
||||||
|
// Merge this pair in all words where it occurs
|
||||||
|
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||||
|
for &word_idx in &top.pos {
|
||||||
|
// Apply merge to this word and collect pair-count deltas
|
||||||
|
let changes = words[word_idx].merge_pair(top.pair, new_id);
|
||||||
|
// Update global pair counts based on this word's count
|
||||||
|
for (pair, delta) in changes {
|
||||||
|
let delta_total = delta * counts[word_idx];
|
||||||
|
if delta_total != 0 {
|
||||||
|
*pair_counts.entry(pair).or_default() += delta_total;
|
||||||
|
if delta > 0 {
|
||||||
|
local_pos_updates.entry(pair).or_default().insert(word_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the updated pair counts back to the heap
|
||||||
|
for (pair, pos) in local_pos_updates {
|
||||||
|
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
|
||||||
|
if cnt > 0 {
|
||||||
|
heap.push(MergeJob {
|
||||||
|
pair,
|
||||||
|
count: cnt as u64,
|
||||||
|
pos,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
merges_done += 1;
|
||||||
|
|
||||||
|
if let Some(pb) = pb.as_ref() {
|
||||||
|
pb.set_position(merges_done as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pb) = pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
info!("Computed merges in {:.03?}", now.elapsed());
|
||||||
|
|
||||||
|
return merges;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: train
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Train a new tokenizer from an iterator of texts.
|
||||||
|
pub fn train<I>(
|
||||||
|
mp: Option<MultiProgress>,
|
||||||
|
iterator: I,
|
||||||
|
vocab_size: u32,
|
||||||
|
) -> Result<Self, TokenizerTrainError>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = String> + ExactSizeIterator,
|
||||||
|
I: ParallelBridge + Send,
|
||||||
|
{
|
||||||
|
if vocab_size < Self::MIN_VOCAB_SIZE {
|
||||||
|
return Err(TokenizerTrainError::InvalidVocabularySize {
|
||||||
|
is: vocab_size,
|
||||||
|
min: Self::MIN_VOCAB_SIZE,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let split_regex_string = Self::DEFAULT_REGEX.to_owned();
|
||||||
|
#[expect(clippy::unwrap_used)] // Default regex must be valid
|
||||||
|
let split_regex = Regex::new(&split_regex_string).unwrap();
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)] // Special regex must be valid
|
||||||
|
let special_regex = Regex::new(Self::SPECIAL_REGEX).unwrap();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let n_train_texts = iterator.len() as u64;
|
||||||
|
debug!("Counting words in {} texts", n_train_texts);
|
||||||
|
let pb = mp.as_ref().map(|mp| {
|
||||||
|
let pb = mp.add(ProgressBar::new(iterator.len() as u64));
|
||||||
|
pb.set_style(progress_big());
|
||||||
|
pb.set_message("Counting words");
|
||||||
|
pb
|
||||||
|
});
|
||||||
|
|
||||||
|
// Process texts in parallel and collect chunk counts
|
||||||
|
let counter = AtomicUsize::new(0);
|
||||||
|
let counts: AHashMap<CompactString, i32> = iterator
|
||||||
|
.par_bridge()
|
||||||
|
.map(|text| {
|
||||||
|
let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||||
|
|
||||||
|
let special = regex_segment(&special_regex, &text);
|
||||||
|
for s in special {
|
||||||
|
if let Ok(special) = SpecialTokens::from_str(s) {
|
||||||
|
*local_counts
|
||||||
|
.entry(CompactString::from(special.as_ref()))
|
||||||
|
.or_default() += 1;
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for mat in split_regex.find_iter(s) {
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let piece = mat.unwrap().as_str();
|
||||||
|
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if let Some(ref pb) = pb
|
||||||
|
&& count.is_multiple_of(1000)
|
||||||
|
{
|
||||||
|
pb.inc(1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
local_counts
|
||||||
|
})
|
||||||
|
.reduce(AHashMap::new, |mut acc, local_counts| {
|
||||||
|
for (k, v) in local_counts {
|
||||||
|
*acc.entry(k).or_default() += v;
|
||||||
|
}
|
||||||
|
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(pb) = pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
info!(
|
||||||
|
"Counted {} unique words in {:.03?}",
|
||||||
|
counts.len(),
|
||||||
|
now.elapsed()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Materialize words & counts
|
||||||
|
let mut words = Vec::with_capacity(counts.len());
|
||||||
|
let mut cvec = Vec::with_capacity(counts.len());
|
||||||
|
let mut n_train_words = 0u64;
|
||||||
|
for (chunk, c) in counts.into_iter() {
|
||||||
|
let token_ids = match SpecialTokens::from_str(&chunk).ok() {
|
||||||
|
Some(x) => vec![256 + x.idx()],
|
||||||
|
None => chunk
|
||||||
|
.as_bytes()
|
||||||
|
.iter()
|
||||||
|
.map(|&b| b as u32)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
words.push(Word::new(token_ids));
|
||||||
|
cvec.push(c);
|
||||||
|
n_train_words += c as u64;
|
||||||
|
}
|
||||||
|
|
||||||
|
let merges = Self::train_core(mp, words, cvec, vocab_size);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
split_regex,
|
||||||
|
split_regex_string,
|
||||||
|
special_regex,
|
||||||
|
n_train_texts,
|
||||||
|
n_train_words,
|
||||||
|
vocab_size: Self::MIN_VOCAB_SIZE + merges.len() as u32,
|
||||||
|
|
||||||
|
unmerges: Self::reverse_merges(&merges),
|
||||||
|
merges,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user