use std::io::{Seek, SeekFrom, Write}; use crate::{ChaChaHeader, ChaChaReaderConfig}; /// Generate a random 32-byte encryption key suitable for use with [`ChaChaWriter`]. pub fn generate_key() -> [u8; 32] { use chacha20poly1305::aead::OsRng; use chacha20poly1305::{KeyInit, XChaCha20Poly1305}; XChaCha20Poly1305::generate_key(&mut OsRng).into() } pub struct ChaChaWriter { inner: W, config: ChaChaReaderConfig, encryption_key: [u8; 32], buffer: Vec, plaintext_bytes_written: u64, } impl ChaChaWriter { pub fn new(mut inner: W, encryption_key: [u8; 32]) -> Result { use binrw::BinWriterExt; let config = ChaChaReaderConfig::default(); let header = ChaChaHeader { chunk_size: config.chunk_size, nonce_size: config.nonce_size, tag_size: config.tag_size, plaintext_size: 0, }; inner.write_le(&header).map_err(std::io::Error::other)?; Ok(Self { inner, config, encryption_key, buffer: Vec::new(), plaintext_bytes_written: 0, }) } /// Encrypt and write any buffered plaintext, patch the header with the /// final `plaintext_size`, then return the inner writer. pub fn finish(mut self) -> Result { use binrw::BinWriterExt; self.flush_buffer()?; self.inner.seek(SeekFrom::Start(0))?; let header = ChaChaHeader { chunk_size: self.config.chunk_size, nonce_size: self.config.nonce_size, tag_size: self.config.tag_size, plaintext_size: self.plaintext_bytes_written, }; self.inner .write_le(&header) .map_err(std::io::Error::other)?; Ok(self.inner) } fn encrypt_chunk(&self, plaintext: &[u8]) -> Result, std::io::Error> { use chacha20poly1305::{ XChaCha20Poly1305, aead::{Aead, AeadCore, KeyInit, OsRng}, }; let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng); let key = chacha20poly1305::Key::from_slice(&self.encryption_key); let cipher = XChaCha20Poly1305::new(key); let ciphertext = cipher .encrypt(&nonce, plaintext) .map_err(|_| std::io::Error::other("encryption failed"))?; let mut output = Vec::with_capacity(nonce.len() + ciphertext.len()); output.extend_from_slice(&nonce); output.extend_from_slice(&ciphertext); Ok(output) } fn flush_buffer(&mut self) -> Result<(), std::io::Error> { if !self.buffer.is_empty() { let encrypted = self.encrypt_chunk(&self.buffer)?; self.inner.write_all(&encrypted)?; self.buffer.clear(); } Ok(()) } } impl Write for ChaChaWriter { fn write(&mut self, buf: &[u8]) -> Result { self.buffer.extend_from_slice(buf); self.plaintext_bytes_written += buf.len() as u64; let chunk_size = self.config.chunk_size as usize; while self.buffer.len() >= chunk_size { let encrypted = self.encrypt_chunk(&self.buffer[..chunk_size])?; self.inner.write_all(&encrypted)?; self.buffer.drain(..chunk_size); } Ok(buf.len()) } /// Encrypts and flushes any buffered plaintext as a partial chunk. /// /// Prefer [`finish`](Self::finish) to retrieve the inner writer after /// all data has been written. Calling `flush` multiple times will produce /// multiple small encrypted chunks for the same partial data. fn flush(&mut self) -> Result<(), std::io::Error> { self.flush_buffer()?; self.inner.flush() } } #[cfg(test)] #[expect(clippy::unwrap_used)] mod tests { use std::io::{Cursor, SeekFrom, Write}; use super::ChaChaWriter; use crate::{AsyncReader, AsyncSeekReader, ChaChaReader}; const KEY: [u8; 32] = [42u8; 32]; fn encrypt(data: &[u8]) -> Cursor> { let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap(); writer.write_all(data).unwrap(); let mut buf = writer.finish().unwrap(); buf.set_position(0); buf } async fn decrypt_all(buf: Cursor>) -> Vec { let mut reader = ChaChaReader::new(buf, KEY).unwrap(); reader.read_to_end().await.unwrap() } #[tokio::test] async fn roundtrip_empty() { let buf = encrypt(&[]); // Header present but no chunks assert!(!buf.get_ref().is_empty()); assert!(decrypt_all(buf).await.is_empty()); } #[tokio::test] async fn roundtrip_small() { let data = b"hello, world!"; assert_eq!(decrypt_all(encrypt(data)).await, data); } #[tokio::test] async fn roundtrip_exact_chunk() { let data = vec![0xABu8; 65536]; assert_eq!(decrypt_all(encrypt(&data)).await, data); } #[tokio::test] async fn roundtrip_multi_chunk() { // 2.5 chunks let data: Vec = (0u8..=255).cycle().take(65536 * 2 + 1000).collect(); assert_eq!(decrypt_all(encrypt(&data)).await, data); } #[tokio::test] async fn roundtrip_incremental_writes() { // Write one byte at a time let data: Vec = (0u8..200).collect(); let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap(); for byte in &data { writer.write_all(&[*byte]).unwrap(); } let mut buf = writer.finish().unwrap(); buf.set_position(0); assert_eq!(decrypt_all(buf).await, data); } #[tokio::test] async fn wrong_key_fails() { let buf = encrypt(b"secret data"); let mut reader = ChaChaReader::new(buf, [0u8; 32]).unwrap(); assert!(reader.read_to_end().await.is_err()); } #[tokio::test] async fn header_magic_checked() { // Corrupt the magic bytes — reader should fail let mut buf = encrypt(b"data"); buf.get_mut()[0] = 0xFF; buf.set_position(0); assert!(ChaChaReader::new(buf, KEY).is_err()); } #[tokio::test] async fn seek_from_start() { let data: Vec = (0u8..100).collect(); let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap(); reader.seek(SeekFrom::Start(50)).await.unwrap(); let mut buf = [0u8; 10]; let mut read = 0; while read < buf.len() { read += reader.read(&mut buf[read..]).await.unwrap(); } assert_eq!(buf, data[50..60]); } #[tokio::test] async fn seek_from_end() { let data: Vec = (0u8..100).collect(); let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap(); reader.seek(SeekFrom::End(-10)).await.unwrap(); assert_eq!(reader.read_to_end().await.unwrap(), &data[90..]); } #[tokio::test] async fn seek_across_chunk_boundary() { // Seek to 6 bytes before the end of chunk 0, read 12 bytes spanning into chunk 1 let data: Vec = (0u8..=255).cycle().take(65536 + 500).collect(); let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap(); reader.seek(SeekFrom::Start(65530)).await.unwrap(); let mut buf = vec![0u8; 12]; let mut read = 0; while read < buf.len() { read += reader.read(&mut buf[read..]).await.unwrap(); } assert_eq!(buf, data[65530..65542]); } #[tokio::test] async fn seek_current() { let data: Vec = (0u8..=255).cycle().take(200).collect(); let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap(); // Read 10, seek back 5, read 5 — should get bytes 5..10 let mut first = [0u8; 10]; let mut n = 0; while n < first.len() { n += reader.read(&mut first[n..]).await.unwrap(); } reader.seek(SeekFrom::Current(-5)).await.unwrap(); let mut second = [0u8; 5]; n = 0; while n < second.len() { n += reader.read(&mut second[n..]).await.unwrap(); } assert_eq!(second, data[5..10]); } #[tokio::test] async fn seek_past_end_clamps() { let data = b"hello"; let mut reader = ChaChaReader::new(encrypt(data), KEY).unwrap(); let pos = reader.seek(SeekFrom::Start(9999)).await.unwrap(); assert_eq!(pos, data.len() as u64); assert_eq!(reader.read_to_end().await.unwrap(), b""); } }