281 lines
7.2 KiB
Rust
281 lines
7.2 KiB
Rust
use anyhow::Result;
|
|
use clap::Args;
|
|
use futures_util::StreamExt;
|
|
use indicatif::MultiProgress;
|
|
use indicatif::ProgressBar;
|
|
use rayon::prelude::*;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
use std::time::Duration;
|
|
use tokio::runtime::Runtime;
|
|
use tracing::{debug, error, info};
|
|
use url::Url;
|
|
|
|
use crate::cli::{progress_big, progress_bytes};
|
|
|
|
const BASE_URL: &str =
|
|
"https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main";
|
|
|
|
const MAX_SHARD: usize = 1822;
|
|
|
|
#[derive(Debug, Args, Clone)]
|
|
|
|
pub struct DownloadArgs {
|
|
/// Training data dir
|
|
#[clap(default_value = "data")]
|
|
data_dir: PathBuf,
|
|
|
|
/// Number of shards to download (-1 for all)
|
|
#[arg(short = 'n', long, default_value = "-1")]
|
|
num_files: isize,
|
|
|
|
/// Number of parallel downloads
|
|
#[arg(short = 't', long, default_value = "8")]
|
|
threads: usize,
|
|
}
|
|
|
|
impl DownloadArgs {
|
|
pub fn run(self, mp: Option<MultiProgress>) -> Result<()> {
|
|
info!("Downloading files from {BASE_URL}");
|
|
fs::create_dir_all(&self.data_dir)?;
|
|
|
|
let num_shards_to_download = if self.num_files == -1 {
|
|
MAX_SHARD + 1
|
|
} else {
|
|
self.num_files.min((MAX_SHARD + 1) as isize) as usize
|
|
};
|
|
|
|
let ids_to_download: Vec<usize> = (0..num_shards_to_download).collect();
|
|
|
|
info!("Downloading {} shards...", ids_to_download.len(),);
|
|
info!("Target directory: {}", self.data_dir.display());
|
|
|
|
let main_pb = mp.as_ref().map(|mp| {
|
|
let pb = mp.add(ProgressBar::new(ids_to_download.len() as u64));
|
|
pb.set_style(progress_big());
|
|
pb.set_message("Downloading training data");
|
|
pb.enable_steady_tick(Duration::from_millis(100));
|
|
pb
|
|
});
|
|
|
|
let rt = Runtime::new()?;
|
|
|
|
let (tx, rx) = std::sync::mpsc::channel();
|
|
|
|
let pool = rayon::ThreadPoolBuilder::new()
|
|
.num_threads(self.threads)
|
|
.build()?;
|
|
|
|
pool.install(|| {
|
|
ids_to_download
|
|
.into_par_iter()
|
|
.for_each_with(tx, |tx, index| {
|
|
let target = self.data_dir.clone();
|
|
let main_pb = main_pb.clone();
|
|
let mp_clone = mp.clone();
|
|
let rt_handle = rt.handle().clone(); // Clone the runtime handle for each thread
|
|
|
|
let result = rt_handle.block_on(async {
|
|
download_single_file(index, &target, main_pb, mp_clone).await
|
|
});
|
|
|
|
// Send the result back to the main thread for aggregation
|
|
#[expect(clippy::unwrap_used)]
|
|
tx.send(result).unwrap();
|
|
});
|
|
});
|
|
|
|
// Wait for all downloads to finish and collect results
|
|
let mut successful_downloads = 0;
|
|
for _ in 0..num_shards_to_download {
|
|
if let Ok(Ok(_)) = rx.recv() {
|
|
// Receive the Result<(), String> wrapped in a Result from MPSC
|
|
successful_downloads += 1;
|
|
}
|
|
}
|
|
|
|
if let Some(pb) = main_pb.as_ref() {
|
|
pb.finish_and_clear();
|
|
info!("Downloads complete ({successful_downloads} successful)");
|
|
}
|
|
|
|
return Ok(());
|
|
}
|
|
}
|
|
|
|
async fn download_single_file(
|
|
index: usize,
|
|
target: &Path,
|
|
progress_bar: Option<ProgressBar>,
|
|
mp: Option<MultiProgress>,
|
|
) -> Result<(), String> {
|
|
let filename = format!("shard_{:05}.parquet", index);
|
|
let filepath = target.join(&filename);
|
|
|
|
if filepath.exists() {
|
|
info!("Skipping {} (already exists)", filepath.display());
|
|
if let Some(pb) = progress_bar.as_ref() {
|
|
pb.inc(1);
|
|
}
|
|
return Ok(());
|
|
}
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
let url = Url::parse(&format!("{BASE_URL}/{filename}")).unwrap();
|
|
|
|
info!("Downloading {} from {}", filename, url);
|
|
|
|
let max_attempts = 5;
|
|
'attempt_loop: for attempt in 1..=max_attempts {
|
|
let temp_filepath = filepath.with_extension("parquet.tmp");
|
|
let client = reqwest::Client::new();
|
|
|
|
match client.get(url.clone()).send().await {
|
|
Ok(response) => {
|
|
if !response.status().is_success() {
|
|
error!(
|
|
"Attempt {}/{}: Server responded with status {} for {}",
|
|
attempt,
|
|
max_attempts,
|
|
response.status(),
|
|
url
|
|
);
|
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
|
continue;
|
|
}
|
|
|
|
let total_size = response.content_length().unwrap_or(0);
|
|
debug!("Total size for {}: {}", filename, total_size);
|
|
|
|
// Create file progress bar
|
|
let file_pb = if total_size > 0
|
|
&& let Some(mp) = mp.as_ref()
|
|
{
|
|
Some({
|
|
let pb = mp.add(ProgressBar::new(total_size));
|
|
pb.set_style(progress_bytes());
|
|
pb.set_message(format!("Downloading {}", filename));
|
|
pb
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut file = match tokio::fs::File::create(&temp_filepath).await {
|
|
Ok(file) => file,
|
|
Err(e) => {
|
|
error!(
|
|
"Attempt {}/{}: Failed to create temporary file {}: {}",
|
|
attempt,
|
|
max_attempts,
|
|
temp_filepath.display(),
|
|
e
|
|
);
|
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let mut stream = response.bytes_stream();
|
|
let mut downloaded: u64 = 0;
|
|
|
|
while let Some(chunk_result) = StreamExt::next(&mut stream).await {
|
|
match chunk_result {
|
|
Ok(chunk) => {
|
|
match tokio::io::AsyncWriteExt::write_all(&mut file, &chunk).await {
|
|
Ok(_) => {
|
|
downloaded += chunk.len() as u64;
|
|
if let Some(pb) = &file_pb {
|
|
pb.set_position(downloaded);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"Attempt {}/{}: Failed to write to temporary file {}: {}",
|
|
attempt,
|
|
max_attempts,
|
|
temp_filepath.display(),
|
|
e
|
|
);
|
|
|
|
// Clean up
|
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
|
if let Some(pb) = &file_pb {
|
|
pb.finish_and_clear();
|
|
}
|
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt)))
|
|
.await;
|
|
continue 'attempt_loop;
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(e) => {
|
|
error!(
|
|
"Attempt {}/{}: Error reading chunk for {}: {}",
|
|
attempt, max_attempts, filename, e
|
|
);
|
|
|
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
|
if let Some(pb) = &file_pb {
|
|
pb.finish_and_clear();
|
|
}
|
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
|
continue 'attempt_loop;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(pb) = &file_pb {
|
|
pb.finish_and_clear();
|
|
}
|
|
|
|
// Atomically rename the temporary file
|
|
match tokio::fs::rename(&temp_filepath, &filepath).await {
|
|
Ok(_) => {
|
|
info!("Successfully downloaded {}", filename);
|
|
if let Some(pb) = progress_bar.as_ref() {
|
|
pb.inc(1);
|
|
}
|
|
return Ok(());
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"Attempt {}/{}: Failed to rename temporary file {} to {}: {}",
|
|
attempt,
|
|
max_attempts,
|
|
temp_filepath.display(),
|
|
filepath.display(),
|
|
e
|
|
);
|
|
let _ = tokio::fs::remove_file(&temp_filepath).await; // Clean up
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!(
|
|
"Attempt {}/{}: Failed to download {}: {}",
|
|
attempt, max_attempts, filename, e
|
|
);
|
|
|
|
// Clean up any partial files
|
|
let _ = tokio::fs::remove_file(&temp_filepath).await;
|
|
}
|
|
}
|
|
|
|
if attempt < max_attempts {
|
|
let wait_time = 2u64.pow(attempt);
|
|
info!("Waiting {} seconds before retry...", wait_time);
|
|
tokio::time::sleep(Duration::from_secs(wait_time)).await;
|
|
} else {
|
|
error!(
|
|
"Failed to download {} after {} attempts",
|
|
filename, max_attempts
|
|
);
|
|
return Err(format!("Failed to download {}", filename));
|
|
}
|
|
}
|
|
|
|
Err(format!("Failed to download {}", filename))
|
|
}
|