265 lines
7.3 KiB
Rust
265 lines
7.3 KiB
Rust
use std::io::{Seek, SeekFrom, Write};
|
|
|
|
use crate::{ChaChaHeader, ChaChaReaderConfig};
|
|
|
|
/// Generate a random 32-byte encryption key suitable for use with [`ChaChaWriter`].
|
|
pub fn generate_key() -> [u8; 32] {
|
|
use chacha20poly1305::aead::OsRng;
|
|
use chacha20poly1305::{KeyInit, XChaCha20Poly1305};
|
|
XChaCha20Poly1305::generate_key(&mut OsRng).into()
|
|
}
|
|
|
|
pub struct ChaChaWriter<W: Write + Seek> {
|
|
inner: W,
|
|
config: ChaChaReaderConfig,
|
|
encryption_key: [u8; 32],
|
|
buffer: Vec<u8>,
|
|
plaintext_bytes_written: u64,
|
|
}
|
|
|
|
impl<W: Write + Seek> ChaChaWriter<W> {
|
|
pub fn new(mut inner: W, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
|
|
use binrw::BinWriterExt;
|
|
|
|
let config = ChaChaReaderConfig::default();
|
|
let header = ChaChaHeader {
|
|
chunk_size: config.chunk_size,
|
|
nonce_size: config.nonce_size,
|
|
tag_size: config.tag_size,
|
|
plaintext_size: 0,
|
|
};
|
|
inner.write_le(&header).map_err(std::io::Error::other)?;
|
|
|
|
Ok(Self {
|
|
inner,
|
|
config,
|
|
encryption_key,
|
|
buffer: Vec::new(),
|
|
plaintext_bytes_written: 0,
|
|
})
|
|
}
|
|
|
|
/// Encrypt and write any buffered plaintext, patch the header with the
|
|
/// final `plaintext_size`, then return the inner writer.
|
|
pub fn finish(mut self) -> Result<W, std::io::Error> {
|
|
use binrw::BinWriterExt;
|
|
|
|
self.flush_buffer()?;
|
|
|
|
self.inner.seek(SeekFrom::Start(0))?;
|
|
let header = ChaChaHeader {
|
|
chunk_size: self.config.chunk_size,
|
|
nonce_size: self.config.nonce_size,
|
|
tag_size: self.config.tag_size,
|
|
plaintext_size: self.plaintext_bytes_written,
|
|
};
|
|
self.inner
|
|
.write_le(&header)
|
|
.map_err(std::io::Error::other)?;
|
|
|
|
Ok(self.inner)
|
|
}
|
|
|
|
fn encrypt_chunk(&self, plaintext: &[u8]) -> Result<Vec<u8>, std::io::Error> {
|
|
use chacha20poly1305::{
|
|
XChaCha20Poly1305,
|
|
aead::{Aead, AeadCore, KeyInit, OsRng},
|
|
};
|
|
|
|
let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
|
|
let key = chacha20poly1305::Key::from_slice(&self.encryption_key);
|
|
let cipher = XChaCha20Poly1305::new(key);
|
|
let ciphertext = cipher
|
|
.encrypt(&nonce, plaintext)
|
|
.map_err(|_| std::io::Error::other("encryption failed"))?;
|
|
|
|
let mut output = Vec::with_capacity(nonce.len() + ciphertext.len());
|
|
output.extend_from_slice(&nonce);
|
|
output.extend_from_slice(&ciphertext);
|
|
Ok(output)
|
|
}
|
|
|
|
fn flush_buffer(&mut self) -> Result<(), std::io::Error> {
|
|
if !self.buffer.is_empty() {
|
|
let encrypted = self.encrypt_chunk(&self.buffer)?;
|
|
self.inner.write_all(&encrypted)?;
|
|
self.buffer.clear();
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl<W: Write + Seek> Write for ChaChaWriter<W> {
|
|
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
|
|
self.buffer.extend_from_slice(buf);
|
|
self.plaintext_bytes_written += buf.len() as u64;
|
|
|
|
let chunk_size = self.config.chunk_size as usize;
|
|
while self.buffer.len() >= chunk_size {
|
|
let encrypted = self.encrypt_chunk(&self.buffer[..chunk_size])?;
|
|
self.inner.write_all(&encrypted)?;
|
|
self.buffer.drain(..chunk_size);
|
|
}
|
|
|
|
Ok(buf.len())
|
|
}
|
|
|
|
/// Encrypts and flushes any buffered plaintext as a partial chunk.
|
|
///
|
|
/// Prefer [`finish`](Self::finish) to retrieve the inner writer after
|
|
/// all data has been written. Calling `flush` multiple times will produce
|
|
/// multiple small encrypted chunks for the same partial data.
|
|
fn flush(&mut self) -> Result<(), std::io::Error> {
|
|
self.flush_buffer()?;
|
|
self.inner.flush()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[expect(clippy::unwrap_used)]
|
|
mod tests {
|
|
use std::io::{Cursor, SeekFrom, Write};
|
|
|
|
use super::ChaChaWriter;
|
|
use crate::{AsyncReader, AsyncSeekReader, ChaChaReader};
|
|
|
|
const KEY: [u8; 32] = [42u8; 32];
|
|
|
|
fn encrypt(data: &[u8]) -> Cursor<Vec<u8>> {
|
|
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap();
|
|
writer.write_all(data).unwrap();
|
|
let mut buf = writer.finish().unwrap();
|
|
buf.set_position(0);
|
|
buf
|
|
}
|
|
|
|
async fn decrypt_all(buf: Cursor<Vec<u8>>) -> Vec<u8> {
|
|
let mut reader = ChaChaReader::new(buf, KEY).unwrap();
|
|
reader.read_to_end().await.unwrap()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn roundtrip_empty() {
|
|
let buf = encrypt(&[]);
|
|
// Header present but no chunks
|
|
assert!(!buf.get_ref().is_empty());
|
|
assert!(decrypt_all(buf).await.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn roundtrip_small() {
|
|
let data = b"hello, world!";
|
|
assert_eq!(decrypt_all(encrypt(data)).await, data);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn roundtrip_exact_chunk() {
|
|
let data = vec![0xABu8; 65536];
|
|
assert_eq!(decrypt_all(encrypt(&data)).await, data);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn roundtrip_multi_chunk() {
|
|
// 2.5 chunks
|
|
let data: Vec<u8> = (0u8..=255).cycle().take(65536 * 2 + 1000).collect();
|
|
assert_eq!(decrypt_all(encrypt(&data)).await, data);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn roundtrip_incremental_writes() {
|
|
// Write one byte at a time
|
|
let data: Vec<u8> = (0u8..200).collect();
|
|
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap();
|
|
for byte in &data {
|
|
writer.write_all(&[*byte]).unwrap();
|
|
}
|
|
let mut buf = writer.finish().unwrap();
|
|
buf.set_position(0);
|
|
assert_eq!(decrypt_all(buf).await, data);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn wrong_key_fails() {
|
|
let buf = encrypt(b"secret data");
|
|
let mut reader = ChaChaReader::new(buf, [0u8; 32]).unwrap();
|
|
assert!(reader.read_to_end().await.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn header_magic_checked() {
|
|
// Corrupt the magic bytes — reader should fail
|
|
let mut buf = encrypt(b"data");
|
|
buf.get_mut()[0] = 0xFF;
|
|
buf.set_position(0);
|
|
assert!(ChaChaReader::new(buf, KEY).is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seek_from_start() {
|
|
let data: Vec<u8> = (0u8..100).collect();
|
|
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
|
|
|
|
reader.seek(SeekFrom::Start(50)).await.unwrap();
|
|
let mut buf = [0u8; 10];
|
|
let mut read = 0;
|
|
while read < buf.len() {
|
|
read += reader.read(&mut buf[read..]).await.unwrap();
|
|
}
|
|
assert_eq!(buf, data[50..60]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seek_from_end() {
|
|
let data: Vec<u8> = (0u8..100).collect();
|
|
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
|
|
|
|
reader.seek(SeekFrom::End(-10)).await.unwrap();
|
|
assert_eq!(reader.read_to_end().await.unwrap(), &data[90..]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seek_across_chunk_boundary() {
|
|
// Seek to 6 bytes before the end of chunk 0, read 12 bytes spanning into chunk 1
|
|
let data: Vec<u8> = (0u8..=255).cycle().take(65536 + 500).collect();
|
|
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
|
|
|
|
reader.seek(SeekFrom::Start(65530)).await.unwrap();
|
|
let mut buf = vec![0u8; 12];
|
|
let mut read = 0;
|
|
while read < buf.len() {
|
|
read += reader.read(&mut buf[read..]).await.unwrap();
|
|
}
|
|
assert_eq!(buf, data[65530..65542]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seek_current() {
|
|
let data: Vec<u8> = (0u8..=255).cycle().take(200).collect();
|
|
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
|
|
|
|
// Read 10, seek back 5, read 5 — should get bytes 5..10
|
|
let mut first = [0u8; 10];
|
|
let mut n = 0;
|
|
while n < first.len() {
|
|
n += reader.read(&mut first[n..]).await.unwrap();
|
|
}
|
|
reader.seek(SeekFrom::Current(-5)).await.unwrap();
|
|
let mut second = [0u8; 5];
|
|
n = 0;
|
|
while n < second.len() {
|
|
n += reader.read(&mut second[n..]).await.unwrap();
|
|
}
|
|
assert_eq!(second, data[5..10]);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn seek_past_end_clamps() {
|
|
let data = b"hello";
|
|
let mut reader = ChaChaReader::new(encrypt(data), KEY).unwrap();
|
|
|
|
let pos = reader.seek(SeekFrom::Start(9999)).await.unwrap();
|
|
assert_eq!(pos, data.len() as u64);
|
|
assert_eq!(reader.read_to_end().await.unwrap(), b"");
|
|
}
|
|
}
|