Compare commits
1 Commits
2ed4dc74ef
...
d5bf7ac5d1
Author | SHA1 | Date | |
---|---|---|---|
d5bf7ac5d1 |
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,2 +1,5 @@
|
|||||||
/target
|
/target
|
||||||
*.ignore
|
*.ignore
|
||||||
|
|
||||||
|
/data
|
||||||
|
/tokenizer.json
|
2889
Cargo.lock
generated
Normal file
2889
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
79
Cargo.toml
Normal file
79
Cargo.toml
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
[package]
|
||||||
|
name = "nanochat-rs"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[lints.rust]
|
||||||
|
unused_import_braces = "deny"
|
||||||
|
unit_bindings = "deny"
|
||||||
|
single_use_lifetimes = "deny"
|
||||||
|
non_ascii_idents = "deny"
|
||||||
|
macro_use_extern_crate = "deny"
|
||||||
|
elided_lifetimes_in_paths = "deny"
|
||||||
|
absolute_paths_not_starting_with_crate = "deny"
|
||||||
|
explicit_outlives_requirements = "warn"
|
||||||
|
unused_crate_dependencies = "warn"
|
||||||
|
redundant_lifetimes = "warn"
|
||||||
|
missing_docs = "allow"
|
||||||
|
|
||||||
|
[lints.clippy]
|
||||||
|
todo = "deny"
|
||||||
|
uninlined_format_args = "allow"
|
||||||
|
result_large_err = "allow"
|
||||||
|
too_many_arguments = "allow"
|
||||||
|
upper_case_acronyms = "deny"
|
||||||
|
needless_return = "allow"
|
||||||
|
new_without_default = "allow"
|
||||||
|
tabs_in_doc_comments = "allow"
|
||||||
|
dbg_macro = "deny"
|
||||||
|
allow_attributes = "deny"
|
||||||
|
create_dir = "deny"
|
||||||
|
filetype_is_file = "deny"
|
||||||
|
integer_division = "allow"
|
||||||
|
lossy_float_literal = "deny"
|
||||||
|
map_err_ignore = "deny"
|
||||||
|
mutex_atomic = "deny"
|
||||||
|
needless_raw_strings = "deny"
|
||||||
|
str_to_string = "deny"
|
||||||
|
string_add = "deny"
|
||||||
|
string_to_string = "deny"
|
||||||
|
use_debug = "allow"
|
||||||
|
verbose_file_reads = "deny"
|
||||||
|
large_types_passed_by_value = "deny"
|
||||||
|
wildcard_dependencies = "deny"
|
||||||
|
negative_feature_names = "deny"
|
||||||
|
redundant_feature_names = "deny"
|
||||||
|
multiple_crate_versions = "allow"
|
||||||
|
missing_safety_doc = "warn"
|
||||||
|
identity_op = "allow"
|
||||||
|
print_stderr = "deny"
|
||||||
|
print_stdout = "deny"
|
||||||
|
comparison_chain = "allow"
|
||||||
|
unimplemented = "deny"
|
||||||
|
unwrap_used = "warn"
|
||||||
|
expect_used = "warn"
|
||||||
|
type_complexity = "allow"
|
||||||
|
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
ahash = "0.8.12"
|
||||||
|
anstyle = "1.0.13"
|
||||||
|
anyhow = "1.0.100"
|
||||||
|
clap = { version = "4.5.49", features = ["derive"] }
|
||||||
|
compact_str = "0.9.0"
|
||||||
|
dary_heap = "0.3.8"
|
||||||
|
fancy-regex = "0.16.2"
|
||||||
|
futures-util = "0.3.31"
|
||||||
|
indicatif = { version = "0.18.0", features = ["improved_unicode"] }
|
||||||
|
parking_lot = "0.12.5"
|
||||||
|
parquet = "56.2.0"
|
||||||
|
rayon = "1.11.0"
|
||||||
|
reqwest = { version = "0.12.24", features = ["json", "stream"] }
|
||||||
|
serde = { version = "1.0.228", features = ["derive"] }
|
||||||
|
serde_json = "1.0.145"
|
||||||
|
thiserror = "2.0.17"
|
||||||
|
tokio = { version = "1.48.0", features = ["full"] }
|
||||||
|
tracing = "0.1.41"
|
||||||
|
tracing-indicatif = "0.3.13"
|
||||||
|
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
|
||||||
|
url = "2.5.7"
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
hard_tabs = true
|
64
src/cli.rs
Normal file
64
src/cli.rs
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
use anstyle::{AnsiColor, Color, Style};
|
||||||
|
use indicatif::ProgressStyle;
|
||||||
|
|
||||||
|
pub fn clap_styles() -> clap::builder::Styles {
|
||||||
|
clap::builder::Styles::styled()
|
||||||
|
.usage(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::Blue))),
|
||||||
|
)
|
||||||
|
.header(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::Blue))),
|
||||||
|
)
|
||||||
|
.literal(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::BrightBlack))),
|
||||||
|
)
|
||||||
|
.invalid(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::Red))),
|
||||||
|
)
|
||||||
|
.error(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::Red))),
|
||||||
|
)
|
||||||
|
.valid(
|
||||||
|
Style::new()
|
||||||
|
.bold()
|
||||||
|
.underline()
|
||||||
|
.fg_color(Some(Color::Ansi(AnsiColor::Green))),
|
||||||
|
)
|
||||||
|
.placeholder(Style::new().fg_color(Some(Color::Ansi(AnsiColor::White))))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
pub fn progress_big() -> ProgressStyle {
|
||||||
|
return ProgressStyle::default_bar()
|
||||||
|
.template(
|
||||||
|
" {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim}",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.progress_chars("=>-")
|
||||||
|
.tick_strings(&[
|
||||||
|
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
pub fn progress_bytes() -> ProgressStyle {
|
||||||
|
return ProgressStyle::default_bar()
|
||||||
|
.template(
|
||||||
|
" {bar:16.red/white.dim} {elapsed_precise:.dim} {bytes:>10}/{total_bytes:>10} {msg:.dim} ({eta})",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.progress_chars("---")
|
||||||
|
.tick_strings(&[
|
||||||
|
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
|
||||||
|
]);
|
||||||
|
}
|
280
src/command/download.rs
Normal file
280
src/command/download.rs
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use clap::Args;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use indicatif::MultiProgress;
|
||||||
|
use indicatif::ProgressBar;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::runtime::Runtime;
|
||||||
|
use tracing::{debug, error, info};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::cli::{progress_big, progress_bytes};
|
||||||
|
|
||||||
|
const BASE_URL: &str =
|
||||||
|
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main";
|
||||||
|
|
||||||
|
const MAX_SHARD: usize = 1822;
|
||||||
|
|
||||||
|
#[derive(Debug, Args, Clone)]
|
||||||
|
|
||||||
|
pub struct DownloadArgs {
|
||||||
|
/// Training data dir
|
||||||
|
#[clap(default_value = "data")]
|
||||||
|
data_dir: PathBuf,
|
||||||
|
|
||||||
|
/// Number of shards to download (-1 for all)
|
||||||
|
#[arg(short = 'n', long, default_value = "-1")]
|
||||||
|
num_files: isize,
|
||||||
|
|
||||||
|
/// Number of parallel downloads
|
||||||
|
#[arg(short = 't', long, default_value = "8")]
|
||||||
|
threads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DownloadArgs {
|
||||||
|
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
||||||
|
info!("Downloading files from {BASE_URL}");
|
||||||
|
fs::create_dir_all(&self.data_dir)?;
|
||||||
|
|
||||||
|
let num_shards_to_download = if self.num_files == -1 {
|
||||||
|
MAX_SHARD + 1
|
||||||
|
} else {
|
||||||
|
self.num_files.min((MAX_SHARD + 1) as isize) as usize
|
||||||
|
};
|
||||||
|
|
||||||
|
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
|
||||||
|
|
||||||
|
info!("Downloading {} shards...", ids_to_download.len(),);
|
||||||
|
info!("Target directory: {}", self.data_dir.display());
|
||||||
|
|
||||||
|
let main_pb = mp.as_ref().map(|mp| {
|
||||||
|
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
|
||||||
|
pb.set_style(progress_big());
|
||||||
|
pb.set_message("Downloading training data");
|
||||||
|
pb.enable_steady_tick(Duration::from_millis(100));
|
||||||
|
pb
|
||||||
|
});
|
||||||
|
|
||||||
|
let rt = Runtime::new()?;
|
||||||
|
|
||||||
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
|
let pool = rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(self.threads)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
pool.install(|| {
|
||||||
|
ids_to_download
|
||||||
|
.into_par_iter()
|
||||||
|
.for_each_with(tx, |tx, index| {
|
||||||
|
let target = self.data_dir.clone();
|
||||||
|
let main_pb = main_pb.clone();
|
||||||
|
let mp_clone = mp.clone();
|
||||||
|
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread
|
||||||
|
|
||||||
|
let result = rt_handle.block_on(async {
|
||||||
|
download_single_file(index, &target, main_pb, mp_clone).await
|
||||||
|
});
|
||||||
|
|
||||||
|
// Send the result back to the main thread for aggregation
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
tx.send(result).unwrap();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for all downloads to finish and collect results
|
||||||
|
let mut successful_downloads = 0;
|
||||||
|
for _ in 0..num_shards_to_download {
|
||||||
|
if let Ok(Ok(_)) = rx.recv() {
|
||||||
|
// Receive the Result<(), String> wrapped in a Result from MPSC
|
||||||
|
successful_downloads += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pb) = main_pb.as_ref() {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
info!("Downloads complete ({successful_downloads} successful)");
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_single_file(
|
||||||
|
index: usize,
|
||||||
|
target: &Path,
|
||||||
|
progress_bar: Option<ProgressBar>,
|
||||||
|
mp: Option<MultiProgress>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let filename = format!("shard_{:05}.parquet", index);
|
||||||
|
let filepath = target.join(&filename);
|
||||||
|
|
||||||
|
if filepath.exists() {
|
||||||
|
info!("Skipping {} (already exists)", filepath.display());
|
||||||
|
if let Some(pb) = progress_bar.as_ref() {
|
||||||
|
pb.inc(1);
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let url = Url::parse(&format!("{BASE_URL}/{filename}")).unwrap();
|
||||||
|
|
||||||
|
info!("Downloading {} from {}", filename, url);
|
||||||
|
|
||||||
|
let max_attempts = 5;
|
||||||
|
'attempt_loop: for attempt in 1..=max_attempts {
|
||||||
|
let temp_filepath = filepath.with_extension("parquet.tmp");
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
match client.get(url.clone()).send().await {
|
||||||
|
Ok(response) => {
|
||||||
|
if !response.status().is_success() {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Server responded with status {} for {}",
|
||||||
|
attempt,
|
||||||
|
max_attempts,
|
||||||
|
response.status(),
|
||||||
|
url
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_size = response.content_length().unwrap_or(0);
|
||||||
|
debug!("Total size for {}: {}", filename, total_size);
|
||||||
|
|
||||||
|
// Create file progress bar
|
||||||
|
let file_pb = if total_size > 0
|
||||||
|
&& let Some(mp) = mp.as_ref()
|
||||||
|
{
|
||||||
|
Some({
|
||||||
|
let pb = mp.add(ProgressBar::new(total_size));
|
||||||
|
pb.set_style(progress_bytes());
|
||||||
|
pb.set_message(format!("Downloading {}", filename));
|
||||||
|
pb
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut file = match tokio::fs::File::create(&temp_filepath).await {
|
||||||
|
Ok(file) => file,
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Failed to create temporary file {}: {}",
|
||||||
|
attempt,
|
||||||
|
max_attempts,
|
||||||
|
temp_filepath.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut downloaded: u64 = 0;
|
||||||
|
|
||||||
|
while let Some(chunk_result) = StreamExt::next(&mut stream).await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
match tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await {
|
||||||
|
Ok(_) => {
|
||||||
|
downloaded += chunk.len() as u64;
|
||||||
|
if let Some(pb) = &file_pb {
|
||||||
|
pb.set_position(downloaded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Failed to write to temporary file {}: {}",
|
||||||
|
attempt,
|
||||||
|
max_attempts,
|
||||||
|
temp_filepath.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
||||||
|
if let Some(pb) = &file_pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt)))
|
||||||
|
.await;
|
||||||
|
continue 'attempt_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Error reading chunk for {}: {}",
|
||||||
|
attempt, max_attempts, filename, e
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
||||||
|
if let Some(pb) = &file_pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
continue 'attempt_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pb) = &file_pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomically rename the temporary file
|
||||||
|
match tokio::fs::rename(&temp_filepath, &filepath).await {
|
||||||
|
Ok(_) => {
|
||||||
|
info!("Successfully downloaded {}", filename);
|
||||||
|
if let Some(pb) = progress_bar.as_ref() {
|
||||||
|
pb.inc(1);
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Failed to rename temporary file {} to {}: {}",
|
||||||
|
attempt,
|
||||||
|
max_attempts,
|
||||||
|
temp_filepath.display(),
|
||||||
|
filepath.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
let _ = tokio::fs::remove_file(&temp_filepath).await; // Clean up
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(
|
||||||
|
"Attempt {}/{}: Failed to download {}: {}",
|
||||||
|
attempt, max_attempts, filename, e
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clean up any partial files
|
||||||
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt < max_attempts {
|
||||||
|
let wait_time = 2u64.pow(attempt);
|
||||||
|
info!("Waiting {} seconds before retry...", wait_time);
|
||||||
|
tokio::time::sleep(Duration::from_secs(wait_time)).await;
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
"Failed to download {} after {} attempts",
|
||||||
|
filename, max_attempts
|
||||||
|
);
|
||||||
|
return Err(format!("Failed to download {}", filename));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(format!("Failed to download {}", filename))
|
||||||
|
}
|
26
src/command/mod.rs
Normal file
26
src/command/mod.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
mod download;
|
||||||
|
mod train_tokenizer;
|
||||||
|
|
||||||
|
#[derive(Debug, clap::Subcommand)]
|
||||||
|
pub enum SubCommand {
|
||||||
|
/// Download training set
|
||||||
|
Download {
|
||||||
|
#[command(flatten)]
|
||||||
|
args: download::DownloadArgs,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Train tokenizer
|
||||||
|
TrainTokenizer {
|
||||||
|
#[command(flatten)]
|
||||||
|
args: train_tokenizer::TrainTokenizerArgs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SubCommand {
|
||||||
|
pub fn run(self, mp: Option<indicatif::MultiProgress>) -> anyhow::Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Download { args } => args.run(mp),
|
||||||
|
Self::TrainTokenizer { args } => args.run(mp),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
57
src/command/train_tokenizer.rs
Normal file
57
src/command/train_tokenizer.rs
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use clap::Args;
|
||||||
|
use indicatif::MultiProgress;
|
||||||
|
use rayon::ThreadPoolBuilder;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::data_reader::DataReader;
|
||||||
|
use crate::tokenizer::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Debug, Args, Clone)]
|
||||||
|
|
||||||
|
pub struct TrainTokenizerArgs {
|
||||||
|
/// Path to training data
|
||||||
|
#[clap(default_value = "data")]
|
||||||
|
data_dir: PathBuf,
|
||||||
|
|
||||||
|
/// Where to save tokenizer
|
||||||
|
#[clap(default_value = "tokenizer.json")]
|
||||||
|
target: PathBuf,
|
||||||
|
|
||||||
|
/// Only train on the first n texts
|
||||||
|
#[clap(long)]
|
||||||
|
first_n: Option<usize>,
|
||||||
|
|
||||||
|
/// Number of threads to use for training
|
||||||
|
#[clap(long, default_value = "0")]
|
||||||
|
threads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TrainTokenizerArgs {
|
||||||
|
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
||||||
|
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)] // Lazy error handling
|
||||||
|
let iter = iter.map(|x| x.unwrap());
|
||||||
|
|
||||||
|
let pool = ThreadPoolBuilder::new()
|
||||||
|
.num_threads(self.threads)
|
||||||
|
.build()
|
||||||
|
.context("while building thread pool")?;
|
||||||
|
|
||||||
|
let tokenizer = pool
|
||||||
|
.install(|| match self.first_n {
|
||||||
|
Some(n) => Tokenizer::train(mp, iter.take(n), 1024),
|
||||||
|
None => Tokenizer::train(mp, iter, 1024),
|
||||||
|
})
|
||||||
|
.context("while training tokenizer")?;
|
||||||
|
|
||||||
|
info!("Saving to {}", self.target.display());
|
||||||
|
let target = File::create(&self.target).context("while creating tokenizer file")?;
|
||||||
|
serde_json::to_writer(target, &tokenizer).context("while saving tokenizer")?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
187
src/data_reader.rs
Normal file
187
src/data_reader.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use parking_lot::Mutex;
|
||||||
|
use parquet::errors::ParquetError;
|
||||||
|
use parquet::file::reader::{FileReader, SerializedFileReader};
|
||||||
|
use parquet::record::RowAccessor;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::mpsc::Receiver;
|
||||||
|
use std::sync::{Arc, mpsc};
|
||||||
|
use thiserror::Error;
|
||||||
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum DataReaderError {
|
||||||
|
#[error("i/o error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("parquet error: {0}")]
|
||||||
|
ParquetError(#[from] ParquetError),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A data reader that uses `n` threads to read
|
||||||
|
/// rows from a directory of parquet files.
|
||||||
|
///
|
||||||
|
/// All parquet files have exactly one text column.
|
||||||
|
/// No promises about this struct's behavior if this is not the case.
|
||||||
|
pub struct DataReader {
|
||||||
|
rx: Receiver<Result<String, DataReaderError>>,
|
||||||
|
total_rows: usize,
|
||||||
|
consumed_rows: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DataReader {
|
||||||
|
/// Create a new [DataReader] that reads all `*.parquet` files in the given directory.
|
||||||
|
/// All other files are ignored.
|
||||||
|
pub fn new(threads: usize, data_dir: impl AsRef<Path>) -> Result<Self, std::io::Error> {
|
||||||
|
let (files, total_rows) = {
|
||||||
|
let mut files = Vec::new();
|
||||||
|
let mut total_rows = 0usize;
|
||||||
|
let entries = std::fs::read_dir(data_dir)?;
|
||||||
|
for entry in entries {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
if path.extension().and_then(|s| s.to_str()) == Some("parquet") {
|
||||||
|
// Count rows in this file
|
||||||
|
let file = File::open(&path)?;
|
||||||
|
let reader = SerializedFileReader::new(file)
|
||||||
|
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||||
|
let metadata = reader.metadata();
|
||||||
|
|
||||||
|
for row_group_index in 0..metadata.num_row_groups() {
|
||||||
|
let row_group_metadata = metadata.row_group(row_group_index);
|
||||||
|
total_rows += row_group_metadata.num_rows() as usize;
|
||||||
|
}
|
||||||
|
|
||||||
|
files.push(path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(Arc::new(Mutex::new(files)), total_rows)
|
||||||
|
};
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let pool = rayon::ThreadPoolBuilder::new()
|
||||||
|
.num_threads(threads)
|
||||||
|
.thread_name(move |i| format!("DataReader-{i}-of-{threads}"))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
message = "Starting DataReader",
|
||||||
|
threads,
|
||||||
|
n_files = files.lock().len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let buffer = threads.min(10) * 200;
|
||||||
|
let (tx, rx) = mpsc::sync_channel::<Result<String, DataReaderError>>(buffer);
|
||||||
|
pool.spawn_broadcast(move |_ctx| {
|
||||||
|
let tx = tx.clone();
|
||||||
|
let files = files.clone();
|
||||||
|
|
||||||
|
'outer: loop {
|
||||||
|
let file = match files.lock().pop() {
|
||||||
|
None => break 'outer,
|
||||||
|
Some(x) => x,
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!("Reading rows from {}", file.display());
|
||||||
|
let file = match File::open(&file) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let reader = match SerializedFileReader::new(file) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let metadata = reader.metadata();
|
||||||
|
|
||||||
|
for row_group_index in 0..metadata.num_row_groups() {
|
||||||
|
let row_group_reader = match reader.get_row_group(row_group_index) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let row_iter = match row_group_reader.get_row_iter(None) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for row_result in row_iter {
|
||||||
|
let row = match row_result {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = match row.get_string(0) {
|
||||||
|
Ok(x) => x.clone(),
|
||||||
|
Err(err) => {
|
||||||
|
let _ = tx.send(Err(err.into()));
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if tx.send(Ok(text)).is_err() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
rx,
|
||||||
|
total_rows,
|
||||||
|
consumed_rows: AtomicUsize::new(0),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the next available row.
|
||||||
|
/// Order is arbitrary.
|
||||||
|
/// Returns `None` when all rows have been read.
|
||||||
|
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
|
||||||
|
self.rx.recv().ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {
|
||||||
|
// self.rx.try_recv()
|
||||||
|
//}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for DataReader {
|
||||||
|
type Item = Result<String, DataReaderError>;
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
match self.recv() {
|
||||||
|
Some(item) => {
|
||||||
|
self.consumed_rows.fetch_add(1, Ordering::Relaxed);
|
||||||
|
Some(item)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
let consumed = self.consumed_rows.load(Ordering::Relaxed);
|
||||||
|
let len = self.total_rows.saturating_sub(consumed);
|
||||||
|
return (len, Some(len));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExactSizeIterator for DataReader {}
|
349
src/logging.rs
Normal file
349
src/logging.rs
Normal file
@ -0,0 +1,349 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use indicatif::MultiProgress;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::{fmt::Display, str::FromStr};
|
||||||
|
use tracing_indicatif::IndicatifWriter;
|
||||||
|
use tracing_subscriber::{
|
||||||
|
EnvFilter, Layer, fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt,
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: loglevel
|
||||||
|
//
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
|
||||||
|
pub enum LogLevel {
|
||||||
|
Trace,
|
||||||
|
Debug,
|
||||||
|
Info,
|
||||||
|
Warn,
|
||||||
|
Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LogLevel {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Info
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for LogLevel {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Trace => write!(f, "trace"),
|
||||||
|
Self::Debug => write!(f, "debug"),
|
||||||
|
Self::Info => write!(f, "info"),
|
||||||
|
Self::Warn => write!(f, "warn"),
|
||||||
|
Self::Error => write!(f, "error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: logconfig
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Configures log levels for known sources
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
|
||||||
|
pub struct LoggingConfig {
|
||||||
|
pub other: LogLevel,
|
||||||
|
pub silence: LogLevel,
|
||||||
|
|
||||||
|
// Bins
|
||||||
|
pub nanochat: LogLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<LoggingConfig> for EnvFilter {
|
||||||
|
fn from(conf: LoggingConfig) -> Self {
|
||||||
|
// Should never fail
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
EnvFilter::from_str(
|
||||||
|
&[
|
||||||
|
//
|
||||||
|
// Silence
|
||||||
|
//
|
||||||
|
format!("hyper_util={}", conf.silence),
|
||||||
|
format!("h2={}", conf.silence),
|
||||||
|
format!("rustls={}", conf.silence),
|
||||||
|
format!("tower={}", conf.silence),
|
||||||
|
format!("html5ever={}", conf.silence),
|
||||||
|
format!("selectors={}", conf.silence),
|
||||||
|
format!("wgpu_core={}", conf.silence),
|
||||||
|
format!("naga={}", conf.silence),
|
||||||
|
format!("cubecl={}", conf.silence),
|
||||||
|
//
|
||||||
|
// Bins
|
||||||
|
//
|
||||||
|
format!("nanochat_rs={}", conf.nanochat),
|
||||||
|
conf.other.to_string(),
|
||||||
|
]
|
||||||
|
.join(","),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: LogCliVQ
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Provides global -v and -q cli arguments.
|
||||||
|
/// Use with `#[arg(flatten)]`.
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
pub struct LogCliVQ {
|
||||||
|
/// Increase verbosity (can be repeated)
|
||||||
|
#[arg(default_value = "0")]
|
||||||
|
#[arg(short, action = clap::ArgAction::Count,global = true)]
|
||||||
|
v: u8,
|
||||||
|
|
||||||
|
/// Decrease verbosity (can be repeated)
|
||||||
|
#[arg(default_value = "0")]
|
||||||
|
#[arg(short, action = clap::ArgAction::Count, global = true)]
|
||||||
|
q: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogCliVQ {
|
||||||
|
pub fn into_preset(self) -> LogFilterPreset {
|
||||||
|
let level_i: i16 = self.v as i16 - self.q as i16;
|
||||||
|
|
||||||
|
let preset;
|
||||||
|
if level_i <= -2 {
|
||||||
|
preset = LogFilterPreset::Error
|
||||||
|
} else if level_i == -1 {
|
||||||
|
preset = LogFilterPreset::Warn
|
||||||
|
} else if level_i == 0 {
|
||||||
|
preset = LogFilterPreset::Info
|
||||||
|
} else if level_i == 1 {
|
||||||
|
preset = LogFilterPreset::Debug
|
||||||
|
} else if level_i >= 2 {
|
||||||
|
preset = LogFilterPreset::Trace
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
|
||||||
|
return preset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: logpreset
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Provides preset configurations of [LoggingConfig]
|
||||||
|
#[derive(Debug, Deserialize, Clone, Copy)]
|
||||||
|
pub enum LogFilterPreset {
|
||||||
|
/// Standard "error" log level
|
||||||
|
Error,
|
||||||
|
|
||||||
|
/// Standard "warn" log level
|
||||||
|
Warn,
|
||||||
|
|
||||||
|
/// Standard "info" log level.
|
||||||
|
/// This is the default.
|
||||||
|
Info,
|
||||||
|
|
||||||
|
/// Standard "debug" log level
|
||||||
|
Debug,
|
||||||
|
|
||||||
|
/// Standard "trace" log level
|
||||||
|
Trace,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LogFilterPreset {
|
||||||
|
fn default() -> Self {
|
||||||
|
return Self::Info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<LogFilterPreset> for LoggingConfig {
|
||||||
|
fn from(val: LogFilterPreset) -> Self {
|
||||||
|
val.get_config()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogFilterPreset {
|
||||||
|
pub fn get_config(&self) -> LoggingConfig {
|
||||||
|
match self {
|
||||||
|
Self::Error => LoggingConfig {
|
||||||
|
other: LogLevel::Error,
|
||||||
|
silence: LogLevel::Error,
|
||||||
|
nanochat: LogLevel::Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
Self::Warn => LoggingConfig {
|
||||||
|
other: LogLevel::Warn,
|
||||||
|
silence: LogLevel::Warn,
|
||||||
|
nanochat: LogLevel::Warn,
|
||||||
|
},
|
||||||
|
|
||||||
|
Self::Info => LoggingConfig {
|
||||||
|
other: LogLevel::Warn,
|
||||||
|
silence: LogLevel::Warn,
|
||||||
|
nanochat: LogLevel::Info,
|
||||||
|
},
|
||||||
|
|
||||||
|
Self::Debug => LoggingConfig {
|
||||||
|
other: LogLevel::Warn,
|
||||||
|
silence: LogLevel::Warn,
|
||||||
|
nanochat: LogLevel::Debug,
|
||||||
|
},
|
||||||
|
|
||||||
|
Self::Trace => LoggingConfig {
|
||||||
|
other: LogLevel::Trace,
|
||||||
|
silence: LogLevel::Warn,
|
||||||
|
nanochat: LogLevel::Trace,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: initializer
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Where to print logs
|
||||||
|
#[expect(clippy::allow_attributes)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub enum LoggingTarget {
|
||||||
|
/// Send logs to stdout
|
||||||
|
Stdout { format: LoggingFormat },
|
||||||
|
|
||||||
|
/// Send logs to stderr
|
||||||
|
Stderr { format: LoggingFormat },
|
||||||
|
|
||||||
|
/// Send logs to an IndicatifWriter.
|
||||||
|
///
|
||||||
|
/// This is the same as Stderr { format: Ansi {color:true} },
|
||||||
|
/// but uses an indicatifwriter with the given multiprogress.
|
||||||
|
Indicatif(MultiProgress),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// How to print logs
|
||||||
|
#[derive(Debug, Clone, Copy, Deserialize)]
|
||||||
|
pub enum LoggingFormat {
|
||||||
|
Ansi,
|
||||||
|
AnsiNoColor,
|
||||||
|
Json,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for LoggingFormat {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Ansi
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LoggingInitializer {
|
||||||
|
/// Log filter for printed logs
|
||||||
|
pub preset: LogFilterPreset,
|
||||||
|
|
||||||
|
/// Where to print logs
|
||||||
|
pub target: LoggingTarget,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LoggingInitializer {
|
||||||
|
pub fn initialize(self) -> Result<()> {
|
||||||
|
let mut stderr_ansi_layer = None;
|
||||||
|
let mut stderr_json_layer = None;
|
||||||
|
let mut stdout_ansi_layer = None;
|
||||||
|
let mut stdout_json_layer = None;
|
||||||
|
let mut indicatif_layer = None;
|
||||||
|
match self.target {
|
||||||
|
LoggingTarget::Stderr {
|
||||||
|
format: LoggingFormat::Ansi,
|
||||||
|
} => {
|
||||||
|
stderr_ansi_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(true)
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Stderr {
|
||||||
|
format: LoggingFormat::AnsiNoColor,
|
||||||
|
} => {
|
||||||
|
stderr_ansi_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(false)
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Stderr {
|
||||||
|
format: LoggingFormat::Json,
|
||||||
|
} => {
|
||||||
|
stderr_json_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(false)
|
||||||
|
.json()
|
||||||
|
.with_writer(std::io::stderr)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Stdout {
|
||||||
|
format: LoggingFormat::Ansi,
|
||||||
|
} => {
|
||||||
|
stdout_ansi_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(true)
|
||||||
|
.with_writer(std::io::stdout)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Stdout {
|
||||||
|
format: LoggingFormat::AnsiNoColor,
|
||||||
|
} => {
|
||||||
|
stdout_ansi_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(false)
|
||||||
|
.with_writer(std::io::stdout)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Stdout {
|
||||||
|
format: LoggingFormat::Json,
|
||||||
|
} => {
|
||||||
|
stdout_json_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(false)
|
||||||
|
.json()
|
||||||
|
.with_writer(std::io::stdout)
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoggingTarget::Indicatif(mp) => {
|
||||||
|
let writer: IndicatifWriter<tracing_indicatif::writer::Stderr> =
|
||||||
|
IndicatifWriter::new(mp);
|
||||||
|
|
||||||
|
indicatif_layer = Some(
|
||||||
|
tracing_subscriber::fmt::Layer::default()
|
||||||
|
.without_time()
|
||||||
|
.with_ansi(true)
|
||||||
|
.with_writer(writer.make_writer())
|
||||||
|
.with_filter::<EnvFilter>(self.preset.get_config().into()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(stdout_ansi_layer)
|
||||||
|
.with(stdout_json_layer)
|
||||||
|
.with(stderr_ansi_layer)
|
||||||
|
.with(stderr_json_layer)
|
||||||
|
.with(indicatif_layer)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
63
src/main.rs
Normal file
63
src/main.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use clap::Parser;
|
||||||
|
use indicatif::MultiProgress;
|
||||||
|
use tracing::error;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
command::SubCommand,
|
||||||
|
logging::{LogCliVQ, LoggingFormat, LoggingInitializer, LoggingTarget},
|
||||||
|
};
|
||||||
|
|
||||||
|
mod cli;
|
||||||
|
mod command;
|
||||||
|
mod data_reader;
|
||||||
|
mod logging;
|
||||||
|
mod tokenizer;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(version, about, long_about = None, styles=crate::cli::clap_styles())]
|
||||||
|
struct Cli {
|
||||||
|
#[clap(flatten)]
|
||||||
|
vq: LogCliVQ,
|
||||||
|
|
||||||
|
/// If true, never show progress bars.
|
||||||
|
#[arg(long)]
|
||||||
|
noprogress: bool,
|
||||||
|
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: SubCommand,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let cli = Cli::parse();
|
||||||
|
let mp = (!cli.noprogress).then(MultiProgress::new);
|
||||||
|
|
||||||
|
{
|
||||||
|
let res = LoggingInitializer {
|
||||||
|
preset: cli.vq.into_preset(),
|
||||||
|
target: match mp.clone() {
|
||||||
|
None => LoggingTarget::Stderr {
|
||||||
|
format: LoggingFormat::Ansi,
|
||||||
|
},
|
||||||
|
Some(mp) => LoggingTarget::Indicatif(mp.clone()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
.initialize();
|
||||||
|
|
||||||
|
if let Err(e) = res {
|
||||||
|
#[expect(clippy::print_stderr)]
|
||||||
|
for e in e.chain() {
|
||||||
|
eprintln!("{e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = cli.command.run(mp) {
|
||||||
|
error!("Error: {e:?}");
|
||||||
|
for e in e.chain() {
|
||||||
|
error!(e);
|
||||||
|
}
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
543
src/tokenizer.rs
Normal file
543
src/tokenizer.rs
Normal file
@ -0,0 +1,543 @@
|
|||||||
|
use ahash::{AHashMap, AHashSet};
|
||||||
|
use anyhow::{Result, bail};
|
||||||
|
use compact_str::CompactString;
|
||||||
|
use dary_heap::DaryHeap;
|
||||||
|
use fancy_regex::Regex;
|
||||||
|
use indicatif::{MultiProgress, ProgressBar};
|
||||||
|
use rayon::iter::{
|
||||||
|
IndexedParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelIterator,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{cmp::Ordering, collections::HashMap, sync::atomic::AtomicUsize, time::Instant};
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use crate::cli::progress_big;
|
||||||
|
|
||||||
|
// Notes:
|
||||||
|
//
|
||||||
|
// ## Why ahash?
|
||||||
|
// See docs.
|
||||||
|
// Ahash provides a uses a higher performance hash fn,
|
||||||
|
// but is not resistant to DOS and thus cannot be used
|
||||||
|
// with untrusted input.
|
||||||
|
//
|
||||||
|
// ## Changes from original impl
|
||||||
|
// - idiomatic Rust
|
||||||
|
// - significantly less memory usage
|
||||||
|
|
||||||
|
/// A pair of adjacent tokens
|
||||||
|
type Pair = (u32, u32);
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: word
|
||||||
|
//
|
||||||
|
|
||||||
|
/// A segment of text that is undergoing tokenization.
|
||||||
|
///
|
||||||
|
/// `ids` contains the ids of the tokens this word consists of.
|
||||||
|
/// Initially, these are ascii u8s. More are added as we merge pairs.
|
||||||
|
///
|
||||||
|
/// This implementation of BPE does not cross word boundaries.
|
||||||
|
/// - one token cannot represent more than one word, which may be a feature.
|
||||||
|
/// - this keeps multi-byte unicode together, since the split regex handles it correctly.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct Word {
|
||||||
|
ids: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Word {
|
||||||
|
#[inline]
|
||||||
|
fn new(ids: Vec<u32>) -> Self {
|
||||||
|
Self { ids }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return an iterator over all adjacent pairs
|
||||||
|
/// of tokens in this word
|
||||||
|
#[inline]
|
||||||
|
fn pairs<'a>(&'a self) -> impl Iterator<Item = Pair> + 'a {
|
||||||
|
self.ids.windows(2).map(|w| (w[0], w[1]))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Merge all non-overlapping occurrences of pair -> new_id.
|
||||||
|
/// Returns a small Vec of local pair-count deltas for THIS word only:
|
||||||
|
/// -1 for removed pairs, +1 for newly created pairs.
|
||||||
|
///
|
||||||
|
/// NOTE: this version deliberately avoids a HashMap in the hot loop.
|
||||||
|
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
|
||||||
|
let (a, b) = pair;
|
||||||
|
let n = self.ids.len();
|
||||||
|
if n < 2 {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out: Vec<u32> = Vec::with_capacity(n);
|
||||||
|
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
while i < n {
|
||||||
|
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
|
||||||
|
let left = out.last().copied();
|
||||||
|
let right = if i + 2 < n {
|
||||||
|
Some(self.ids[i + 2])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// remove old pairs
|
||||||
|
if let Some(x) = left {
|
||||||
|
deltas.push(((x, a), -1));
|
||||||
|
deltas.push(((x, new_id), 1));
|
||||||
|
}
|
||||||
|
deltas.push(((a, b), -1));
|
||||||
|
if let Some(y) = right {
|
||||||
|
deltas.push(((b, y), -1));
|
||||||
|
deltas.push(((new_id, y), 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// write merged token
|
||||||
|
out.push(new_id);
|
||||||
|
i += 2; // skip 'a' and 'b'
|
||||||
|
} else {
|
||||||
|
out.push(self.ids[i]);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.ids = out;
|
||||||
|
deltas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: mergejob
|
||||||
|
//
|
||||||
|
|
||||||
|
#[derive(Debug, Eq)]
|
||||||
|
struct MergeJob {
|
||||||
|
pair: Pair,
|
||||||
|
count: u64,
|
||||||
|
|
||||||
|
/// set of word indices where this pair may occur and needs processing
|
||||||
|
pos: AHashSet<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for MergeJob {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.count == other.count && self.pair == other.pair
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for MergeJob {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
Some(self.cmp(other))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ord for MergeJob {
|
||||||
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
|
self.count
|
||||||
|
.cmp(&other.count)
|
||||||
|
.then_with(|| other.pair.cmp(&self.pair))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: tokenizer
|
||||||
|
//
|
||||||
|
|
||||||
|
/// A Byte Pair Encoding tokenizer that matches the GPT-4 style implementation
|
||||||
|
///
|
||||||
|
/// BPE's goal is to build compound tokens for a text corpus.
|
||||||
|
/// It starts with a simple initial state---in this case, each ascii u8
|
||||||
|
/// is a token. We then examine all adjacent pairs of tokens, replacing
|
||||||
|
/// common pairs with a new token id. This is repeated until our vocabulary
|
||||||
|
/// is as large as it needs to be.
|
||||||
|
pub struct Tokenizer {
|
||||||
|
/// Maps pairs of token IDs to their merged token ID
|
||||||
|
merges: HashMap<Pair, u32>,
|
||||||
|
|
||||||
|
n_tokens: u32,
|
||||||
|
|
||||||
|
/// The regex pattern used for text splitting
|
||||||
|
#[expect(dead_code)]
|
||||||
|
split_regex: Regex,
|
||||||
|
|
||||||
|
/// Source of split_regex
|
||||||
|
/// (debug info)
|
||||||
|
split_regex_string: String,
|
||||||
|
|
||||||
|
/// The number of texts this tokenizer was trained on
|
||||||
|
/// (debug info)
|
||||||
|
n_train_texts: u64,
|
||||||
|
|
||||||
|
/// The number of words this tokenizer was trained on
|
||||||
|
/// (debug info)
|
||||||
|
n_train_words: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for Tokenizer {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
use serde::ser::SerializeStruct;
|
||||||
|
|
||||||
|
let mut state = serializer.serialize_struct("Tokenizer", 4)?;
|
||||||
|
|
||||||
|
// Convert HashMap to Vec for serialization
|
||||||
|
// (only string keys are valid in json)
|
||||||
|
let merges_vec: Vec<(Pair, u32)> = self.merges.iter().map(|(&k, &v)| (k, v)).collect();
|
||||||
|
state.serialize_field("merges", &merges_vec)?;
|
||||||
|
|
||||||
|
state.serialize_field("split_regex_string", &self.split_regex_string)?;
|
||||||
|
state.serialize_field("n_train_texts", &self.n_train_texts)?;
|
||||||
|
state.serialize_field("n_train_words", &self.n_train_words)?;
|
||||||
|
state.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom deserializer that automatically compiles split_regex
|
||||||
|
impl<'de> Deserialize<'de> for Tokenizer {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TokenizerData {
|
||||||
|
merges: Vec<(Pair, u32)>,
|
||||||
|
split_regex_string: String,
|
||||||
|
n_train_texts: u64,
|
||||||
|
n_train_words: u64,
|
||||||
|
n_tokens: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = TokenizerData::deserialize(deserializer)?;
|
||||||
|
let split_regex = Regex::new(&data.split_regex_string).map_err(serde::de::Error::custom)?;
|
||||||
|
let merges: HashMap<Pair, u32> = data.merges.into_iter().collect();
|
||||||
|
|
||||||
|
Ok(Tokenizer {
|
||||||
|
merges,
|
||||||
|
split_regex,
|
||||||
|
split_regex_string: data.split_regex_string,
|
||||||
|
n_train_texts: data.n_train_texts,
|
||||||
|
n_train_words: data.n_train_words,
|
||||||
|
n_tokens: data.n_tokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tokenizer {
|
||||||
|
/// Default regex pattern for splitting text
|
||||||
|
const DEFAULT_REGEX: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
|
||||||
|
|
||||||
|
/// Return the regex pattern used to split words
|
||||||
|
#[inline]
|
||||||
|
pub fn get_regex(&self) -> &str {
|
||||||
|
&self.split_regex_string
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenize a string
|
||||||
|
pub fn tokenize(&self, text: &str) -> Vec<u32> {
|
||||||
|
let mut all_ids = Vec::new();
|
||||||
|
|
||||||
|
for m in self.split_regex.find_iter(text) {
|
||||||
|
#[expect(clippy::unwrap_used)] // Shouldn't ever fail
|
||||||
|
let word = m.unwrap().as_str();
|
||||||
|
let mut word_tokens = word.bytes().map(|b| b as u32).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Apply merges
|
||||||
|
while word_tokens.len() >= 2 {
|
||||||
|
// Merge the pair with the largest token idx
|
||||||
|
// (pair_start_idx, replace_with)
|
||||||
|
let mut best_pair: Option<(usize, u32)> = None;
|
||||||
|
|
||||||
|
for (i, pair) in word_tokens.windows(2).map(|x| (x[0], x[1])).enumerate() {
|
||||||
|
let new_id = match self.merges.get(&pair) {
|
||||||
|
None => continue,
|
||||||
|
Some(x) => *x,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
if best_pair.is_none() || new_id < best_pair.unwrap().1 {
|
||||||
|
best_pair = Some((i, new_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match best_pair {
|
||||||
|
Some((idx, new_id)) => {
|
||||||
|
word_tokens[idx] = new_id;
|
||||||
|
word_tokens.remove(idx + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
None => {
|
||||||
|
// No merges possible
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids.extend(word_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: main training code
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO: pool config
|
||||||
|
|
||||||
|
/// Given an array of words an an array of counts,
|
||||||
|
/// - count the number of occurrences of each token pair
|
||||||
|
/// - map each pair to the indices of the words that contain it
|
||||||
|
///
|
||||||
|
/// ## Notes:
|
||||||
|
/// - will panic if `words.len() != counts.len()`
|
||||||
|
/// - will not behave correctly if words are repeated
|
||||||
|
#[inline]
|
||||||
|
fn count_pairs(
|
||||||
|
words: &[Word],
|
||||||
|
counts: &[i32],
|
||||||
|
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
|
||||||
|
words
|
||||||
|
.par_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, w)| {
|
||||||
|
let mut pair_counts: AHashMap<Pair, i32> = AHashMap::new();
|
||||||
|
let mut word_map: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||||
|
|
||||||
|
if w.ids.len() >= 2 && counts[i] != 0 {
|
||||||
|
for (a, b) in w.pairs() {
|
||||||
|
*pair_counts.entry((a, b)).or_default() += counts[i];
|
||||||
|
word_map.entry((a, b)).or_default().insert(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(pair_counts, word_map)
|
||||||
|
})
|
||||||
|
.reduce(
|
||||||
|
|| (AHashMap::new(), AHashMap::new()),
|
||||||
|
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
|
||||||
|
for (k, v) in pc {
|
||||||
|
*acc_pc.entry(k).or_default() += v;
|
||||||
|
}
|
||||||
|
for (k, s) in wtu {
|
||||||
|
acc_wtu.entry(k).or_default().extend(s);
|
||||||
|
}
|
||||||
|
(acc_pc, acc_wtu)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Core incremental BPE training given unique words and their counts.
|
||||||
|
/// - `words`: one entry per unique chunk (Vec<u32> of token-ids/bytes).
|
||||||
|
/// - `counts`: same length as `words`, count per chunk.
|
||||||
|
///
|
||||||
|
/// ## Notes:
|
||||||
|
/// - vocab size must be >= 256
|
||||||
|
/// - will panic if `words.len() != counts.len()`
|
||||||
|
/// - will not behave correctly if words are repeated
|
||||||
|
fn train_core(
|
||||||
|
mp: Option<MultiProgress>,
|
||||||
|
mut words: Vec<Word>,
|
||||||
|
counts: Vec<i32>,
|
||||||
|
vocab_size: u32,
|
||||||
|
) -> HashMap<Pair, u32> {
|
||||||
|
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
||||||
|
|
||||||
|
let num_merges = vocab_size - 256;
|
||||||
|
let mut merges = HashMap::with_capacity(num_merges as usize);
|
||||||
|
info!(message = "Training tokenizer", num_merges);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
message = "Computing initial pair counts",
|
||||||
|
n_words = words.len()
|
||||||
|
);
|
||||||
|
let (mut pair_counts, mut where_to_update) = Self::count_pairs(&words, &counts);
|
||||||
|
|
||||||
|
let mut heap = {
|
||||||
|
debug!("Building heap with {} unique pairs", pair_counts.len());
|
||||||
|
|
||||||
|
let mut heap = DaryHeap::<_, 8>::with_capacity(pair_counts.len());
|
||||||
|
for (pair, pos) in where_to_update.drain() {
|
||||||
|
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
||||||
|
if c > 0 {
|
||||||
|
heap.push(MergeJob {
|
||||||
|
pair,
|
||||||
|
count: c as u64,
|
||||||
|
pos,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heap
|
||||||
|
};
|
||||||
|
|
||||||
|
let pb = mp.as_ref().map(|mp| {
|
||||||
|
let pb = mp.add(ProgressBar::new(num_merges as u64));
|
||||||
|
pb.set_style(progress_big());
|
||||||
|
pb.set_message("Computing merges");
|
||||||
|
pb
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut merges_done = 0u32;
|
||||||
|
|
||||||
|
while merges_done < num_merges {
|
||||||
|
let Some(mut top) = heap.pop() else {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Lazy refresh
|
||||||
|
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
|
||||||
|
if top.count != current as u64 {
|
||||||
|
top.count = current as u64;
|
||||||
|
if top.count > 0 {
|
||||||
|
heap.push(top);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if top.count == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record merge
|
||||||
|
let new_id = 256 + merges_done;
|
||||||
|
merges.insert(top.pair, new_id);
|
||||||
|
|
||||||
|
// Merge this pair in all words where it occurs
|
||||||
|
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
|
||||||
|
for &word_idx in &top.pos {
|
||||||
|
// Apply merge to this word and collect pair-count deltas
|
||||||
|
let changes = words[word_idx].merge_pair(top.pair, new_id);
|
||||||
|
// Update global pair counts based on this word's count
|
||||||
|
for (pair, delta) in changes {
|
||||||
|
let delta_total = delta * counts[word_idx];
|
||||||
|
if delta_total != 0 {
|
||||||
|
*pair_counts.entry(pair).or_default() += delta_total;
|
||||||
|
if delta > 0 {
|
||||||
|
local_pos_updates.entry(pair).or_default().insert(word_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the updated pair counts back to the heap
|
||||||
|
for (pair, pos) in local_pos_updates {
|
||||||
|
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
|
||||||
|
if cnt > 0 {
|
||||||
|
heap.push(MergeJob {
|
||||||
|
pair,
|
||||||
|
count: cnt as u64,
|
||||||
|
pos,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
merges_done += 1;
|
||||||
|
|
||||||
|
if let Some(pb) = pb.as_ref() {
|
||||||
|
pb.set_position(merges_done as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pb) = pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
info!("Computed merges in {:.03?}", now.elapsed());
|
||||||
|
|
||||||
|
return merges;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MARK: train
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Train a new tokenizer from an iterator of texts.
|
||||||
|
pub fn train<I>(mp: Option<MultiProgress>, iterator: I, vocab_size: u32) -> Result<Self>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = String> + ExactSizeIterator,
|
||||||
|
I: ParallelBridge + Send,
|
||||||
|
{
|
||||||
|
if vocab_size < 256 {
|
||||||
|
bail!("vocab_size must be at least 256, but it is {vocab_size}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let split_regex_string = Self::DEFAULT_REGEX.to_owned();
|
||||||
|
#[expect(clippy::unwrap_used)] // Default regex must be valid
|
||||||
|
let split_regex = Regex::new(&split_regex_string).unwrap();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let n_train_texts = iterator.len() as u64;
|
||||||
|
debug!("Counting words in {} texts", n_train_texts);
|
||||||
|
let pb = mp.as_ref().map(|mp| {
|
||||||
|
let pb = mp.add(ProgressBar::new(iterator.len() as u64));
|
||||||
|
pb.set_style(progress_big());
|
||||||
|
pb.set_message("Counting words");
|
||||||
|
pb
|
||||||
|
});
|
||||||
|
|
||||||
|
// Process texts in parallel and collect chunk counts
|
||||||
|
let counter = AtomicUsize::new(0);
|
||||||
|
let counts: AHashMap<CompactString, i32> = iterator
|
||||||
|
.par_bridge()
|
||||||
|
.map(|text| {
|
||||||
|
let mut local_counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||||
|
for mat in split_regex.find_iter(&text) {
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
let piece = mat.unwrap().as_str();
|
||||||
|
|
||||||
|
*local_counts.entry(CompactString::from(piece)).or_default() += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if let Some(ref pb) = pb {
|
||||||
|
if count % 1000 == 0 {
|
||||||
|
pb.inc(1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
local_counts
|
||||||
|
})
|
||||||
|
.reduce(AHashMap::new, |mut acc, local_counts| {
|
||||||
|
for (k, v) in local_counts {
|
||||||
|
*acc.entry(k).or_default() += v;
|
||||||
|
}
|
||||||
|
|
||||||
|
acc
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(pb) = pb {
|
||||||
|
pb.finish_and_clear();
|
||||||
|
}
|
||||||
|
info!(
|
||||||
|
"Counted {} unique words in {:.03?}",
|
||||||
|
counts.len(),
|
||||||
|
now.elapsed()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Materialize words & counts
|
||||||
|
let mut words = Vec::with_capacity(counts.len());
|
||||||
|
let mut cvec = Vec::with_capacity(counts.len());
|
||||||
|
let mut n_train_words = 0u64;
|
||||||
|
for (chunk, c) in counts.into_iter() {
|
||||||
|
words.push(Word::new(
|
||||||
|
chunk.as_bytes().iter().map(|&b| b as u32).collect(),
|
||||||
|
));
|
||||||
|
cvec.push(c);
|
||||||
|
n_train_words += c as u64;
|
||||||
|
}
|
||||||
|
|
||||||
|
let merges = Self::train_core(mp, words, cvec, vocab_size);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
merges,
|
||||||
|
split_regex,
|
||||||
|
split_regex_string,
|
||||||
|
n_train_texts,
|
||||||
|
n_train_words,
|
||||||
|
n_tokens: vocab_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user