Compare commits

...

2 Commits

Author SHA1 Message Date
2ed4dc74ef Special tokens 2025-10-16 09:52:37 -07:00
6b7b410dda Train tokenizer 2025-10-16 09:51:15 -07:00
13 changed files with 5358 additions and 1 deletions

3
.gitignore vendored
View File

@ -1,2 +1,5 @@
/target
*.ignore
/data
/tokenizer.json

2889
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

79
Cargo.toml Normal file
View File

@ -0,0 +1,79 @@
[package]
name = "nanochat-rs"
version = "0.1.0"
edition = "2024"
[lints.rust]
unused_import_braces = "deny"
unit_bindings = "deny"
single_use_lifetimes = "deny"
non_ascii_idents = "deny"
macro_use_extern_crate = "deny"
elided_lifetimes_in_paths = "deny"
absolute_paths_not_starting_with_crate = "deny"
explicit_outlives_requirements = "warn"
unused_crate_dependencies = "warn"
redundant_lifetimes = "warn"
missing_docs = "allow"
[lints.clippy]
todo = "deny"
uninlined_format_args = "allow"
result_large_err = "allow"
too_many_arguments = "allow"
upper_case_acronyms = "deny"
needless_return = "allow"
new_without_default = "allow"
tabs_in_doc_comments = "allow"
dbg_macro = "deny"
allow_attributes = "deny"
create_dir = "deny"
filetype_is_file = "deny"
integer_division = "allow"
lossy_float_literal = "deny"
map_err_ignore = "deny"
mutex_atomic = "deny"
needless_raw_strings = "deny"
str_to_string = "deny"
string_add = "deny"
string_to_string = "deny"
use_debug = "allow"
verbose_file_reads = "deny"
large_types_passed_by_value = "deny"
wildcard_dependencies = "deny"
negative_feature_names = "deny"
redundant_feature_names = "deny"
multiple_crate_versions = "allow"
missing_safety_doc = "warn"
identity_op = "allow"
print_stderr = "deny"
print_stdout = "deny"
comparison_chain = "allow"
unimplemented = "deny"
unwrap_used = "warn"
expect_used = "warn"
type_complexity = "allow"
[dependencies]
ahash = "0.8.12"
anstyle = "1.0.13"
anyhow = "1.0.100"
clap = { version = "4.5.49", features = ["derive"] }
compact_str = "0.9.0"
dary_heap = "0.3.8"
fancy-regex = "0.16.2"
futures-util = "0.3.31"
indicatif = { version = "0.18.0", features = ["improved_unicode"] }
parking_lot = "0.12.5"
parquet = "56.2.0"
rayon = "1.11.0"
reqwest = { version = "0.12.24", features = ["json", "stream"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
thiserror = "2.0.17"
tokio = { version = "1.48.0", features = ["full"] }
tracing = "0.1.41"
tracing-indicatif = "0.3.13"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
url = "2.5.7"

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
hard_tabs = true

64
src/cli.rs Normal file
View File

@ -0,0 +1,64 @@
use anstyle::{AnsiColor, Color, Style};
use indicatif::ProgressStyle;
pub fn clap_styles() -> clap::builder::Styles {
clap::builder::Styles::styled()
.usage(
Style::new()
.bold()
.fg_color(Some(Color::Ansi(AnsiColor::Blue))),
)
.header(
Style::new()
.bold()
.fg_color(Some(Color::Ansi(AnsiColor::Blue))),
)
.literal(
Style::new()
.bold()
.fg_color(Some(Color::Ansi(AnsiColor::BrightBlack))),
)
.invalid(
Style::new()
.bold()
.fg_color(Some(Color::Ansi(AnsiColor::Red))),
)
.error(
Style::new()
.bold()
.fg_color(Some(Color::Ansi(AnsiColor::Red))),
)
.valid(
Style::new()
.bold()
.underline()
.fg_color(Some(Color::Ansi(AnsiColor::Green))),
)
.placeholder(Style::new().fg_color(Some(Color::Ansi(AnsiColor::White))))
}
#[expect(clippy::unwrap_used)]
pub fn progress_big() -> ProgressStyle {
return ProgressStyle::default_bar()
.template(
" {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim} ({eta})",
)
.unwrap()
.progress_chars("=>-")
.tick_strings(&[
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
]);
}
#[expect(clippy::unwrap_used)]
pub fn progress_bytes() -> ProgressStyle {
return ProgressStyle::default_bar()
.template(
" {bar:16.red/white.dim} {elapsed_precise:.dim} {bytes:>10}/{total_bytes:>10} {msg:.dim} ({eta})",
)
.unwrap()
.progress_chars("---")
.tick_strings(&[
"⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹",
]);
}

280
src/command/download.rs Normal file
View File

@ -0,0 +1,280 @@
use anyhow::Result;
use clap::Args;
use futures_util::StreamExt;
use indicatif::MultiProgress;
use indicatif::ProgressBar;
use rayon::prelude::*;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::runtime::Runtime;
use tracing::{debug, error, info};
use url::Url;
use crate::cli::{progress_big, progress_bytes};
const BASE_URL: &str =
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main";
const MAX_SHARD: usize = 1822;
#[derive(Debug, Args, Clone)]
pub struct DownloadArgs {
/// Training data dir
#[clap(default_value = "data")]
data_dir: PathBuf,
/// Number of shards to download (-1 for all)
#[arg(short = 'n', long, default_value = "-1")]
num_files: isize,
/// Number of parallel downloads
#[arg(short = 't', long, default_value = "8")]
threads: usize,
}
impl DownloadArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
info!("Downloading files from {BASE_URL}");
fs::create_dir_all(&self.data_dir)?;
let num_shards_to_download = if self.num_files == -1 {
MAX_SHARD + 1
} else {
self.num_files.min((MAX_SHARD + 1) as isize) as usize
};
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
info!("Downloading {} shards...", ids_to_download.len(),);
info!("Target directory: {}", self.data_dir.display());
let main_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
pb.set_style(progress_big());
pb.set_message("Downloading training data");
pb.enable_steady_tick(Duration::from_millis(100));
pb
});
let rt = Runtime::new()?;
let (tx, rx) = std::sync::mpsc::channel();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.threads)
.build()?;
pool.install(|| {
ids_to_download
.into_par_iter()
.for_each_with(tx, |tx, index| {
let target = self.data_dir.clone();
let main_pb = main_pb.clone();
let mp_clone = mp.clone();
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread
let result = rt_handle.block_on(async {
download_single_file(index, &target, main_pb, mp_clone).await
});
// Send the result back to the main thread for aggregation
#[expect(clippy::unwrap_used)]
tx.send(result).unwrap();
});
});
// Wait for all downloads to finish and collect results
let mut successful_downloads = 0;
for _ in 0..num_shards_to_download {
if let Ok(Ok(_)) = rx.recv() {
// Receive the Result<(), String> wrapped in a Result from MPSC
successful_downloads += 1;
}
}
if let Some(pb) = main_pb.as_ref() {
pb.finish_and_clear();
info!("Downloads complete ({successful_downloads} successful)");
}
return Ok(());
}
}
async fn download_single_file(
index: usize,
target: &Path,
progress_bar: Option<ProgressBar>,
mp: Option<MultiProgress>,
) -> Result<(), String> {
let filename = format!("shard_{:05}.parquet", index);
let filepath = target.join(&filename);
if filepath.exists() {
info!("Skipping {} (already exists)", filepath.display());
if let Some(pb) = progress_bar.as_ref() {
pb.inc(1);
}
return Ok(());
}
#[expect(clippy::unwrap_used)]
let url = Url::parse(&format!("{BASE_URL}/{filename}")).unwrap();
info!("Downloading {} from {}", filename, url);
let max_attempts = 5;
'attempt_loop: for attempt in 1..=max_attempts {
let temp_filepath = filepath.with_extension("parquet.tmp");
let client = reqwest::Client::new();
match client.get(url.clone()).send().await {
Ok(response) => {
if !response.status().is_success() {
error!(
"Attempt {}/{}: Server responded with status {} for {}",
attempt,
max_attempts,
response.status(),
url
);
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue;
}
let total_size = response.content_length().unwrap_or(0);
debug!("Total size for {}: {}", filename, total_size);
// Create file progress bar
let file_pb = if total_size > 0
&& let Some(mp) = mp.as_ref()
{
Some({
let pb = mp.add(ProgressBar::new(total_size));
pb.set_style(progress_bytes());
pb.set_message(format!("Downloading {}", filename));
pb
})
} else {
None
};
let mut file = match tokio::fs::File::create(&temp_filepath).await {
Ok(file) => file,
Err(e) => {
error!(
"Attempt {}/{}: Failed to create temporary file {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
e
);
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue;
}
};
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
while let Some(chunk_result) = StreamExt::next(&mut stream).await {
match chunk_result {
Ok(chunk) => {
match tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await {
Ok(_) => {
downloaded += chunk.len() as u64;
if let Some(pb) = &file_pb {
pb.set_position(downloaded);
}
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to write to temporary file {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
e
);
// Clean up
let _ = tokio::fs::remove_file(&temp_filepath).await;
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt)))
.await;
continue 'attempt_loop;
}
}
}
Err(e) => {
error!(
"Attempt {}/{}: Error reading chunk for {}: {}",
attempt, max_attempts, filename, e
);
let _ = tokio::fs::remove_file(&temp_filepath).await;
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue 'attempt_loop;
}
}
}
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
// Atomically rename the temporary file
match tokio::fs::rename(&temp_filepath, &filepath).await {
Ok(_) => {
info!("Successfully downloaded {}", filename);
if let Some(pb) = progress_bar.as_ref() {
pb.inc(1);
}
return Ok(());
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to rename temporary file {} to {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
filepath.display(),
e
);
let _ = tokio::fs::remove_file(&temp_filepath).await; // Clean up
}
}
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to download {}: {}",
attempt, max_attempts, filename, e
);
// Clean up any partial files
let _ = tokio::fs::remove_file(&temp_filepath).await;
}
}
if attempt < max_attempts {
let wait_time = 2u64.pow(attempt);
info!("Waiting {} seconds before retry...", wait_time);
tokio::time::sleep(Duration::from_secs(wait_time)).await;
} else {
error!(
"Failed to download {} after {} attempts",
filename, max_attempts
);
return Err(format!("Failed to download {}", filename));
}
}
Err(format!("Failed to download {}", filename))
}

26
src/command/mod.rs Normal file
View File

@ -0,0 +1,26 @@
mod download;
mod train_tokenizer;
#[derive(Debug, clap::Subcommand)]
pub enum SubCommand {
/// Download training set
Download {
#[command(flatten)]
args: download::DownloadArgs,
},
/// Train tokenizer
TrainTokenizer {
#[command(flatten)]
args: train_tokenizer::TrainTokenizerArgs,
},
}
impl SubCommand {
pub fn run(self, mp: Option<indicatif::MultiProgress>) -> anyhow::Result<()> {
match self {
Self::Download { args } => args.run(mp),
Self::TrainTokenizer { args } => args.run(mp),
}
}
}

View File

@ -0,0 +1,61 @@
use anyhow::{Context, Result};
use clap::Args;
use indicatif::MultiProgress;
use rayon::ThreadPoolBuilder;
use std::fs::File;
use std::path::PathBuf;
use tracing::info;
use crate::data_reader::DataReader;
use crate::tokenizer::Tokenizer;
#[derive(Debug, Args, Clone)]
pub struct TrainTokenizerArgs {
/// Path to training data
#[clap(default_value = "data")]
data_dir: PathBuf,
/// Where to save tokenizer
#[clap(default_value = "tokenizer.json")]
target: PathBuf,
/// Only train on the first n texts
#[clap(long)]
first_n: Option<usize>,
/// Number of threads to use for training
#[clap(long, default_value = "0")]
threads: usize,
/// Tokenizer vocabulary size
#[clap(long, default_value = "65535")]
n_tokens: u32,
}
impl TrainTokenizerArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap());
let pool = ThreadPoolBuilder::new()
.num_threads(self.threads)
.build()
.context("while building thread pool")?;
let tokenizer = pool
.install(|| match self.first_n {
Some(n) => Tokenizer::train(mp, iter.take(n), self.n_tokens),
None => Tokenizer::train(mp, iter, self.n_tokens),
})
.context("while training tokenizer")?;
info!("Saving to {}", self.target.display());
let target = File::create(&self.target).context("while creating tokenizer file")?;
serde_json::to_writer(target, &tokenizer).context("while saving tokenizer")?;
Ok(())
}
}

187
src/data_reader.rs Normal file
View File

@ -0,0 +1,187 @@
use anyhow::Result;
use parking_lot::Mutex;
use parquet::errors::ParquetError;
use parquet::file::reader::{FileReader, SerializedFileReader};
use parquet::record::RowAccessor;
use std::fs::File;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::Receiver;
use std::sync::{Arc, mpsc};
use thiserror::Error;
use tracing::{debug, trace};
#[derive(Debug, Error)]
pub enum DataReaderError {
#[error("i/o error: {0}")]
IoError(#[from] std::io::Error),
#[error("parquet error: {0}")]
ParquetError(#[from] ParquetError),
}
/// A data reader that uses `n` threads to read
/// rows from a directory of parquet files.
///
/// All parquet files have exactly one text column.
/// No promises about this struct's behavior if this is not the case.
pub struct DataReader {
rx: Receiver<Result<String, DataReaderError>>,
total_rows: usize,
consumed_rows: AtomicUsize,
}
impl DataReader {
/// Create a new [DataReader] that reads all `*.parquet` files in the given directory.
/// All other files are ignored.
pub fn new(threads: usize, data_dir: impl AsRef<Path>) -> Result<Self, std::io::Error> {
let (files, total_rows) = {
let mut files = Vec::new();
let mut total_rows = 0usize;
let entries = std::fs::read_dir(data_dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("parquet") {
// Count rows in this file
let file = File::open(&path)?;
let reader = SerializedFileReader::new(file)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let metadata = reader.metadata();
for row_group_index in 0..metadata.num_row_groups() {
let row_group_metadata = metadata.row_group(row_group_index);
total_rows += row_group_metadata.num_rows() as usize;
}
files.push(path);
}
}
(Arc::new(Mutex::new(files)), total_rows)
};
#[expect(clippy::unwrap_used)]
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.thread_name(move |i| format!("DataReader-{i}-of-{threads}"))
.build()
.unwrap();
debug!(
message = "Starting DataReader",
threads,
n_files = files.lock().len()
);
let buffer = threads.min(10) * 200;
let (tx, rx) = mpsc::sync_channel::<Result<String, DataReaderError>>(buffer);
pool.spawn_broadcast(move |_ctx| {
let tx = tx.clone();
let files = files.clone();
'outer: loop {
let file = match files.lock().pop() {
None => break 'outer,
Some(x) => x,
};
trace!("Reading rows from {}", file.display());
let file = match File::open(&file) {
Ok(x) => x,
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
let reader = match SerializedFileReader::new(file) {
Ok(x) => x,
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
let metadata = reader.metadata();
for row_group_index in 0..metadata.num_row_groups() {
let row_group_reader = match reader.get_row_group(row_group_index) {
Ok(x) => x,
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
let row_iter = match row_group_reader.get_row_iter(None) {
Ok(x) => x,
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
for row_result in row_iter {
let row = match row_result {
Ok(x) => x,
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
let text = match row.get_string(0) {
Ok(x) => x.clone(),
Err(err) => {
let _ = tx.send(Err(err.into()));
break 'outer;
}
};
if tx.send(Ok(text)).is_err() {
return;
}
}
}
}
});
Ok(Self {
rx,
total_rows,
consumed_rows: AtomicUsize::new(0),
})
}
/// Get the next available row.
/// Order is arbitrary.
/// Returns `None` when all rows have been read.
pub fn recv(&self) -> Option<Result<String, DataReaderError>> {
self.rx.recv().ok()
}
//pub fn try_recv(&self) -> Result<Result<String, DataReaderError>, TryRecvError> {
// self.rx.try_recv()
//}
}
impl Iterator for DataReader {
type Item = Result<String, DataReaderError>;
fn next(&mut self) -> Option<Self::Item> {
match self.recv() {
Some(item) => {
self.consumed_rows.fetch_add(1, Ordering::Relaxed);
Some(item)
}
None => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let consumed = self.consumed_rows.load(Ordering::Relaxed);
let len = self.total_rows.saturating_sub(consumed);
return (len, Some(len));
}
}
impl ExactSizeIterator for DataReader {}

349
src/logging.rs Normal file
View File

@ -0,0 +1,349 @@
use anyhow::Result;
use clap::{Parser, ValueEnum};
use indicatif::MultiProgress;
use serde::Deserialize;
use std::{fmt::Display, str::FromStr};
use tracing_indicatif::IndicatifWriter;
use tracing_subscriber::{
EnvFilter, Layer, fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt,
};
//
// MARK: loglevel
//
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)]
pub enum LogLevel {
Trace,
Debug,
Info,
Warn,
Error,
}
impl Default for LogLevel {
fn default() -> Self {
Self::Info
}
}
impl Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Trace => write!(f, "trace"),
Self::Debug => write!(f, "debug"),
Self::Info => write!(f, "info"),
Self::Warn => write!(f, "warn"),
Self::Error => write!(f, "error"),
}
}
}
//
// MARK: logconfig
//
/// Configures log levels for known sources
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
pub struct LoggingConfig {
pub other: LogLevel,
pub silence: LogLevel,
// Bins
pub nanochat: LogLevel,
}
impl From<LoggingConfig> for EnvFilter {
fn from(conf: LoggingConfig) -> Self {
// Should never fail
#[expect(clippy::unwrap_used)]
EnvFilter::from_str(
&[
//
// Silence
//
format!("hyper_util={}", conf.silence),
format!("h2={}", conf.silence),
format!("rustls={}", conf.silence),
format!("tower={}", conf.silence),
format!("html5ever={}", conf.silence),
format!("selectors={}", conf.silence),
format!("wgpu_core={}", conf.silence),
format!("naga={}", conf.silence),
format!("cubecl={}", conf.silence),
//
// Bins
//
format!("nanochat_rs={}", conf.nanochat),
conf.other.to_string(),
]
.join(","),
)
.unwrap()
}
}
//
// MARK: LogCliVQ
//
/// Provides global -v and -q cli arguments.
/// Use with `#[arg(flatten)]`.
#[derive(Parser, Debug, Clone)]
pub struct LogCliVQ {
/// Increase verbosity (can be repeated)
#[arg(default_value = "0")]
#[arg(short, action = clap::ArgAction::Count,global = true)]
v: u8,
/// Decrease verbosity (can be repeated)
#[arg(default_value = "0")]
#[arg(short, action = clap::ArgAction::Count, global = true)]
q: u8,
}
impl LogCliVQ {
pub fn into_preset(self) -> LogFilterPreset {
let level_i: i16 = self.v as i16 - self.q as i16;
let preset;
if level_i <= -2 {
preset = LogFilterPreset::Error
} else if level_i == -1 {
preset = LogFilterPreset::Warn
} else if level_i == 0 {
preset = LogFilterPreset::Info
} else if level_i == 1 {
preset = LogFilterPreset::Debug
} else if level_i >= 2 {
preset = LogFilterPreset::Trace
} else {
unreachable!()
}
return preset;
}
}
//
// MARK: logpreset
//
/// Provides preset configurations of [LoggingConfig]
#[derive(Debug, Deserialize, Clone, Copy)]
pub enum LogFilterPreset {
/// Standard "error" log level
Error,
/// Standard "warn" log level
Warn,
/// Standard "info" log level.
/// This is the default.
Info,
/// Standard "debug" log level
Debug,
/// Standard "trace" log level
Trace,
}
impl Default for LogFilterPreset {
fn default() -> Self {
return Self::Info;
}
}
impl From<LogFilterPreset> for LoggingConfig {
fn from(val: LogFilterPreset) -> Self {
val.get_config()
}
}
impl LogFilterPreset {
pub fn get_config(&self) -> LoggingConfig {
match self {
Self::Error => LoggingConfig {
other: LogLevel::Error,
silence: LogLevel::Error,
nanochat: LogLevel::Error,
},
Self::Warn => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Warn,
},
Self::Info => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Info,
},
Self::Debug => LoggingConfig {
other: LogLevel::Warn,
silence: LogLevel::Warn,
nanochat: LogLevel::Debug,
},
Self::Trace => LoggingConfig {
other: LogLevel::Trace,
silence: LogLevel::Warn,
nanochat: LogLevel::Trace,
},
}
}
}
//
// MARK: initializer
//
/// Where to print logs
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
pub enum LoggingTarget {
/// Send logs to stdout
Stdout { format: LoggingFormat },
/// Send logs to stderr
Stderr { format: LoggingFormat },
/// Send logs to an IndicatifWriter.
///
/// This is the same as Stderr { format: Ansi {color:true} },
/// but uses an indicatifwriter with the given multiprogress.
Indicatif(MultiProgress),
}
/// How to print logs
#[derive(Debug, Clone, Copy, Deserialize)]
pub enum LoggingFormat {
Ansi,
AnsiNoColor,
Json,
}
impl Default for LoggingFormat {
fn default() -> Self {
Self::Ansi
}
}
pub struct LoggingInitializer {
/// Log filter for printed logs
pub preset: LogFilterPreset,
/// Where to print logs
pub target: LoggingTarget,
}
impl LoggingInitializer {
pub fn initialize(self) -> Result<()> {
let mut stderr_ansi_layer = None;
let mut stderr_json_layer = None;
let mut stdout_ansi_layer = None;
let mut stdout_json_layer = None;
let mut indicatif_layer = None;
match self.target {
LoggingTarget::Stderr {
format: LoggingFormat::Ansi,
} => {
stderr_ansi_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(true)
.with_writer(std::io::stderr)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Stderr {
format: LoggingFormat::AnsiNoColor,
} => {
stderr_ansi_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(false)
.with_writer(std::io::stderr)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Stderr {
format: LoggingFormat::Json,
} => {
stderr_json_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(false)
.json()
.with_writer(std::io::stderr)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Stdout {
format: LoggingFormat::Ansi,
} => {
stdout_ansi_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(true)
.with_writer(std::io::stdout)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Stdout {
format: LoggingFormat::AnsiNoColor,
} => {
stdout_ansi_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(false)
.with_writer(std::io::stdout)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Stdout {
format: LoggingFormat::Json,
} => {
stdout_json_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(false)
.json()
.with_writer(std::io::stdout)
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
LoggingTarget::Indicatif(mp) => {
let writer: IndicatifWriter<tracing_indicatif::writer::Stderr> =
IndicatifWriter::new(mp);
indicatif_layer = Some(
tracing_subscriber::fmt::Layer::default()
.without_time()
.with_ansi(true)
.with_writer(writer.make_writer())
.with_filter::<EnvFilter>(self.preset.get_config().into()),
)
}
}
tracing_subscriber::registry()
.with(stdout_ansi_layer)
.with(stdout_json_layer)
.with(stderr_ansi_layer)
.with(stderr_json_layer)
.with(indicatif_layer)
.init();
Ok(())
}
}

64
src/main.rs Normal file
View File

@ -0,0 +1,64 @@
use clap::Parser;
use indicatif::MultiProgress;
use tracing::error;
use crate::{
command::SubCommand,
logging::{LogCliVQ, LoggingFormat, LoggingInitializer, LoggingTarget},
};
mod cli;
mod command;
mod data_reader;
mod logging;
mod split;
mod tokenizer;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None, styles=crate::cli::clap_styles())]
struct Cli {
#[clap(flatten)]
vq: LogCliVQ,
/// If true, never show progress bars.
#[arg(long)]
noprogress: bool,
#[command(subcommand)]
command: SubCommand,
}
fn main() {
let cli = Cli::parse();
let mp = (!cli.noprogress).then(MultiProgress::new);
{
let res = LoggingInitializer {
preset: cli.vq.into_preset(),
target: match mp.clone() {
None => LoggingTarget::Stderr {
format: LoggingFormat::Ansi,
},
Some(mp) => LoggingTarget::Indicatif(mp.clone()),
},
}
.initialize();
if let Err(e) = res {
#[expect(clippy::print_stderr)]
for e in e.chain() {
eprintln!("{e}");
}
std::process::exit(1);
}
}
if let Err(e) = cli.command.run(mp) {
error!("Error: {e:?}");
for e in e.chain() {
error!(e);
}
std::process::exit(1);
}
}

266
src/split.rs Normal file
View File

@ -0,0 +1,266 @@
use fancy_regex::Regex;
/// Split text using a regex while keeping both the matched parts and the parts between matches
pub fn regex_segment<'a>(re: &Regex, text: &'a str) -> Vec<&'a str> {
let mut result = Vec::new();
let mut last = 0;
for mat in re.find_iter(text) {
#[expect(clippy::unwrap_used)]
let mat = mat.unwrap();
if mat.start() > last {
result.push(&text[last..mat.start()]);
}
result.push(mat.as_str());
last = mat.end();
}
if last < text.len() {
result.push(&text[last..]);
}
result.retain(|x| !x.is_empty());
result
}
//
// MARK: tests
//
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::Tokenizer;
#[test]
fn basic() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,banana;cherry";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", "banana", ";", "cherry"]);
}
#[test]
fn empty_string() {
let re = Regex::new(r"[,;]").unwrap();
let text = "";
let result = regex_segment(&re, text);
assert_eq!(result, Vec::<&str>::new());
}
#[test]
fn no_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple banana cherry";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple banana cherry"]);
}
#[test]
fn only_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = ",;,";
let result = regex_segment(&re, text);
assert_eq!(result, vec![",", ";", ","]);
}
#[test]
fn starts_with_match() {
let re = Regex::new(r"[,;]").unwrap();
let text = ",apple;banana";
let result = regex_segment(&re, text);
assert_eq!(result, vec![",", "apple", ";", "banana"]);
}
#[test]
fn ends_with_match() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,banana;";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", "banana", ";"]);
}
#[test]
fn consecutive_matches() {
let re = Regex::new(r"[,;]").unwrap();
let text = "apple,,banana";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["apple", ",", ",", "banana"]);
}
#[test]
fn word_boundaries() {
let re = Regex::new(r"\b").unwrap();
let text = "hello world";
let result = regex_segment(&re, text);
// Word boundaries are zero-width, so we get empty matches between word chars and non-word chars
assert_eq!(result, vec!["hello", " ", "world"]);
}
#[test]
fn digits() {
let re = Regex::new(r"\d+").unwrap();
let text = "abc123def456ghi";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["abc", "123", "def", "456", "ghi"]);
}
#[test]
fn special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "Hello <|user_start|>world<|user_end|> test";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["Hello ", "<|user_start|>", "world", "<|user_end|>", " test"]
);
}
#[test]
fn unicode() {
let re = Regex::new(r"[=<3D>=<3D>]+").unwrap();
let text = "Hello=<3D>world=<3D>test";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["Hello", "=<3D>", "world", "=<3D>", "test"]);
}
#[test]
fn single_char() {
let re = Regex::new(r"x").unwrap();
let text = "x";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["x"]);
}
#[test]
fn multichar_match() {
let re = Regex::new(r"abc").unwrap();
let text = "123abc456abc789";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["123", "abc", "456", "abc", "789"]);
}
#[test]
fn bos_token() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|bos|>This is a document";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|bos|>", "This is a document"]);
}
#[test]
fn conversation_flow() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|user_start|>Hello<|user_end|><|assistant_start|>Hi there!<|assistant_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"<|user_start|>",
"Hello",
"<|user_end|>",
"<|assistant_start|>",
"Hi there!",
"<|assistant_end|>"
]
);
}
#[test]
fn python_code_block() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "Code: <|python_start|>print('hello')<|python_end|> Output: <|output_start|>hello<|output_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"Code: ",
"<|python_start|>",
"print('hello')",
"<|python_end|>",
" Output: ",
"<|output_start|>",
"hello",
"<|output_end|>"
]
);
}
#[test]
fn mixed_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text =
"<|bos|><|user_start|>Question<|user_end|><|assistant_start|>Answer<|assistant_end|>";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
"<|bos|>",
"<|user_start|>",
"Question",
"<|user_end|>",
"<|assistant_start|>",
"Answer",
"<|assistant_end|>"
]
);
}
#[test]
fn no_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "This is just regular text with no special tokens";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["This is just regular text with no special tokens"]
);
}
#[test]
fn malformed_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "This has <|invalid_token> and <user_start> which shouldn't match";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec!["This has <|invalid_token> and <user_start> which shouldn't match"]
);
}
#[test]
fn special_tokens_with_whitespace() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = " <|bos|> \n<|user_start|>\tHello\n<|user_end|> ";
let result = regex_segment(&re, text);
assert_eq!(
result,
vec![
" ",
"<|bos|>",
" \n",
"<|user_start|>",
"\tHello\n",
"<|user_end|>",
" "
]
);
}
#[test]
fn only_special_tokens() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|bos|><|user_start|><|user_end|>";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|bos|>", "<|user_start|>", "<|user_end|>"]);
}
#[test]
fn nested() {
let re = Regex::new(Tokenizer::SPECIAL_REGEX).unwrap();
let text = "<|<|bos|>|>";
let result = regex_segment(&re, text);
assert_eq!(result, vec!["<|", "<|bos|>", "|>"]);
}
}

1088
src/tokenizer.rs Normal file

File diff suppressed because it is too large Load Diff