use std::io::SeekFrom; use crate::{AsyncReader, AsyncSeekReader, chacha::ChaChaHeaderv1}; pub struct ChaChaReaderv1Async { inner: R, header: ChaChaHeaderv1, data_offset: u64, encryption_key: [u8; 32], cursor: u64, plaintext_size: u64, cached_chunk: Option<(u64, Vec)>, } impl ChaChaReaderv1Async { pub async fn new(mut inner: R, encryption_key: [u8; 32]) -> Result { use binrw::BinReaderExt; use std::io::Cursor; inner.seek(SeekFrom::Start(0)).await?; let mut buf = [0u8; ChaChaHeaderv1::SIZE]; read_exact(&mut inner, &mut buf).await?; let header: ChaChaHeaderv1 = Cursor::new(&buf[..]) .read_le() .map_err(std::io::Error::other)?; Ok(Self { inner, header, data_offset: buf.len() as u64, encryption_key, cursor: 0, plaintext_size: header.plaintext_size, cached_chunk: None, }) } async fn fetch_chunk(&mut self, chunk_index: u64) -> Result<(), std::io::Error> { use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead}; let enc_start = self.data_offset + chunk_index * self.header.config.enc_chunk_size(); self.inner.seek(SeekFrom::Start(enc_start)).await?; let mut encrypted = vec![0u8; self.header.config.enc_chunk_size() as usize]; let n = read_exact_or_eof(&mut self.inner, &mut encrypted).await?; encrypted.truncate(n); if encrypted.len() < (self.header.config.nonce_size + self.header.config.tag_size) as usize { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "encrypted chunk too short", )); } let (nonce_bytes, ciphertext) = encrypted.split_at(self.header.config.nonce_size as usize); let nonce = XNonce::from_slice(nonce_bytes); let key = chacha20poly1305::Key::from_slice(&self.encryption_key); let cipher = XChaCha20Poly1305::new(key); let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|_| { std::io::Error::new(std::io::ErrorKind::InvalidData, "decryption failed") })?; self.cached_chunk = Some((chunk_index, plaintext)); Ok(()) } } async fn read_exact(inner: &mut R, buf: &mut [u8]) -> Result<(), std::io::Error> { let n = read_exact_or_eof(inner, buf).await?; if n < buf.len() { return Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "unexpected EOF reading header", )); } Ok(()) } async fn read_exact_or_eof( inner: &mut R, buf: &mut [u8], ) -> Result { let mut total = 0; while total < buf.len() { match inner.read(&mut buf[total..]).await? { 0 => break, n => total += n, } } Ok(total) } impl AsyncReader for ChaChaReaderv1Async { async fn read(&mut self, buf: &mut [u8]) -> Result { let remaining = self.plaintext_size.saturating_sub(self.cursor); if remaining == 0 || buf.is_empty() { return Ok(0); } let chunk_index = self.cursor / self.header.config.chunk_size; let need_fetch = match &self.cached_chunk { None => true, Some((idx, _)) => *idx != chunk_index, }; if need_fetch { self.fetch_chunk(chunk_index).await?; } #[expect(clippy::unwrap_used)] let (_, chunk_data) = self.cached_chunk.as_ref().unwrap(); let offset_in_chunk = (self.cursor % self.header.config.chunk_size) as usize; let available = chunk_data.len() - offset_in_chunk; let to_copy = available.min(buf.len()); buf[..to_copy].copy_from_slice(&chunk_data[offset_in_chunk..offset_in_chunk + to_copy]); self.cursor += to_copy as u64; Ok(to_copy) } } impl AsyncSeekReader for ChaChaReaderv1Async { async fn seek(&mut self, pos: SeekFrom) -> Result { match pos { SeekFrom::Start(x) => self.cursor = x.min(self.plaintext_size), SeekFrom::Current(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.cursor { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor -= abs; } else { self.cursor += x as u64; } } SeekFrom::End(x) => { if x < 0 { let abs = x.unsigned_abs(); if abs > self.plaintext_size { return Err(std::io::Error::new( std::io::ErrorKind::InvalidInput, "cannot seek past start", )); } self.cursor = self.plaintext_size - abs; } else { self.cursor = self.plaintext_size + x as u64; } } } self.cursor = self.cursor.min(self.plaintext_size); Ok(self.cursor) } }