use smartstring::{LazyCompact, SmartString}; use std::{ fs::File, io::{Read, Seek, SeekFrom}, sync::Arc, }; use tokio::runtime::Handle; // // MARK: traits // pub trait AsyncReader: Send { /// Read a chunk of bytes. fn read( &mut self, buf: &mut [u8], ) -> impl Future> + Send; /// Read all remaining bytes into a `Vec`. fn read_to_end(&mut self) -> impl Future, std::io::Error>> + Send { async { let mut buf = Vec::new(); let mut chunk = vec![0u8; 65536]; loop { let n = self.read(&mut chunk).await?; if n == 0 { break; } buf.extend_from_slice(&chunk[..n]); } Ok(buf) } } } pub trait AsyncSeekReader: AsyncReader { fn seek(&mut self, pos: SeekFrom) -> impl Future> + Send; } // // MARK: sync bridge // /// Turn an async [Reader] into a sync [Read] + [Seek]. /// /// Never use this outside of [tokio::task::spawn_blocking], /// the async runtime will deadlock if this struct blocks /// the runtime. pub struct SyncReadBridge { inner: R, handle: Handle, } impl SyncReadBridge { /// Creates a new adapter using a handle to the current runtime. /// Panics if called outside of tokio pub fn new_current(inner: R) -> Self { Self::new(inner, Handle::current()) } /// Creates a new adapter using a handle to an existing runtime. pub fn new(inner: R, handle: Handle) -> Self { Self { inner, handle } } } impl Read for SyncReadBridge { fn read(&mut self, buf: &mut [u8]) -> Result { self.handle.block_on(self.inner.read(buf)) } } impl Seek for SyncReadBridge { fn seek(&mut self, pos: SeekFrom) -> Result { self.handle.block_on(self.inner.seek(pos)) } } // // MARK: itemreader // pub enum ItemReader { File(File), S3(S3Reader), } impl AsyncReader for ItemReader { async fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::File(x) => std::io::Read::read(x, buf), Self::S3(x) => x.read(buf).await, } } } impl AsyncSeekReader for ItemReader { async fn seek(&mut self, pos: std::io::SeekFrom) -> Result { match self { Self::File(x) => x.seek(pos), Self::S3(x) => x.seek(pos).await, } } } // // MARK: S3Reader // pub struct S3Reader { pub client: Arc, pub bucket: SmartString, pub key: SmartString, pub cursor: u64, pub size: u64, } impl AsyncReader for S3Reader { async fn read(&mut self, buf: &mut [u8]) -> Result { let len_left = self.size.saturating_sub(self.cursor); if len_left == 0 || buf.is_empty() { return Ok(0); } let start_byte = self.cursor; let len_to_read = (buf.len() as u64).min(len_left); let end_byte = start_byte + len_to_read - 1; let resp = self .client .get_object() .bucket(self.bucket.as_str()) .key(self.key.as_str()) .range(format!("bytes={start_byte}-{end_byte}")) .send() .await .map_err(std::io::Error::other)?; let bytes = resp .body .collect() .await .map(|x| x.into_bytes()) .map_err(std::io::Error::other)?; let n = bytes.len().min(buf.len()); buf[..n].copy_from_slice(&bytes[..n]); self.cursor += n as u64; Ok(n) } } impl AsyncSeekReader for S3Reader { async fn seek(&mut self, pos: SeekFrom) -> Result { match pos { SeekFrom::Start(x) => self.cursor = x.min(self.size), SeekFrom::Current(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.cursor { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor -= abs; } else { self.cursor += x as u64; } } std::io::SeekFrom::End(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor = self.size - abs; } else { self.cursor = self.size + x as u64; } } } self.cursor = self.cursor.min(self.size); Ok(self.cursor) } }