166 lines
4.4 KiB
Rust
166 lines
4.4 KiB
Rust
use std::io::SeekFrom;
|
|
|
|
use crate::{AsyncReader, AsyncSeekReader, chacha::ChaChaHeaderv1};
|
|
|
|
pub struct ChaChaReaderv1Async<R: AsyncSeekReader> {
|
|
inner: R,
|
|
header: ChaChaHeaderv1,
|
|
|
|
data_offset: u64,
|
|
encryption_key: [u8; 32],
|
|
cursor: u64,
|
|
plaintext_size: u64,
|
|
cached_chunk: Option<(u64, Vec<u8>)>,
|
|
}
|
|
|
|
impl<R: AsyncSeekReader> ChaChaReaderv1Async<R> {
|
|
pub async fn new(mut inner: R, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
|
|
use binrw::BinReaderExt;
|
|
use std::io::Cursor;
|
|
|
|
inner.seek(SeekFrom::Start(0)).await?;
|
|
let mut buf = [0u8; ChaChaHeaderv1::SIZE];
|
|
read_exact(&mut inner, &mut buf).await?;
|
|
let header: ChaChaHeaderv1 = Cursor::new(&buf[..])
|
|
.read_le()
|
|
.map_err(std::io::Error::other)?;
|
|
|
|
Ok(Self {
|
|
inner,
|
|
header,
|
|
data_offset: buf.len() as u64,
|
|
encryption_key,
|
|
cursor: 0,
|
|
plaintext_size: header.plaintext_size,
|
|
cached_chunk: None,
|
|
})
|
|
}
|
|
|
|
async fn fetch_chunk(&mut self, chunk_index: u64) -> Result<(), std::io::Error> {
|
|
use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead};
|
|
|
|
let enc_start = self.data_offset + chunk_index * self.header.config.enc_chunk_size();
|
|
self.inner.seek(SeekFrom::Start(enc_start)).await?;
|
|
|
|
let mut encrypted = vec![0u8; self.header.config.enc_chunk_size() as usize];
|
|
let n = read_exact_or_eof(&mut self.inner, &mut encrypted).await?;
|
|
encrypted.truncate(n);
|
|
|
|
if encrypted.len() < (self.header.config.nonce_size + self.header.config.tag_size) as usize
|
|
{
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidData,
|
|
"encrypted chunk too short",
|
|
));
|
|
}
|
|
|
|
let (nonce_bytes, ciphertext) = encrypted.split_at(self.header.config.nonce_size as usize);
|
|
let nonce = XNonce::from_slice(nonce_bytes);
|
|
let key = chacha20poly1305::Key::from_slice(&self.encryption_key);
|
|
let cipher = XChaCha20Poly1305::new(key);
|
|
let plaintext = cipher.decrypt(nonce, ciphertext).map_err(|_| {
|
|
std::io::Error::new(std::io::ErrorKind::InvalidData, "decryption failed")
|
|
})?;
|
|
|
|
self.cached_chunk = Some((chunk_index, plaintext));
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
async fn read_exact<R: AsyncReader>(inner: &mut R, buf: &mut [u8]) -> Result<(), std::io::Error> {
|
|
let n = read_exact_or_eof(inner, buf).await?;
|
|
if n < buf.len() {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::UnexpectedEof,
|
|
"unexpected EOF reading header",
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn read_exact_or_eof<R: AsyncReader>(
|
|
inner: &mut R,
|
|
buf: &mut [u8],
|
|
) -> Result<usize, std::io::Error> {
|
|
let mut total = 0;
|
|
while total < buf.len() {
|
|
match inner.read(&mut buf[total..]).await? {
|
|
0 => break,
|
|
n => total += n,
|
|
}
|
|
}
|
|
Ok(total)
|
|
}
|
|
|
|
impl<R: AsyncSeekReader> AsyncReader for ChaChaReaderv1Async<R> {
|
|
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
|
let remaining = self.plaintext_size.saturating_sub(self.cursor);
|
|
if remaining == 0 || buf.is_empty() {
|
|
return Ok(0);
|
|
}
|
|
|
|
let chunk_index = self.cursor / self.header.config.chunk_size;
|
|
|
|
let need_fetch = match &self.cached_chunk {
|
|
None => true,
|
|
Some((idx, _)) => *idx != chunk_index,
|
|
};
|
|
|
|
if need_fetch {
|
|
self.fetch_chunk(chunk_index).await?;
|
|
}
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
let (_, chunk_data) = self.cached_chunk.as_ref().unwrap();
|
|
|
|
let offset_in_chunk = (self.cursor % self.header.config.chunk_size) as usize;
|
|
let available = chunk_data.len() - offset_in_chunk;
|
|
let to_copy = available.min(buf.len());
|
|
|
|
buf[..to_copy].copy_from_slice(&chunk_data[offset_in_chunk..offset_in_chunk + to_copy]);
|
|
self.cursor += to_copy as u64;
|
|
Ok(to_copy)
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncSeekReader> AsyncSeekReader for ChaChaReaderv1Async<R> {
|
|
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
|
|
match pos {
|
|
SeekFrom::Start(x) => self.cursor = x.min(self.plaintext_size),
|
|
|
|
SeekFrom::Current(x) => {
|
|
if x < 0 {
|
|
let abs = x.unsigned_abs();
|
|
if abs > self.cursor {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
"cannot seek past start",
|
|
));
|
|
}
|
|
self.cursor -= abs;
|
|
} else {
|
|
self.cursor += x as u64;
|
|
}
|
|
}
|
|
|
|
SeekFrom::End(x) => {
|
|
if x < 0 {
|
|
let abs = x.unsigned_abs();
|
|
if abs > self.plaintext_size {
|
|
return Err(std::io::Error::new(
|
|
std::io::ErrorKind::InvalidInput,
|
|
"cannot seek past start",
|
|
));
|
|
}
|
|
self.cursor = self.plaintext_size - abs;
|
|
} else {
|
|
self.cursor = self.plaintext_size + x as u64;
|
|
}
|
|
}
|
|
}
|
|
|
|
self.cursor = self.cursor.min(self.plaintext_size);
|
|
Ok(self.cursor)
|
|
}
|
|
}
|