use chrono::{DateTime, Utc}; use pile_config::{ Label, pattern::{GroupPattern, GroupSegment}, }; use pile_io::S3Client; use smartstring::{LazyCompact, SmartString}; use std::{ collections::{BTreeMap, HashMap, HashSet}, sync::{Arc, OnceLock}, }; use crate::{ extract::traits::ExtractState, source::DataSource, value::{Item, PileValue}, }; #[derive(Debug)] pub struct S3DataSource { pub name: Label, pub client: Arc, pub prefix: Option>, pub pattern: GroupPattern, pub encryption_key: Option<[u8; 32]>, pub index: OnceLock, Item>>, } impl S3DataSource { pub async fn new( name: &Label, bucket: &str, prefix: Option<&str>, endpoint: Option<&str>, region: &str, access_key_id: &str, secret_access_key: &str, cache_limit_bytes: usize, pattern: GroupPattern, encryption_key: Option<[u8; 32]>, ) -> Result, std::io::Error> { let client = S3Client::new( bucket, endpoint, region, access_key_id, secret_access_key, cache_limit_bytes, ) .await; let source = Arc::new(Self { name: name.clone(), client, prefix: prefix.map(|x| x.into()), pattern, encryption_key, index: OnceLock::new(), }); // // MARK: list keys // let mut all_keys: HashSet> = HashSet::new(); let mut continuation_token: Option = None; loop { let mut req = source .client .client .list_objects_v2() .bucket(source.client.bucket()); if let Some(prefix) = &source.prefix { req = req.prefix(prefix.as_str()); } if let Some(token) = continuation_token { req = req.continuation_token(token); } let resp = req.send().await.map_err(std::io::Error::other)?; let next_token = resp.next_continuation_token().map(ToOwned::to_owned); let is_truncated = resp.is_truncated().unwrap_or(false); for obj in resp.contents() { let Some(full_key) = obj.key() else { continue }; let raw_key = strip_prefix(full_key, source.prefix.as_deref()); let key = match &source.encryption_key { None => raw_key.into(), Some(enc_key) => match decrypt_path(enc_key, raw_key) { Some(decrypted) => decrypted.into(), None => continue, }, }; all_keys.insert(key); } if !is_truncated { break; } continuation_token = next_token; } // // MARK: resolve groups // let mut keys_grouped: HashSet> = HashSet::new(); for key in &all_keys { let groups = resolve_groups(&source.pattern, key).await; for group_key in groups.into_values() { if all_keys.contains(&group_key) { keys_grouped.insert(group_key); } } } let mut index = BTreeMap::new(); for key in all_keys.difference(&keys_grouped) { let groups = resolve_groups(&source.pattern, key).await; let group = groups .into_iter() .filter(|(_, gk)| all_keys.contains(gk)) .map(|(label, gk)| { ( label, Box::new(Item::S3 { source: Arc::clone(&source), mime: mime_guess::from_path(gk.as_str()).first_or_octet_stream(), key: gk, group: Arc::new(HashMap::new()), }), ) }) .collect::>(); let item = Item::S3 { source: Arc::clone(&source), mime: mime_guess::from_path(key.as_str()).first_or_octet_stream(), key: key.clone(), group: Arc::new(group), }; index.insert(item.key(), item); } source.index.get_or_init(|| index); Ok(source) } } impl DataSource for Arc { #[expect(clippy::expect_used)] fn len(&self) -> usize { self.index.get().expect("index should be initialized").len() } #[expect(clippy::expect_used)] async fn get(&self, key: &str) -> Result, std::io::Error> { return Ok(self .index .get() .expect("index should be initialized") .get(key) .cloned()); } #[expect(clippy::expect_used)] fn iter(&self) -> impl Iterator { self.index .get() .expect("index should be initialized") .values() } async fn latest_change(&self) -> Result>, std::io::Error> { let mut ts: Option> = None; let mut continuation_token: Option = None; loop { let mut req = self .client .client .list_objects_v2() .bucket(self.client.bucket()); if let Some(prefix) = &self.prefix { req = req.prefix(prefix.as_str()); } if let Some(token) = continuation_token { req = req.continuation_token(token); } let resp = match req.send().await { Err(_) => return Ok(None), Ok(resp) => resp, }; let next_token = resp.next_continuation_token().map(ToOwned::to_owned); let is_truncated = resp.is_truncated().unwrap_or(false); for obj in resp.contents() { if let Some(last_modified) = obj.last_modified() { let dt = DateTime::from_timestamp( last_modified.secs(), last_modified.subsec_nanos(), ); if let Some(dt) = dt { ts = Some(match ts { None => dt, Some(prev) => prev.max(dt), }); } } } if !is_truncated { break; } continuation_token = next_token; } Ok(ts) } } /// Derive an encryption key from a password pub fn string_to_key(password: &str) -> [u8; 32] { blake3::derive_key("pile s3 encryption", password.as_bytes()) } /// Encrypt a logical path to a base64 S3 key using a deterministic nonce. pub fn encrypt_path(enc_key: &[u8; 32], path: &str) -> String { use base64::Engine; use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead}; let hash = blake3::keyed_hash(enc_key, path.as_bytes()); let nonce_bytes = &hash.as_bytes()[..24]; let nonce = XNonce::from_slice(nonce_bytes); let key = chacha20poly1305::Key::from_slice(enc_key); let cipher = XChaCha20Poly1305::new(key); #[expect(clippy::expect_used)] let ciphertext = cipher .encrypt(nonce, path.as_bytes()) .expect("path encryption should not fail"); let mut result = nonce_bytes.to_vec(); result.extend_from_slice(&ciphertext); base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(result) } /// Decrypt a base64 S3 key back to its logical path. fn decrypt_path(enc_key: &[u8; 32], encrypted: &str) -> Option { use base64::Engine; use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead}; let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(encrypted) .ok()?; if bytes.len() < 24 + 16 { return None; } let (nonce_bytes, ciphertext) = bytes.split_at(24); let nonce = XNonce::from_slice(nonce_bytes); let key = chacha20poly1305::Key::from_slice(enc_key); let cipher = XChaCha20Poly1305::new(key); let plaintext = cipher.decrypt(nonce, ciphertext).ok()?; String::from_utf8(plaintext).ok() } fn strip_prefix<'a>(key: &'a str, prefix: Option<&str>) -> &'a str { match prefix { None => key, Some(p) => { let with_slash = if p.ends_with('/') { key.strip_prefix(p) } else { key.strip_prefix(&format!("{p}/")) }; with_slash.unwrap_or(key) } } } async fn resolve_groups( pattern: &GroupPattern, key: &str, ) -> HashMap> { let state = ExtractState { ignore_mime: false }; let mut group = HashMap::new(); 'pattern: for (l, pat) in &pattern.pattern { let item = PileValue::String(Arc::new(key.into())); let mut target = String::new(); for p in pat { match p { GroupSegment::Literal(x) => target.push_str(x), GroupSegment::Path(op) => { let res = match item.query(&state, op).await { Ok(Some(x)) => x, _ => continue 'pattern, }; let res = match res.as_str() { Some(x) => x, None => continue 'pattern, }; target.push_str(res); } } } group.insert(l.clone(), target.into()); } return group; }