use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region}; use chrono::{DateTime, Utc}; use pile_config::{ Label, S3Credentials, pattern::{GroupPattern, GroupSegment}, }; use smartstring::{LazyCompact, SmartString}; use std::{ collections::{HashMap, HashSet}, sync::{Arc, OnceLock}, }; use crate::{ extract::traits::ExtractState, source::DataSource, value::{Item, PileValue}, }; #[derive(Debug)] pub struct S3DataSource { pub name: Label, pub bucket: SmartString, pub prefix: Option>, pub client: Arc, pub pattern: GroupPattern, pub index: OnceLock, Item>>, } impl S3DataSource { pub async fn new( name: &Label, bucket: String, prefix: Option, endpoint: Option, region: String, credentials: &S3Credentials, pattern: GroupPattern, ) -> Result, std::io::Error> { let client = { let creds = Credentials::new( &credentials.access_key_id, &credentials.secret_access_key, None, None, "pile", ); let mut s3_config = aws_sdk_s3::config::Builder::new() .behavior_version(BehaviorVersion::latest()) .region(Region::new(region)) .credentials_provider(creds); if let Some(ep) = endpoint { s3_config = s3_config.endpoint_url(ep).force_path_style(true); } aws_sdk_s3::Client::from_conf(s3_config.build()) }; let source = Arc::new(Self { name: name.clone(), bucket: bucket.into(), prefix: prefix.map(|x| x.into()), client: Arc::new(client), pattern, index: OnceLock::new(), }); // // MARK: list keys // let mut all_keys: HashSet> = HashSet::new(); let mut continuation_token: Option = None; loop { let mut req = source .client .list_objects_v2() .bucket(source.bucket.as_str()); if let Some(prefix) = &source.prefix { req = req.prefix(prefix.as_str()); } if let Some(token) = continuation_token { req = req.continuation_token(token); } let resp = req.send().await.map_err(std::io::Error::other)?; let next_token = resp.next_continuation_token().map(ToOwned::to_owned); let is_truncated = resp.is_truncated().unwrap_or(false); for obj in resp.contents() { let Some(full_key) = obj.key() else { continue }; let key = strip_prefix(full_key, source.prefix.as_deref()); all_keys.insert(key.into()); } if !is_truncated { break; } continuation_token = next_token; } // // MARK: resolve groups // let mut keys_grouped: HashSet> = HashSet::new(); for key in &all_keys { let groups = resolve_groups(&source.pattern, key).await; for group_key in groups.into_values() { if all_keys.contains(&group_key) { keys_grouped.insert(group_key); } } } let mut index = HashMap::new(); for key in all_keys.difference(&keys_grouped) { let groups = resolve_groups(&source.pattern, key).await; let group = groups .into_iter() .filter(|(_, gk)| all_keys.contains(gk)) .map(|(label, gk)| { ( label, Box::new(Item::S3 { source: Arc::clone(&source), mime: mime_guess::from_path(gk.as_str()).first_or_octet_stream(), key: gk, group: Arc::new(HashMap::new()), }), ) }) .collect::>(); let item = Item::S3 { source: Arc::clone(&source), mime: mime_guess::from_path(key.as_str()).first_or_octet_stream(), key: key.clone(), group: Arc::new(group), }; index.insert(item.key(), item); } source.index.get_or_init(|| index); Ok(source) } } impl DataSource for Arc { #[expect(clippy::expect_used)] async fn get(&self, key: &str) -> Result, std::io::Error> { return Ok(self .index .get() .expect("index should be initialized") .get(key) .cloned()); } #[expect(clippy::expect_used)] fn iter(&self) -> impl Iterator { self.index .get() .expect("index should be initialized") .values() } async fn latest_change(&self) -> Result>, std::io::Error> { let mut ts: Option> = None; let mut continuation_token: Option = None; loop { let mut req = self.client.list_objects_v2().bucket(self.bucket.as_str()); if let Some(prefix) = &self.prefix { req = req.prefix(prefix.as_str()); } if let Some(token) = continuation_token { req = req.continuation_token(token); } let resp = match req.send().await { Err(_) => return Ok(None), Ok(resp) => resp, }; let next_token = resp.next_continuation_token().map(ToOwned::to_owned); let is_truncated = resp.is_truncated().unwrap_or(false); for obj in resp.contents() { if let Some(last_modified) = obj.last_modified() { let dt = DateTime::from_timestamp( last_modified.secs(), last_modified.subsec_nanos(), ); if let Some(dt) = dt { ts = Some(match ts { None => dt, Some(prev) => prev.max(dt), }); } } } if !is_truncated { break; } continuation_token = next_token; } Ok(ts) } } fn strip_prefix<'a>(key: &'a str, prefix: Option<&str>) -> &'a str { match prefix { None => key, Some(p) => { let with_slash = if p.ends_with('/') { key.strip_prefix(p) } else { key.strip_prefix(&format!("{p}/")) }; with_slash.unwrap_or(key) } } } async fn resolve_groups( pattern: &GroupPattern, key: &str, ) -> HashMap> { let state = ExtractState { ignore_mime: false }; let mut group = HashMap::new(); 'pattern: for (l, pat) in &pattern.pattern { let item = PileValue::String(Arc::new(key.into())); let mut target = String::new(); for p in pat { match p { GroupSegment::Literal(x) => target.push_str(x), GroupSegment::Path(op) => { let res = match item.query(&state, op).await { Ok(Some(x)) => x, _ => continue 'pattern, }; let res = match res.as_str() { Some(x) => x, None => continue 'pattern, }; target.push_str(res); } } } group.insert(l.clone(), target.into()); } return group; }