Files
nanochat-rs/src/command/train_tokenizer.rs
2025-10-16 06:11:40 -07:00

58 lines
1.5 KiB
Rust

use anyhow::{Context, Result};
use clap::Args;
use indicatif::MultiProgress;
use rayon::ThreadPoolBuilder;
use std::fs::File;
use std::path::PathBuf;
use tracing::info;
use crate::data_reader::DataReader;
use crate::tokenizer::Tokenizer;
#[derive(Debug, Args, Clone)]
pub struct TrainTokenizerArgs {
/// Path to training data
#[clap(default_value = "data")]
data_dir: PathBuf,
/// Where to save tokenizer
#[clap(default_value = "tokenizer.json")]
target: PathBuf,
/// Only train on the first n texts
#[clap(long)]
first_n: Option<usize>,
/// Number of threads to use for training
#[clap(long, default_value = "0")]
threads: usize,
}
impl TrainTokenizerArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
#[expect(clippy::unwrap_used)] // Lazy error handling
let iter = iter.map(|x| x.unwrap());
let pool = ThreadPoolBuilder::new()
.num_threads(self.threads)
.build()
.context("while building thread pool")?;
let tokenizer = pool
.install(|| match self.first_n {
Some(n) => Tokenizer::train(mp, iter.take(n), 1024),
None => Tokenizer::train(mp, iter, 1024),
})
.context("while training tokenizer")?;
info!("Saving to {}", self.target.display());
let target = File::create(&self.target).context("while creating tokenizer file")?;
serde_json::to_writer(target, &tokenizer).context("while saving tokenizer")?;
Ok(())
}
}