182 lines
3.9 KiB
Rust
182 lines
3.9 KiB
Rust
use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region};
|
|
use smartstring::{LazyCompact, SmartString};
|
|
use std::{fmt::Debug, io::SeekFrom, sync::Arc};
|
|
|
|
use crate::{AsyncReader, AsyncSeekReader};
|
|
|
|
//
|
|
// MARK: client
|
|
//
|
|
|
|
/// An interface to an S3 bucket.
|
|
///
|
|
/// TODO: S3 is slow and expensive. Ideally, we'll have this struct cache data
|
|
/// so we don't have to download anything twice. This is, however, complicated,
|
|
/// and doesn't fully solve the "expensive" problem.
|
|
pub struct S3Client {
|
|
pub client: aws_sdk_s3::Client,
|
|
bucket: SmartString<LazyCompact>,
|
|
|
|
/// maximum number of bytes to use for cached data
|
|
cache_limit_bytes: usize,
|
|
}
|
|
|
|
impl Debug for S3Client {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("S3Client")
|
|
.field("bucket", &self.bucket)
|
|
.field("cache_limit_bytes", &self.cache_limit_bytes)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl S3Client {
|
|
pub async fn new(
|
|
bucket: &str,
|
|
endpoint: Option<&str>,
|
|
region: &str,
|
|
access_key_id: &str,
|
|
secret_access_key: &str,
|
|
cache_limit_bytes: usize,
|
|
) -> Arc<Self> {
|
|
let client = {
|
|
let mut s3_config = aws_sdk_s3::config::Builder::new()
|
|
.behavior_version(BehaviorVersion::latest())
|
|
.region(Region::new(region.to_owned()))
|
|
.credentials_provider(Credentials::new(
|
|
access_key_id,
|
|
secret_access_key,
|
|
None,
|
|
None,
|
|
"pile",
|
|
));
|
|
|
|
if let Some(ep) = endpoint {
|
|
s3_config = s3_config.endpoint_url(ep).force_path_style(true);
|
|
}
|
|
|
|
aws_sdk_s3::Client::from_conf(s3_config.build())
|
|
};
|
|
|
|
return Arc::new(Self {
|
|
bucket: bucket.into(),
|
|
client,
|
|
cache_limit_bytes,
|
|
});
|
|
}
|
|
|
|
pub fn bucket(&self) -> &str {
|
|
&self.bucket
|
|
}
|
|
|
|
pub async fn get(self: &Arc<Self>, key: &str) -> Result<S3Reader, std::io::Error> {
|
|
let head = self
|
|
.client
|
|
.head_object()
|
|
.bucket(self.bucket.as_str())
|
|
.key(key)
|
|
.send()
|
|
.await
|
|
.map_err(std::io::Error::other)?;
|
|
|
|
let size = head.content_length().unwrap_or(0) as u64;
|
|
|
|
Ok(S3Reader {
|
|
client: self.clone(),
|
|
bucket: self.bucket.clone(),
|
|
key: key.into(),
|
|
cursor: 0,
|
|
size,
|
|
})
|
|
}
|
|
}
|
|
|
|
//
|
|
// MARK: reader
|
|
//
|
|
|
|
pub struct S3Reader {
|
|
pub client: Arc<S3Client>,
|
|
pub bucket: SmartString<LazyCompact>,
|
|
pub key: SmartString<LazyCompact>,
|
|
pub cursor: u64,
|
|
pub size: u64,
|
|
}
|
|
|
|
impl AsyncReader for S3Reader {
|
|
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
|
let len_left = self.size.saturating_sub(self.cursor);
|
|
if len_left == 0 || buf.is_empty() {
|
|
return Ok(0);
|
|
}
|
|
|
|
let start_byte = self.cursor;
|
|
let len_to_read = (buf.len() as u64).min(len_left);
|
|
let end_byte = start_byte + len_to_read - 1;
|
|
|
|
let resp = self
|
|
.client
|
|
.client
|
|
.get_object()
|
|
.bucket(self.bucket.as_str())
|
|
.key(self.key.as_str())
|
|
.range(format!("bytes={start_byte}-{end_byte}"))
|
|
.send()
|
|
.await
|
|
.map_err(std::io::Error::other)?;
|
|
|
|
let bytes = resp
|
|
.body
|
|
.collect()
|
|
.await
|
|
.map(|x| x.into_bytes())
|
|
.map_err(std::io::Error::other)?;
|
|
|
|
let n = bytes.len().min(buf.len());
|
|
buf[..n].copy_from_slice(&bytes[..n]);
|
|
self.cursor += n as u64;
|
|
Ok(n)
|
|
}
|
|
}
|
|
|
|
impl AsyncSeekReader for S3Reader {
|
|
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
|
|
match pos {
|
|
SeekFrom::Start(x) => self.cursor = x.min(self.size),
|
|
|
|
SeekFrom::Current(x) => {
|
|
if x < 0 {
|
|
let abs = x.unsigned_abs();
|
|
if abs > self.cursor {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
"cannot seek past start",
|
|
));
|
|
}
|
|
self.cursor -= abs;
|
|
} else {
|
|
self.cursor += x as u64;
|
|
}
|
|
}
|
|
|
|
std::io::SeekFrom::End(x) => {
|
|
if x < 0 {
|
|
let abs = x.unsigned_abs();
|
|
if abs > self.size {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
"cannot seek past start",
|
|
));
|
|
}
|
|
self.cursor = self.size - abs;
|
|
} else {
|
|
self.cursor = self.size + x as u64;
|
|
}
|
|
}
|
|
}
|
|
|
|
self.cursor = self.cursor.min(self.size);
|
|
Ok(self.cursor)
|
|
}
|
|
}
|