Files
pile/crates/pile-value/src/source/s3.rs
2026-03-16 22:24:30 -07:00

269 lines
6.1 KiB
Rust

use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region};
use chrono::{DateTime, Utc};
use pile_config::{
Label, S3Credentials,
pattern::{GroupPattern, GroupSegment},
};
use smartstring::{LazyCompact, SmartString};
use std::{
collections::{HashMap, HashSet},
sync::{Arc, OnceLock},
};
use crate::{
extract::traits::ExtractState,
source::DataSource,
value::{Item, PileValue},
};
#[derive(Debug)]
pub struct S3DataSource {
pub name: Label,
pub bucket: SmartString<LazyCompact>,
pub prefix: Option<SmartString<LazyCompact>>,
pub client: Arc<aws_sdk_s3::Client>,
pub pattern: GroupPattern,
pub index: OnceLock<HashMap<SmartString<LazyCompact>, Item>>,
}
impl S3DataSource {
pub async fn new(
name: &Label,
bucket: String,
prefix: Option<String>,
endpoint: Option<String>,
region: String,
credentials: &S3Credentials,
pattern: GroupPattern,
) -> Result<Arc<Self>, std::io::Error> {
let client = {
let creds = Credentials::new(
&credentials.access_key_id,
&credentials.secret_access_key,
None,
None,
"pile",
);
let mut s3_config = aws_sdk_s3::config::Builder::new()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region))
.credentials_provider(creds);
if let Some(ep) = endpoint {
s3_config = s3_config.endpoint_url(ep).force_path_style(true);
}
aws_sdk_s3::Client::from_conf(s3_config.build())
};
let source = Arc::new(Self {
name: name.clone(),
bucket: bucket.into(),
prefix: prefix.map(|x| x.into()),
client: Arc::new(client),
pattern,
index: OnceLock::new(),
});
//
// MARK: list keys
//
let mut all_keys: HashSet<SmartString<LazyCompact>> = HashSet::new();
let mut continuation_token: Option<String> = None;
loop {
let mut req = source
.client
.list_objects_v2()
.bucket(source.bucket.as_str());
if let Some(prefix) = &source.prefix {
req = req.prefix(prefix.as_str());
}
if let Some(token) = continuation_token {
req = req.continuation_token(token);
}
let resp = req.send().await.map_err(std::io::Error::other)?;
let next_token = resp.next_continuation_token().map(ToOwned::to_owned);
let is_truncated = resp.is_truncated().unwrap_or(false);
for obj in resp.contents() {
let Some(full_key) = obj.key() else { continue };
let key = strip_prefix(full_key, source.prefix.as_deref());
all_keys.insert(key.into());
}
if !is_truncated {
break;
}
continuation_token = next_token;
}
//
// MARK: resolve groups
//
let mut keys_grouped: HashSet<SmartString<LazyCompact>> = HashSet::new();
for key in &all_keys {
let groups = resolve_groups(&source.pattern, key).await;
for group_key in groups.into_values() {
if all_keys.contains(&group_key) {
keys_grouped.insert(group_key);
}
}
}
let mut index = HashMap::new();
for key in all_keys.difference(&keys_grouped) {
let groups = resolve_groups(&source.pattern, key).await;
let group = groups
.into_iter()
.filter(|(_, gk)| all_keys.contains(gk))
.map(|(label, gk)| {
(
label,
Box::new(Item::S3 {
source: Arc::clone(&source),
mime: mime_guess::from_path(gk.as_str()).first_or_octet_stream(),
key: gk,
group: Arc::new(HashMap::new()),
}),
)
})
.collect::<HashMap<_, _>>();
let item = Item::S3 {
source: Arc::clone(&source),
mime: mime_guess::from_path(key.as_str()).first_or_octet_stream(),
key: key.clone(),
group: Arc::new(group),
};
index.insert(item.key(), item);
}
source.index.get_or_init(|| index);
Ok(source)
}
}
impl DataSource for Arc<S3DataSource> {
#[expect(clippy::expect_used)]
async fn get(&self, key: &str) -> Result<Option<Item>, std::io::Error> {
return Ok(self
.index
.get()
.expect("index should be initialized")
.get(key)
.cloned());
}
#[expect(clippy::expect_used)]
fn iter(&self) -> impl Iterator<Item = &Item> {
self.index
.get()
.expect("index should be initialized")
.values()
}
async fn latest_change(&self) -> Result<Option<DateTime<Utc>>, std::io::Error> {
let mut ts: Option<DateTime<Utc>> = None;
let mut continuation_token: Option<String> = None;
loop {
let mut req = self.client.list_objects_v2().bucket(self.bucket.as_str());
if let Some(prefix) = &self.prefix {
req = req.prefix(prefix.as_str());
}
if let Some(token) = continuation_token {
req = req.continuation_token(token);
}
let resp = match req.send().await {
Err(_) => return Ok(None),
Ok(resp) => resp,
};
let next_token = resp.next_continuation_token().map(ToOwned::to_owned);
let is_truncated = resp.is_truncated().unwrap_or(false);
for obj in resp.contents() {
if let Some(last_modified) = obj.last_modified() {
let dt = DateTime::from_timestamp(
last_modified.secs(),
last_modified.subsec_nanos(),
);
if let Some(dt) = dt {
ts = Some(match ts {
None => dt,
Some(prev) => prev.max(dt),
});
}
}
}
if !is_truncated {
break;
}
continuation_token = next_token;
}
Ok(ts)
}
}
fn strip_prefix<'a>(key: &'a str, prefix: Option<&str>) -> &'a str {
match prefix {
None => key,
Some(p) => {
let with_slash = if p.ends_with('/') {
key.strip_prefix(p)
} else {
key.strip_prefix(&format!("{p}/"))
};
with_slash.unwrap_or(key)
}
}
}
async fn resolve_groups(
pattern: &GroupPattern,
key: &str,
) -> HashMap<Label, SmartString<LazyCompact>> {
let state = ExtractState { ignore_mime: false };
let mut group = HashMap::new();
'pattern: for (l, pat) in &pattern.pattern {
let item = PileValue::String(Arc::new(key.into()));
let mut target = String::new();
for p in pat {
match p {
GroupSegment::Literal(x) => target.push_str(x),
GroupSegment::Path(op) => {
let res = match item.query(&state, op).await {
Ok(Some(x)) => x,
_ => continue 'pattern,
};
let res = match res.as_str() {
Some(x) => x,
None => continue 'pattern,
};
target.push_str(res);
}
}
}
group.insert(l.clone(), target.into());
}
return group;
}