58 lines
1.5 KiB
Rust
58 lines
1.5 KiB
Rust
use anyhow::{Context, Result};
|
|
use clap::Args;
|
|
use indicatif::MultiProgress;
|
|
use rayon::ThreadPoolBuilder;
|
|
use std::fs::File;
|
|
use std::path::PathBuf;
|
|
use tracing::info;
|
|
|
|
use crate::data_reader::DataReader;
|
|
use crate::tokenizer::Tokenizer;
|
|
|
|
#[derive(Debug, Args, Clone)]
|
|
|
|
pub struct TrainTokenizerArgs {
|
|
/// Path to training data
|
|
#[clap(default_value = "data")]
|
|
data_dir: PathBuf,
|
|
|
|
/// Where to save tokenizer
|
|
#[clap(default_value = "tokenizer.json")]
|
|
target: PathBuf,
|
|
|
|
/// Only train on the first n texts
|
|
#[clap(long)]
|
|
first_n: Option<usize>,
|
|
|
|
/// Number of threads to use for training
|
|
#[clap(long, default_value = "0")]
|
|
threads: usize,
|
|
}
|
|
|
|
impl TrainTokenizerArgs {
|
|
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
|
let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?;
|
|
|
|
#[expect(clippy::unwrap_used)] // Lazy error handling
|
|
let iter = iter.map(|x| x.unwrap());
|
|
|
|
let pool = ThreadPoolBuilder::new()
|
|
.num_threads(self.threads)
|
|
.build()
|
|
.context("while building thread pool")?;
|
|
|
|
let tokenizer = pool
|
|
.install(|| match self.first_n {
|
|
Some(n) => Tokenizer::train(mp, iter.take(n), 1024),
|
|
None => Tokenizer::train(mp, iter, 1024),
|
|
})
|
|
.context("while training tokenizer")?;
|
|
|
|
info!("Saving to {}", self.target.display());
|
|
let target = File::create(&self.target).context("while creating tokenizer file")?;
|
|
serde_json::to_writer(target, &tokenizer).context("while saving tokenizer")?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|