Files
nanochat-rs/src/command/download.rs
2025-10-16 07:03:54 -07:00

281 lines
7.2 KiB
Rust

use anyhow::Result;
use clap::Args;
use futures_util::StreamExt;
use indicatif::MultiProgress;
use indicatif::ProgressBar;
use rayon::prelude::*;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::runtime::Runtime;
use tracing::{debug, error, info};
use url::Url;
use crate::cli::{progress_big, progress_bytes};
const BASE_URL: &str =
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main";
const MAX_SHARD: usize = 1822;
#[derive(Debug, Args, Clone)]
pub struct DownloadArgs {
/// Training data dir
#[clap(default_value = "data")]
data_dir: PathBuf,
/// Number of shards to download (-1 for all)
#[arg(short = 'n', long, default_value = "-1")]
num_files: isize,
/// Number of parallel downloads
#[arg(short = 't', long, default_value = "8")]
threads: usize,
}
impl DownloadArgs {
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
info!("Downloading files from {BASE_URL}");
fs::create_dir_all(&self.data_dir)?;
let num_shards_to_download = if self.num_files == -1 {
MAX_SHARD + 1
} else {
self.num_files.min((MAX_SHARD + 1) as isize) as usize
};
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
info!("Downloading {} shards...", ids_to_download.len(),);
info!("Target directory: {}", self.data_dir.display());
let main_pb = mp.as_ref().map(|mp| {
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
pb.set_style(progress_big());
pb.set_message("Downloading training data");
pb.enable_steady_tick(Duration::from_millis(100));
pb
});
let rt = Runtime::new()?;
let (tx, rx) = std::sync::mpsc::channel();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(self.threads)
.build()?;
pool.install(|| {
ids_to_download
.into_par_iter()
.for_each_with(tx, |tx, index| {
let target = self.data_dir.clone();
let main_pb = main_pb.clone();
let mp_clone = mp.clone();
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread
let result = rt_handle.block_on(async {
download_single_file(index, &target, main_pb, mp_clone).await
});
// Send the result back to the main thread for aggregation
#[expect(clippy::unwrap_used)]
tx.send(result).unwrap();
});
});
// Wait for all downloads to finish and collect results
let mut successful_downloads = 0;
for _ in 0..num_shards_to_download {
if let Ok(Ok(_)) = rx.recv() {
// Receive the Result<(), String> wrapped in a Result from MPSC
successful_downloads += 1;
}
}
if let Some(pb) = main_pb.as_ref() {
pb.finish_and_clear();
info!("Downloads complete ({successful_downloads} successful)");
}
return Ok(());
}
}
async fn download_single_file(
index: usize,
target: &Path,
progress_bar: Option<ProgressBar>,
mp: Option<MultiProgress>,
) -> Result<(), String> {
let filename = format!("shard_{:05}.parquet", index);
let filepath = target.join(&filename);
if filepath.exists() {
info!("Skipping {} (already exists)", filepath.display());
if let Some(pb) = progress_bar.as_ref() {
pb.inc(1);
}
return Ok(());
}
#[expect(clippy::unwrap_used)]
let url = Url::parse(&format!("{BASE_URL}/{filename}")).unwrap();
info!("Downloading {} from {}", filename, url);
let max_attempts = 5;
'attempt_loop: for attempt in 1..=max_attempts {
let temp_filepath = filepath.with_extension("parquet.tmp");
let client = reqwest::Client::new();
match client.get(url.clone()).send().await {
Ok(response) => {
if !response.status().is_success() {
error!(
"Attempt {}/{}: Server responded with status {} for {}",
attempt,
max_attempts,
response.status(),
url
);
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue;
}
let total_size = response.content_length().unwrap_or(0);
debug!("Total size for {}: {}", filename, total_size);
// Create file progress bar
let file_pb = if total_size > 0
&& let Some(mp) = mp.as_ref()
{
Some({
let pb = mp.add(ProgressBar::new(total_size));
pb.set_style(progress_bytes());
pb.set_message(format!("Downloading {}", filename));
pb
})
} else {
None
};
let mut file = match tokio::fs::File::create(&temp_filepath).await {
Ok(file) => file,
Err(e) => {
error!(
"Attempt {}/{}: Failed to create temporary file {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
e
);
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue;
}
};
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
while let Some(chunk_result) = StreamExt::next(&mut stream).await {
match chunk_result {
Ok(chunk) => {
match tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await {
Ok(_) => {
downloaded += chunk.len() as u64;
if let Some(pb) = &file_pb {
pb.set_position(downloaded);
}
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to write to temporary file {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
e
);
// Clean up
let _ = tokio::fs::remove_file(&temp_filepath).await;
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt)))
.await;
continue 'attempt_loop;
}
}
}
Err(e) => {
error!(
"Attempt {}/{}: Error reading chunk for {}: {}",
attempt, max_attempts, filename, e
);
let _ = tokio::fs::remove_file(&temp_filepath).await;
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
continue 'attempt_loop;
}
}
}
if let Some(pb) = &file_pb {
pb.finish_and_clear();
}
// Atomically rename the temporary file
match tokio::fs::rename(&temp_filepath, &filepath).await {
Ok(_) => {
info!("Successfully downloaded {}", filename);
if let Some(pb) = progress_bar.as_ref() {
pb.inc(1);
}
return Ok(());
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to rename temporary file {} to {}: {}",
attempt,
max_attempts,
temp_filepath.display(),
filepath.display(),
e
);
let _ = tokio::fs::remove_file(&temp_filepath).await; // Clean up
}
}
}
Err(e) => {
error!(
"Attempt {}/{}: Failed to download {}: {}",
attempt, max_attempts, filename, e
);
// Clean up any partial files
let _ = tokio::fs::remove_file(&temp_filepath).await;
}
}
if attempt < max_attempts {
let wait_time = 2u64.pow(attempt);
info!("Waiting {} seconds before retry...", wait_time);
tokio::time::sleep(Duration::from_secs(wait_time)).await;
} else {
error!(
"Failed to download {} after {} attempts",
filename, max_attempts
);
return Err(format!("Failed to download {}", filename));
}
}
Err(format!("Failed to download {}", filename))
}