use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region}; use smartstring::{LazyCompact, SmartString}; use std::{fmt::Debug, io::SeekFrom, sync::Arc}; use crate::{AsyncReader, AsyncSeekReader}; // // MARK: client // /// An interface to an S3 bucket. /// /// TODO: S3 is slow and expensive. Ideally, we'll have this struct cache data /// so we don't have to download anything twice. This is, however, complicated, /// and doesn't fully solve the "expensive" problem. pub struct S3Client { pub client: aws_sdk_s3::Client, bucket: SmartString, /// maximum number of bytes to use for cached data cache_limit_bytes: usize, } impl Debug for S3Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("S3Client") .field("bucket", &self.bucket) .field("cache_limit_bytes", &self.cache_limit_bytes) .finish() } } impl S3Client { pub async fn new( bucket: &str, endpoint: Option<&str>, region: &str, access_key_id: &str, secret_access_key: &str, cache_limit_bytes: usize, ) -> Arc { let client = { let mut s3_config = aws_sdk_s3::config::Builder::new() .behavior_version(BehaviorVersion::latest()) .region(Region::new(region.to_owned())) .credentials_provider(Credentials::new( access_key_id, secret_access_key, None, None, "pile", )); if let Some(ep) = endpoint { s3_config = s3_config.endpoint_url(ep).force_path_style(true); } aws_sdk_s3::Client::from_conf(s3_config.build()) }; return Arc::new(Self { bucket: bucket.into(), client, cache_limit_bytes, }); } pub fn bucket(&self) -> &str { &self.bucket } pub async fn get(self: &Arc, key: &str) -> Result { let head = self .client .head_object() .bucket(self.bucket.as_str()) .key(key) .send() .await .map_err(std::io::Error::other)?; let size = head.content_length().unwrap_or(0) as u64; Ok(S3Reader { client: self.clone(), bucket: self.bucket.clone(), key: key.into(), cursor: 0, size, }) } } // // MARK: reader // pub struct S3Reader { pub client: Arc, pub bucket: SmartString, pub key: SmartString, pub cursor: u64, pub size: u64, } impl AsyncReader for S3Reader { async fn read(&mut self, buf: &mut [u8]) -> Result { let len_left = self.size.saturating_sub(self.cursor); if len_left == 0 || buf.is_empty() { return Ok(0); } let start_byte = self.cursor; let len_to_read = (buf.len() as u64).min(len_left); let end_byte = start_byte + len_to_read - 1; let resp = self .client .client .get_object() .bucket(self.bucket.as_str()) .key(self.key.as_str()) .range(format!("bytes={start_byte}-{end_byte}")) .send() .await .map_err(std::io::Error::other)?; let bytes = resp .body .collect() .await .map(|x| x.into_bytes()) .map_err(std::io::Error::other)?; let n = bytes.len().min(buf.len()); buf[..n].copy_from_slice(&bytes[..n]); self.cursor += n as u64; Ok(n) } } impl AsyncSeekReader for S3Reader { async fn seek(&mut self, pos: SeekFrom) -> Result { match pos { SeekFrom::Start(x) => self.cursor = x.min(self.size), SeekFrom::Current(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.cursor { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor -= abs; } else { self.cursor += x as u64; } } std::io::SeekFrom::End(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor = self.size - abs; } else { self.cursor = self.size + x as u64; } } } self.cursor = self.cursor.min(self.size); Ok(self.cursor) } }