Files
pile/crates/pile-io/src/s3reader.rs
2026-03-23 21:09:22 -07:00

182 lines
3.9 KiB
Rust

use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region};
use smartstring::{LazyCompact, SmartString};
use std::{fmt::Debug, io::SeekFrom, sync::Arc};
use crate::{AsyncReader, AsyncSeekReader};
//
// MARK: client
//
/// An interface to an S3 bucket.
///
/// TODO: S3 is slow and expensive. Ideally, we'll have this struct cache data
/// so we don't have to download anything twice. This is, however, complicated,
/// and doesn't fully solve the "expensive" problem.
pub struct S3Client {
pub client: aws_sdk_s3::Client,
bucket: SmartString<LazyCompact>,
/// maximum number of bytes to use for cached data
cache_limit_bytes: usize,
}
impl Debug for S3Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("S3Client")
.field("bucket", &self.bucket)
.field("cache_limit_bytes", &self.cache_limit_bytes)
.finish()
}
}
impl S3Client {
pub async fn new(
bucket: &str,
endpoint: Option<&str>,
region: &str,
access_key_id: &str,
secret_access_key: &str,
cache_limit_bytes: usize,
) -> Arc<Self> {
let client = {
let mut s3_config = aws_sdk_s3::config::Builder::new()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region.to_owned()))
.credentials_provider(Credentials::new(
access_key_id,
secret_access_key,
None,
None,
"pile",
));
if let Some(ep) = endpoint {
s3_config = s3_config.endpoint_url(ep).force_path_style(true);
}
aws_sdk_s3::Client::from_conf(s3_config.build())
};
return Arc::new(Self {
bucket: bucket.into(),
client,
cache_limit_bytes,
});
}
pub fn bucket(&self) -> &str {
&self.bucket
}
pub async fn get(self: &Arc<Self>, key: &str) -> Result<S3Reader, std::io::Error> {
let head = self
.client
.head_object()
.bucket(self.bucket.as_str())
.key(key)
.send()
.await
.map_err(std::io::Error::other)?;
let size = head.content_length().unwrap_or(0) as u64;
Ok(S3Reader {
client: self.clone(),
bucket: self.bucket.clone(),
key: key.into(),
cursor: 0,
size,
})
}
}
//
// MARK: reader
//
pub struct S3Reader {
pub client: Arc<S3Client>,
pub bucket: SmartString<LazyCompact>,
pub key: SmartString<LazyCompact>,
pub cursor: u64,
pub size: u64,
}
impl AsyncReader for S3Reader {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let len_left = self.size.saturating_sub(self.cursor);
if len_left == 0 || buf.is_empty() {
return Ok(0);
}
let start_byte = self.cursor;
let len_to_read = (buf.len() as u64).min(len_left);
let end_byte = start_byte + len_to_read - 1;
let resp = self
.client
.client
.get_object()
.bucket(self.bucket.as_str())
.key(self.key.as_str())
.range(format!("bytes={start_byte}-{end_byte}"))
.send()
.await
.map_err(std::io::Error::other)?;
let bytes = resp
.body
.collect()
.await
.map(|x| x.into_bytes())
.map_err(std::io::Error::other)?;
let n = bytes.len().min(buf.len());
buf[..n].copy_from_slice(&bytes[..n]);
self.cursor += n as u64;
Ok(n)
}
}
impl AsyncSeekReader for S3Reader {
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
match pos {
SeekFrom::Start(x) => self.cursor = x.min(self.size),
SeekFrom::Current(x) => {
if x < 0 {
let abs = x.unsigned_abs();
if abs > self.cursor {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"cannot seek past start",
));
}
self.cursor -= abs;
} else {
self.cursor += x as u64;
}
}
std::io::SeekFrom::End(x) => {
if x < 0 {
let abs = x.unsigned_abs();
if abs > self.size {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"cannot seek past start",
));
}
self.cursor = self.size - abs;
} else {
self.cursor = self.size + x as u64;
}
}
}
self.cursor = self.cursor.min(self.size);
Ok(self.cursor)
}
}