diff --git a/Cargo.lock b/Cargo.lock index 1b7bb95..cb363dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2485,6 +2485,7 @@ version = "0.0.2" dependencies = [ "anstyle", "anyhow", + "aws-sdk-s3", "axum", "clap", "indicatif", diff --git a/crates/pile/Cargo.toml b/crates/pile/Cargo.toml index e8a9e89..10f26cc 100644 --- a/crates/pile/Cargo.toml +++ b/crates/pile/Cargo.toml @@ -13,6 +13,7 @@ pile-dataset = { workspace = true, features = ["axum", "pdfium"] } pile-value = { workspace = true, features = ["pdfium"] } pile-config = { workspace = true } +aws-sdk-s3 = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } tokio = { workspace = true } diff --git a/crates/pile/src/cli.rs b/crates/pile/src/cli.rs index 85e3783..b74ddbb 100644 --- a/crates/pile/src/cli.rs +++ b/crates/pile/src/cli.rs @@ -1,4 +1,5 @@ use anstyle::{AnsiColor, Color, Style}; +use indicatif::ProgressStyle; pub fn clap_styles() -> clap::builder::Styles { clap::builder::Styles::styled() @@ -36,7 +37,6 @@ pub fn clap_styles() -> clap::builder::Styles { .placeholder(Style::new().fg_color(Some(Color::Ansi(AnsiColor::White)))) } -/* #[expect(clippy::unwrap_used)] pub fn progress_big() -> ProgressStyle { return ProgressStyle::default_bar() @@ -50,6 +50,7 @@ pub fn progress_big() -> ProgressStyle { ]); } +/* #[expect(clippy::unwrap_used)] pub fn spinner_small() -> ProgressStyle { return ProgressStyle::default_bar() diff --git a/crates/pile/src/command/mod.rs b/crates/pile/src/command/mod.rs index 7f1ab46..050ff24 100644 --- a/crates/pile/src/command/mod.rs +++ b/crates/pile/src/command/mod.rs @@ -13,6 +13,7 @@ mod list; mod lookup; mod probe; mod serve; +mod upload; use crate::{Cli, GlobalContext}; @@ -60,7 +61,7 @@ pub enum SubCommand { }, /// Print an overview of all fields present in this dataset - Overview { + Fields { #[command(flatten)] cmd: fields::FieldsCommand, }, @@ -76,6 +77,12 @@ pub enum SubCommand { #[command(flatten)] cmd: serve::ServeCommand, }, + + /// Upload a filesystem source to an S3 source + Upload { + #[command(flatten)] + cmd: upload::UploadCommand, + }, } impl CliCmdDispatch for SubCommand { @@ -87,9 +94,10 @@ impl CliCmdDispatch for SubCommand { Self::Index { cmd } => cmd.start(ctx), Self::List { cmd } => cmd.start(ctx), Self::Lookup { cmd } => cmd.start(ctx), - Self::Overview { cmd } => cmd.start(ctx), + Self::Fields { cmd } => cmd.start(ctx), Self::Probe { cmd } => cmd.start(ctx), Self::Serve { cmd } => cmd.start(ctx), + Self::Upload { cmd } => cmd.start(ctx), Self::Docs {} => { print_help_recursively(&mut Cli::command(), None); diff --git a/crates/pile/src/command/upload.rs b/crates/pile/src/command/upload.rs new file mode 100644 index 0000000..d4d97dd --- /dev/null +++ b/crates/pile/src/command/upload.rs @@ -0,0 +1,272 @@ +use anyhow::{Context, Result}; +use aws_sdk_s3::primitives::ByteStream; +use clap::Args; +use indicatif::ProgressBar; +use pile_config::Label; +use pile_dataset::{Dataset, Datasets}; +use pile_toolbox::cancelabletask::{CancelFlag, CancelableTaskError}; +use pile_value::source::{DataSource, DirDataSource, S3DataSource}; +use std::{path::PathBuf, sync::Arc, time::Duration}; +use tokio::task::JoinSet; +use tokio_stream::StreamExt; +use tracing::info; + +use crate::{CliCmd, GlobalContext, cli::progress_big}; + +#[derive(Debug, Args)] +pub struct UploadCommand { + /// Name of the filesystem source to upload from + dir_source: String, + + /// Name of the S3 source to upload to + s3_source: String, + + /// Prefix path under the S3 source to upload files to + prefix: String, + + /// Path to dataset config + #[arg(long, short = 'c', default_value = "./pile.toml")] + config: PathBuf, + + /// Override the S3 bucket from pile.toml + #[arg(long)] + bucket: Option, + + /// Allow overwriting files that already exist at the target prefix + #[arg(long)] + overwrite: bool, + + /// Delete all files at the target prefix before uploading + #[arg(long)] + delete_existing_forever: bool, + + /// Number of parallel upload jobs + #[arg(long, short = 'j', default_value = "5")] + jobs: usize, +} + +impl CliCmd for UploadCommand { + async fn run( + self, + ctx: GlobalContext, + flag: CancelFlag, + ) -> Result> { + let ds = Datasets::open(&self.config) + .with_context(|| format!("while opening dataset for {}", self.config.display()))?; + + let dir_label = Label::new(&self.dir_source) + .ok_or_else(|| anyhow::anyhow!("invalid source name: {}", self.dir_source))?; + let s3_label = Label::new(&self.s3_source) + .ok_or_else(|| anyhow::anyhow!("invalid source name: {}", self.s3_source))?; + + let dir_ds: Arc = get_dir_source(&ds, &dir_label, &self.dir_source)?; + let s3_ds: Arc = get_s3_source(&ds, &s3_label, &self.s3_source)?; + + let bucket = self + .bucket + .as_deref() + .unwrap_or(s3_ds.bucket.as_str()) + .to_owned(); + let full_prefix = self.prefix.trim_matches('/').to_owned(); + + // Check for existing objects at the target prefix + let existing_keys = list_prefix(&s3_ds.client, &bucket, &full_prefix) + .await + .context("while checking for existing objects at target prefix")?; + + if !existing_keys.is_empty() { + if self.delete_existing_forever { + info!( + "Deleting {} existing object(s) at '{}'", + existing_keys.len(), + full_prefix + ); + for key in &existing_keys { + s3_ds + .client + .delete_object() + .bucket(&bucket) + .key(key) + .send() + .await + .with_context(|| format!("while deleting existing object '{key}'"))?; + } + } else if !self.overwrite { + return Err(anyhow::anyhow!( + "{} file(s) already exist at '{}'. \ + Pass --overwrite to allow overwriting, \ + or --delete-existing-forever to delete them first.", + existing_keys.len(), + full_prefix + ) + .into()); + } + } + + // Count total files before uploading so we can show accurate progress + let total = { + let mut count = 0u64; + let mut count_stream = Arc::clone(&dir_ds).iter(); + while let Some(result) = count_stream.next().await { + result.context("while counting filesystem source")?; + count += 1; + } + count + }; + + // Walk filesystem source and upload files in parallel + let jobs = self.jobs.max(1); + let mut uploaded: u64 = 0; + let mut stream = Arc::clone(&dir_ds).iter(); + let mut join_set: JoinSet> = JoinSet::new(); + + let pb = ctx.mp.add(ProgressBar::new(total)); + pb.set_style(progress_big()); + pb.enable_steady_tick(Duration::from_millis(100)); + pb.set_message(full_prefix.clone()); + + loop { + // Drain completed tasks before checking for cancellation or new work + while join_set.len() >= jobs { + match join_set.join_next().await { + Some(Ok(Ok(key))) => { + info!("Uploaded {key}"); + pb.set_message(key); + pb.inc(1); + uploaded += 1; + } + Some(Ok(Err(e))) => return Err(e.into()), + Some(Err(e)) => return Err(anyhow::anyhow!("upload task panicked: {e}").into()), + None => break, + } + } + + if flag.is_cancelled() { + join_set.abort_all(); + return Err(CancelableTaskError::Cancelled); + } + + let item = match stream.next().await { + None => break, + Some(Err(e)) => { + return Err(anyhow::Error::from(e) + .context("while iterating filesystem source") + .into()); + } + Some(Ok(item)) => item, + }; + + let item_path = PathBuf::from(item.key().as_str()); + let relative = item_path.strip_prefix(&dir_ds.dir).with_context(|| { + format!("path '{}' is not under source root", item_path.display()) + })?; + let relative_str = relative + .to_str() + .ok_or_else(|| anyhow::anyhow!("non-UTF-8 path: {}", item_path.display()))? + .to_owned(); + + let key = format!("{full_prefix}/{relative_str}"); + let mime = item.mime().to_string(); + let client = Arc::clone(&s3_ds.client); + let bucket = bucket.clone(); + + join_set.spawn(async move { + let body = ByteStream::from_path(&item_path) + .await + .with_context(|| format!("while opening '{}'", item_path.display()))?; + + client + .put_object() + .bucket(&bucket) + .key(&key) + .content_type(&mime) + .body(body) + .send() + .await + .with_context(|| { + format!("while uploading '{}' to '{key}'", item_path.display()) + })?; + + Ok(key) + }); + } + + // Drain remaining tasks + while let Some(result) = join_set.join_next().await { + match result { + Ok(Ok(key)) => { + info!("Uploaded {key}"); + pb.set_message(key); + pb.inc(1); + uploaded += 1; + } + Ok(Err(e)) => return Err(e.into()), + Err(e) => return Err(anyhow::anyhow!("upload task panicked: {e}").into()), + } + } + + pb.finish_and_clear(); + info!("Done: uploaded {uploaded} file(s) to '{full_prefix}'"); + Ok(0) + } +} + +fn get_dir_source( + ds: &Datasets, + label: &Label, + name: &str, +) -> Result, anyhow::Error> { + match ds.sources.get(label) { + Some(Dataset::Dir(d)) => Ok(Arc::clone(d)), + Some(_) => Err(anyhow::anyhow!( + "source '{name}' is not a filesystem source" + )), + None => Err(anyhow::anyhow!("source '{name}' not found in config")), + } +} + +fn get_s3_source( + ds: &Datasets, + label: &Label, + name: &str, +) -> Result, anyhow::Error> { + match ds.sources.get(label) { + Some(Dataset::S3(s)) => Ok(Arc::clone(s)), + Some(_) => Err(anyhow::anyhow!("source '{name}' is not an S3 source")), + None => Err(anyhow::anyhow!("source '{name}' not found in config")), + } +} + +/// List all S3 object keys under the given prefix. +async fn list_prefix( + client: &aws_sdk_s3::Client, + bucket: &str, + prefix: &str, +) -> Result> { + let mut keys = Vec::new(); + let mut continuation_token: Option = None; + + loop { + let mut req = client.list_objects_v2().bucket(bucket).prefix(prefix); + + if let Some(token) = continuation_token { + req = req.continuation_token(token); + } + + let resp = req.send().await.context("list_objects_v2 failed")?; + + for obj in resp.contents() { + if let Some(k) = obj.key() { + keys.push(k.to_owned()); + } + } + + if !resp.is_truncated().unwrap_or(false) { + break; + } + + continuation_token = resp.next_continuation_token().map(ToOwned::to_owned); + } + + Ok(keys) +} diff --git a/crates/pile/src/main.rs b/crates/pile/src/main.rs index cfedace..7f219b4 100644 --- a/crates/pile/src/main.rs +++ b/crates/pile/src/main.rs @@ -35,8 +35,7 @@ struct Cli { #[derive(Clone)] pub struct GlobalContext { - #[expect(dead_code)] - mp: MultiProgress, + pub mp: MultiProgress, } fn main() -> ExitCode {