From 35f64c16c15b22019ba0fd70ae58bb8a45fd19e7 Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:41:40 -0800 Subject: [PATCH] Initial CLI --- crates/llmfs/src/cli.rs | 64 ++++ crates/llmfs/src/command/download.rs | 280 ++++++++++++++++ crates/llmfs/src/command/mod.rs | 26 ++ crates/llmfs/src/command/train_tokenizer.rs | 61 ++++ crates/llmfs/src/data_reader.rs | 187 +++++++++++ crates/llmfs/src/logging.rs | 343 ++++++++++++++++++++ crates/llmfs/src/main.rs | 62 ++++ 7 files changed, 1023 insertions(+) create mode 100644 crates/llmfs/src/cli.rs create mode 100644 crates/llmfs/src/command/download.rs create mode 100644 crates/llmfs/src/command/mod.rs create mode 100644 crates/llmfs/src/command/train_tokenizer.rs create mode 100644 crates/llmfs/src/data_reader.rs create mode 100644 crates/llmfs/src/logging.rs create mode 100644 crates/llmfs/src/main.rs diff --git a/crates/llmfs/src/cli.rs b/crates/llmfs/src/cli.rs new file mode 100644 index 0000000..01d17bc --- /dev/null +++ b/crates/llmfs/src/cli.rs @@ -0,0 +1,64 @@ +use anstyle::{AnsiColor, Color, Style}; +use indicatif::ProgressStyle; + +pub fn clap_styles() -> clap::builder::Styles { + clap::builder::Styles::styled() + .usage( + Style::new() + .bold() + .fg_color(Some(Color::Ansi(AnsiColor::Blue))), + ) + .header( + Style::new() + .bold() + .fg_color(Some(Color::Ansi(AnsiColor::Blue))), + ) + .literal( + Style::new() + .bold() + .fg_color(Some(Color::Ansi(AnsiColor::BrightBlack))), + ) + .invalid( + Style::new() + .bold() + .fg_color(Some(Color::Ansi(AnsiColor::Red))), + ) + .error( + Style::new() + .bold() + .fg_color(Some(Color::Ansi(AnsiColor::Red))), + ) + .valid( + Style::new() + .bold() + .underline() + .fg_color(Some(Color::Ansi(AnsiColor::Green))), + ) + .placeholder(Style::new().fg_color(Some(Color::Ansi(AnsiColor::White)))) +} + +#[expect(clippy::unwrap_used)] +pub fn progress_big() -> ProgressStyle { + return ProgressStyle::default_bar() + .template( + " {spinner:.green} [{elapsed_precise}] [{bar:40.green/dim}] {pos:>7}/{len:7} {msg:.dim} ({eta})", + ) + .unwrap() + .progress_chars("=>-") + .tick_strings(&[ + "⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹", + ]); +} + +#[expect(clippy::unwrap_used)] +pub fn progress_bytes() -> ProgressStyle { + return ProgressStyle::default_bar() + .template( + " {bar:16.red/white.dim} {elapsed_precise:.dim} {bytes:>10}/{total_bytes:>10} {msg:.dim} ({eta})", + ) + .unwrap() + .progress_chars("---") + .tick_strings(&[ + "⠉⠉", "⠈⠙", "⠀⠹", "⠀⢸", "⠀⣰", "⢀⣠", "⣀⣀", "⣄⡀", "⣆⠀", "⡇⠀", "⠏⠀", "⠋⠁", "⣏⣹", + ]); +} diff --git a/crates/llmfs/src/command/download.rs b/crates/llmfs/src/command/download.rs new file mode 100644 index 0000000..9674cb4 --- /dev/null +++ b/crates/llmfs/src/command/download.rs @@ -0,0 +1,280 @@ +use anyhow::Result; +use clap::Args; +use futures_util::StreamExt; +use indicatif::MultiProgress; +use indicatif::ProgressBar; +use rayon::prelude::*; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::Duration; +use tokio::runtime::Runtime; +use tracing::{debug, error, info}; +use url::Url; + +use crate::cli::{progress_big, progress_bytes}; + +const BASE_URL: &str = + "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"; + +const MAX_SHARD: usize = 1822; + +#[derive(Debug, Args, Clone)] + +pub struct DownloadArgs { + /// Training data dir + #[clap(default_value = "data")] + data_dir: PathBuf, + + /// Number of shards to download (-1 for all) + #[arg(short = 'n', long, default_value = "-1")] + num_files: isize, + + /// Number of parallel downloads + #[arg(short = 't', long, default_value = "8")] + threads: usize, +} + +impl DownloadArgs { + pub fn run(self, mp: Option) -> Result<()> { + info!("Downloading files from {BASE_URL}"); + fs::create_dir_all(&self.data_dir)?; + + let num_shards_to_download = if self.num_files == -1 { + MAX_SHARD + 1 + } else { + self.num_files.min((MAX_SHARD + 1) as isize) as usize + }; + + let ids_to_download: Vec = (0..num_shards_to_download).collect(); + + info!("Downloading {} shards...", ids_to_download.len(),); + info!("Target directory: {}", self.data_dir.display()); + + let main_pb = mp.as_ref().map(|mp| { + let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64)); + pb.set_style(progress_big()); + pb.set_message("Downloading training data"); + pb.enable_steady_tick(Duration::from_millis(100)); + pb + }); + + let rt = Runtime::new()?; + + let (tx, rx) = std::sync::mpsc::channel(); + + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(self.threads) + .build()?; + + pool.install(|| { + ids_to_download + .into_par_iter() + .for_each_with(tx, |tx, index| { + let target = self.data_dir.clone(); + let main_pb = main_pb.clone(); + let mp_clone = mp.clone(); + let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread + + let result = rt_handle.block_on(async { + download_single_file(index, &target, main_pb, mp_clone).await + }); + + // Send the result back to the main thread for aggregation + #[expect(clippy::unwrap_used)] + tx.send(result).unwrap(); + }); + }); + + // Wait for all downloads to finish and collect results + let mut successful_downloads = 0; + for _ in 0..num_shards_to_download { + if let Ok(Ok(_)) = rx.recv() { + // Receive the Result<(), String> wrapped in a Result from MPSC + successful_downloads += 1; + } + } + + if let Some(pb) = main_pb.as_ref() { + pb.finish_and_clear(); + info!("Downloads complete ({successful_downloads} successful)"); + } + + return Ok(()); + } +} + +async fn download_single_file( + index: usize, + target: &Path, + progress_bar: Option, + mp: Option, +) -> Result<(), String> { + let filename = format!("shard_{:05}.parquet", index); + let filepath = target.join(&filename); + + if filepath.exists() { + info!("Skipping {} (already exists)", filepath.display()); + if let Some(pb) = progress_bar.as_ref() { + pb.inc(1); + } + return Ok(()); + } + + #[expect(clippy::unwrap_used)] + let url = Url::parse(&format!("{BASE_URL}/{filename}")).unwrap(); + + info!("Downloading {} from {}", filename, url); + + let max_attempts = 5; + 'attempt_loop: for attempt in 1..=max_attempts { + let temp_filepath = filepath.with_extension("parquet.tmp"); + let client = reqwest::Client::new(); + + match client.get(url.clone()).send().await { + Ok(response) => { + if !response.status().is_success() { + error!( + "Attempt {}/{}: Server responded with status {} for {}", + attempt, + max_attempts, + response.status(), + url + ); + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + continue; + } + + let total_size = response.content_length().unwrap_or(0); + debug!("Total size for {}: {}", filename, total_size); + + // Create file progress bar + let file_pb = if total_size > 0 + && let Some(mp) = mp.as_ref() + { + Some({ + let pb = mp.add(ProgressBar::new(total_size)); + pb.set_style(progress_bytes()); + pb.set_message(format!("Downloading {}", filename)); + pb + }) + } else { + None + }; + + let mut file = match tokio::fs::File::create(&temp_filepath).await { + Ok(file) => file, + Err(e) => { + error!( + "Attempt {}/{}: Failed to create temporary file {}: {}", + attempt, + max_attempts, + temp_filepath.display(), + e + ); + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + continue; + } + }; + + let mut stream = response.bytes_stream(); + let mut downloaded: u64 = 0; + + while let Some(chunk_result) = StreamExt::next(&mut stream).await { + match chunk_result { + Ok(chunk) => { + match tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await { + Ok(_) => { + downloaded += chunk.len() as u64; + if let Some(pb) = &file_pb { + pb.set_position(downloaded); + } + } + Err(e) => { + error!( + "Attempt {}/{}: Failed to write to temporary file {}: {}", + attempt, + max_attempts, + temp_filepath.display(), + e + ); + + // Clean up + let _ = tokio::fs::remove_file(&temp_filepath).await; + if let Some(pb) = &file_pb { + pb.finish_and_clear(); + } + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))) + .await; + continue 'attempt_loop; + } + } + } + + Err(e) => { + error!( + "Attempt {}/{}: Error reading chunk for {}: {}", + attempt, max_attempts, filename, e + ); + + let _ = tokio::fs::remove_file(&temp_filepath).await; + if let Some(pb) = &file_pb { + pb.finish_and_clear(); + } + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + continue 'attempt_loop; + } + } + } + + if let Some(pb) = &file_pb { + pb.finish_and_clear(); + } + + // Atomically rename the temporary file + match tokio::fs::rename(&temp_filepath, &filepath).await { + Ok(_) => { + info!("Successfully downloaded {}", filename); + if let Some(pb) = progress_bar.as_ref() { + pb.inc(1); + } + return Ok(()); + } + Err(e) => { + error!( + "Attempt {}/{}: Failed to rename temporary file {} to {}: {}", + attempt, + max_attempts, + temp_filepath.display(), + filepath.display(), + e + ); + let _ = tokio::fs::remove_file(&temp_filepath).await; // Clean up + } + } + } + Err(e) => { + error!( + "Attempt {}/{}: Failed to download {}: {}", + attempt, max_attempts, filename, e + ); + + // Clean up any partial files + let _ = tokio::fs::remove_file(&temp_filepath).await; + } + } + + if attempt < max_attempts { + let wait_time = 2u64.pow(attempt); + info!("Waiting {} seconds before retry...", wait_time); + tokio::time::sleep(Duration::from_secs(wait_time)).await; + } else { + error!( + "Failed to download {} after {} attempts", + filename, max_attempts + ); + return Err(format!("Failed to download {}", filename)); + } + } + + Err(format!("Failed to download {}", filename)) +} diff --git a/crates/llmfs/src/command/mod.rs b/crates/llmfs/src/command/mod.rs new file mode 100644 index 0000000..290a636 --- /dev/null +++ b/crates/llmfs/src/command/mod.rs @@ -0,0 +1,26 @@ +mod download; +mod train_tokenizer; + +#[derive(Debug, clap::Subcommand)] +pub enum SubCommand { + /// Download training set + Download { + #[command(flatten)] + args: download::DownloadArgs, + }, + + /// Train tokenizer + TrainTokenizer { + #[command(flatten)] + args: train_tokenizer::TrainTokenizerArgs, + }, +} + +impl SubCommand { + pub fn run(self, mp: Option) -> anyhow::Result<()> { + match self { + Self::Download { args } => args.run(mp), + Self::TrainTokenizer { args } => args.run(mp), + } + } +} diff --git a/crates/llmfs/src/command/train_tokenizer.rs b/crates/llmfs/src/command/train_tokenizer.rs new file mode 100644 index 0000000..8e74b8d --- /dev/null +++ b/crates/llmfs/src/command/train_tokenizer.rs @@ -0,0 +1,61 @@ +use anyhow::{Context, Result}; +use clap::Args; +use indicatif::MultiProgress; +use rayon::ThreadPoolBuilder; +use std::fs::File; +use std::path::PathBuf; +use tokenizer::Tokenizer; +use tracing::info; + +use crate::data_reader::DataReader; + +#[derive(Debug, Args, Clone)] + +pub struct TrainTokenizerArgs { + /// Path to training data + #[clap(default_value = "data")] + data_dir: PathBuf, + + /// Where to save tokenizer + #[clap(default_value = "tokenizer.json")] + target: PathBuf, + + /// Only train on the first n texts + #[clap(long)] + first_n: Option, + + /// Number of threads to use for training + #[clap(long, default_value = "0")] + threads: usize, + + /// Tokenizer vocabulary size + #[clap(long, default_value = "65535")] + n_tokens: u32, +} + +impl TrainTokenizerArgs { + pub fn run(self, mp: Option) -> Result<()> { + let iter = DataReader::new(5, &self.data_dir).context("while initializing data reader")?; + + #[expect(clippy::unwrap_used)] // Lazy error handling + let iter = iter.map(|x| x.unwrap()); + + let pool = ThreadPoolBuilder::new() + .num_threads(self.threads) + .build() + .context("while building thread pool")?; + + let tokenizer = pool + .install(|| match self.first_n { + Some(n) => Tokenizer::train(mp, iter.take(n), self.n_tokens), + None => Tokenizer::train(mp, iter, self.n_tokens), + }) + .context("while training tokenizer")?; + + info!("Saving to {}", self.target.display()); + let target = File::create(&self.target).context("while creating tokenizer file")?; + serde_json::to_writer(target, &tokenizer).context("while saving tokenizer")?; + + Ok(()) + } +} diff --git a/crates/llmfs/src/data_reader.rs b/crates/llmfs/src/data_reader.rs new file mode 100644 index 0000000..7e1188d --- /dev/null +++ b/crates/llmfs/src/data_reader.rs @@ -0,0 +1,187 @@ +use anyhow::Result; +use parking_lot::Mutex; +use parquet::errors::ParquetError; +use parquet::file::reader::{FileReader, SerializedFileReader}; +use parquet::record::RowAccessor; +use std::fs::File; +use std::path::Path; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::Receiver; +use std::sync::{Arc, mpsc}; +use thiserror::Error; +use tracing::{debug, trace}; + +#[derive(Debug, Error)] +pub enum DataReaderError { + #[error("i/o error: {0}")] + IoError(#[from] std::io::Error), + + #[error("parquet error: {0}")] + ParquetError(#[from] ParquetError), +} + +/// A data reader that uses `n` threads to read +/// rows from a directory of parquet files. +/// +/// All parquet files have exactly one text column. +/// No promises about this struct's behavior if this is not the case. +pub struct DataReader { + rx: Receiver>, + total_rows: usize, + consumed_rows: AtomicUsize, +} + +impl DataReader { + /// Create a new [DataReader] that reads all `*.parquet` files in the given directory. + /// All other files are ignored. + pub fn new(threads: usize, data_dir: impl AsRef) -> Result { + let (files, total_rows) = { + let mut files = Vec::new(); + let mut total_rows = 0usize; + let entries = std::fs::read_dir(data_dir)?; + for entry in entries { + let entry = entry?; + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("parquet") { + // Count rows in this file + let file = File::open(&path)?; + let reader = SerializedFileReader::new(file) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let metadata = reader.metadata(); + + for row_group_index in 0..metadata.num_row_groups() { + let row_group_metadata = metadata.row_group(row_group_index); + total_rows += row_group_metadata.num_rows() as usize; + } + + files.push(path); + } + } + (Arc::new(Mutex::new(files)), total_rows) + }; + + #[expect(clippy::unwrap_used)] + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(threads) + .thread_name(move |i| format!("DataReader-{i}-of-{threads}")) + .build() + .unwrap(); + + debug!( + message = "Starting DataReader", + threads, + n_files = files.lock().len() + ); + + let buffer = threads.min(10) * 200; + let (tx, rx) = mpsc::sync_channel::>(buffer); + pool.spawn_broadcast(move |_ctx| { + let tx = tx.clone(); + let files = files.clone(); + + 'outer: loop { + let file = match files.lock().pop() { + None => break 'outer, + Some(x) => x, + }; + + trace!("Reading rows from {}", file.display()); + let file = match File::open(&file) { + Ok(x) => x, + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + let reader = match SerializedFileReader::new(file) { + Ok(x) => x, + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + let metadata = reader.metadata(); + + for row_group_index in 0..metadata.num_row_groups() { + let row_group_reader = match reader.get_row_group(row_group_index) { + Ok(x) => x, + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + let row_iter = match row_group_reader.get_row_iter(None) { + Ok(x) => x, + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + for row_result in row_iter { + let row = match row_result { + Ok(x) => x, + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + let text = match row.get_string(0) { + Ok(x) => x.clone(), + Err(err) => { + let _ = tx.send(Err(err.into())); + break 'outer; + } + }; + + if tx.send(Ok(text)).is_err() { + return; + } + } + } + } + }); + + Ok(Self { + rx, + total_rows, + consumed_rows: AtomicUsize::new(0), + }) + } + + /// Get the next available row. + /// Order is arbitrary. + /// Returns `None` when all rows have been read. + pub fn recv(&self) -> Option> { + self.rx.recv().ok() + } + + //pub fn try_recv(&self) -> Result, TryRecvError> { + // self.rx.try_recv() + //} +} + +impl Iterator for DataReader { + type Item = Result; + fn next(&mut self) -> Option { + match self.recv() { + Some(item) => { + self.consumed_rows.fetch_add(1, Ordering::Relaxed); + Some(item) + } + None => None, + } + } + + fn size_hint(&self) -> (usize, Option) { + let consumed = self.consumed_rows.load(Ordering::Relaxed); + let len = self.total_rows.saturating_sub(consumed); + return (len, Some(len)); + } +} + +impl ExactSizeIterator for DataReader {} diff --git a/crates/llmfs/src/logging.rs b/crates/llmfs/src/logging.rs new file mode 100644 index 0000000..7660b9e --- /dev/null +++ b/crates/llmfs/src/logging.rs @@ -0,0 +1,343 @@ +use anyhow::Result; +use clap::{Parser, ValueEnum}; +use indicatif::MultiProgress; +use serde::Deserialize; +use std::{fmt::Display, str::FromStr}; +use tracing_indicatif::IndicatifWriter; +use tracing_subscriber::{ + EnvFilter, Layer, fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt, +}; + +// +// MARK: loglevel +// + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, ValueEnum)] +#[derive(Default)] +pub enum LogLevel { + Trace, + Debug, + #[default] + Info, + Warn, + Error, +} + + +impl Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Trace => write!(f, "trace"), + Self::Debug => write!(f, "debug"), + Self::Info => write!(f, "info"), + Self::Warn => write!(f, "warn"), + Self::Error => write!(f, "error"), + } + } +} + +// +// MARK: logconfig +// + +/// Configures log levels for known sources +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)] +pub struct LoggingConfig { + pub other: LogLevel, + pub silence: LogLevel, + + // Bins + pub nanochat: LogLevel, +} + +impl From for EnvFilter { + fn from(conf: LoggingConfig) -> Self { + // Should never fail + #[expect(clippy::unwrap_used)] + EnvFilter::from_str( + &[ + // + // Silence + // + format!("hyper_util={}", conf.silence), + format!("h2={}", conf.silence), + format!("rustls={}", conf.silence), + format!("tower={}", conf.silence), + format!("html5ever={}", conf.silence), + format!("selectors={}", conf.silence), + format!("wgpu_core={}", conf.silence), + format!("naga={}", conf.silence), + format!("cubecl={}", conf.silence), + // + // Bins + // + format!("nanochat_rs={}", conf.nanochat), + conf.other.to_string(), + ] + .join(","), + ) + .unwrap() + } +} + +// +// MARK: LogCliVQ +// + +/// Provides global -v and -q cli arguments. +/// Use with `#[arg(flatten)]`. +#[derive(Parser, Debug, Clone)] +pub struct LogCliVQ { + /// Increase verbosity (can be repeated) + #[arg(default_value = "0")] + #[arg(short, action = clap::ArgAction::Count,global = true)] + v: u8, + + /// Decrease verbosity (can be repeated) + #[arg(default_value = "0")] + #[arg(short, action = clap::ArgAction::Count, global = true)] + q: u8, +} + +impl LogCliVQ { + pub fn into_preset(self) -> LogFilterPreset { + let level_i: i16 = self.v as i16 - self.q as i16; + + let preset; + if level_i <= -2 { + preset = LogFilterPreset::Error + } else if level_i == -1 { + preset = LogFilterPreset::Warn + } else if level_i == 0 { + preset = LogFilterPreset::Info + } else if level_i == 1 { + preset = LogFilterPreset::Debug + } else if level_i >= 2 { + preset = LogFilterPreset::Trace + } else { + unreachable!() + } + + return preset; + } +} + +// +// MARK: logpreset +// + +/// Provides preset configurations of [LoggingConfig] +#[derive(Debug, Deserialize, Clone, Copy)] +pub enum LogFilterPreset { + /// Standard "error" log level + Error, + + /// Standard "warn" log level + Warn, + + /// Standard "info" log level. + /// This is the default. + Info, + + /// Standard "debug" log level + Debug, + + /// Standard "trace" log level + Trace, +} + +impl Default for LogFilterPreset { + fn default() -> Self { + return Self::Info; + } +} + +impl From for LoggingConfig { + fn from(val: LogFilterPreset) -> Self { + val.get_config() + } +} + +impl LogFilterPreset { + pub fn get_config(&self) -> LoggingConfig { + match self { + Self::Error => LoggingConfig { + other: LogLevel::Error, + silence: LogLevel::Error, + nanochat: LogLevel::Error, + }, + + Self::Warn => LoggingConfig { + other: LogLevel::Warn, + silence: LogLevel::Warn, + nanochat: LogLevel::Warn, + }, + + Self::Info => LoggingConfig { + other: LogLevel::Warn, + silence: LogLevel::Warn, + nanochat: LogLevel::Info, + }, + + Self::Debug => LoggingConfig { + other: LogLevel::Warn, + silence: LogLevel::Warn, + nanochat: LogLevel::Debug, + }, + + Self::Trace => LoggingConfig { + other: LogLevel::Trace, + silence: LogLevel::Warn, + nanochat: LogLevel::Trace, + }, + } + } +} + +// +// MARK: initializer +// + +/// Where to print logs +#[expect(clippy::allow_attributes)] +#[allow(dead_code)] +pub enum LoggingTarget { + /// Send logs to stdout + Stdout { format: LoggingFormat }, + + /// Send logs to stderr + Stderr { format: LoggingFormat }, + + /// Send logs to an IndicatifWriter. + /// + /// This is the same as Stderr { format: Ansi {color:true} }, + /// but uses an indicatifwriter with the given multiprogress. + Indicatif(MultiProgress), +} + +/// How to print logs +#[derive(Debug, Clone, Copy, Deserialize)] +#[derive(Default)] +pub enum LoggingFormat { + #[default] + Ansi, + AnsiNoColor, + Json, +} + + +pub struct LoggingInitializer { + /// Log filter for printed logs + pub preset: LogFilterPreset, + + /// Where to print logs + pub target: LoggingTarget, +} + +impl LoggingInitializer { + pub fn initialize(self) -> Result<()> { + let mut stderr_ansi_layer = None; + let mut stderr_json_layer = None; + let mut stdout_ansi_layer = None; + let mut stdout_json_layer = None; + let mut indicatif_layer = None; + match self.target { + LoggingTarget::Stderr { + format: LoggingFormat::Ansi, + } => { + stderr_ansi_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(true) + .with_writer(std::io::stderr) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Stderr { + format: LoggingFormat::AnsiNoColor, + } => { + stderr_ansi_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(false) + .with_writer(std::io::stderr) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Stderr { + format: LoggingFormat::Json, + } => { + stderr_json_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(false) + .json() + .with_writer(std::io::stderr) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Stdout { + format: LoggingFormat::Ansi, + } => { + stdout_ansi_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(true) + .with_writer(std::io::stdout) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Stdout { + format: LoggingFormat::AnsiNoColor, + } => { + stdout_ansi_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(false) + .with_writer(std::io::stdout) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Stdout { + format: LoggingFormat::Json, + } => { + stdout_json_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(false) + .json() + .with_writer(std::io::stdout) + .with_filter::(self.preset.get_config().into()), + ) + } + + LoggingTarget::Indicatif(mp) => { + let writer: IndicatifWriter = + IndicatifWriter::new(mp); + + indicatif_layer = Some( + tracing_subscriber::fmt::Layer::default() + .without_time() + .with_ansi(true) + .with_writer(writer.make_writer()) + .with_filter::(self.preset.get_config().into()), + ) + } + } + + tracing_subscriber::registry() + .with(stdout_ansi_layer) + .with(stdout_json_layer) + .with(stderr_ansi_layer) + .with(stderr_json_layer) + .with(indicatif_layer) + .init(); + + Ok(()) + } +} diff --git a/crates/llmfs/src/main.rs b/crates/llmfs/src/main.rs new file mode 100644 index 0000000..2297c77 --- /dev/null +++ b/crates/llmfs/src/main.rs @@ -0,0 +1,62 @@ +use clap::Parser; +use indicatif::MultiProgress; +use tracing::error; + +use crate::{ + command::SubCommand, + logging::{LogCliVQ, LoggingFormat, LoggingInitializer, LoggingTarget}, +}; + +mod cli; +mod command; +mod data_reader; +mod logging; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None, styles=crate::cli::clap_styles())] +struct Cli { + #[clap(flatten)] + vq: LogCliVQ, + + /// If true, never show progress bars. + #[arg(long)] + noprogress: bool, + + #[command(subcommand)] + command: SubCommand, +} + +fn main() { + let cli = Cli::parse(); + let mp = (!cli.noprogress).then(MultiProgress::new); + + { + let res = LoggingInitializer { + preset: cli.vq.into_preset(), + target: match mp.clone() { + None => LoggingTarget::Stderr { + format: LoggingFormat::Ansi, + }, + Some(mp) => LoggingTarget::Indicatif(mp.clone()), + }, + } + .initialize(); + + if let Err(e) = res { + #[expect(clippy::print_stderr)] + for e in e.chain() { + eprintln!("{e}"); + } + + std::process::exit(1); + } + } + + if let Err(e) = cli.command.run(mp) { + error!("Error: {e:?}"); + for e in e.chain() { + error!(e); + } + std::process::exit(1); + } +}