Compare commits

..

1 Commits

Author SHA1 Message Date
6b4fbf0ae9 Fuzzy search
All checks were successful
CI / Typos (push) Successful in 20s
CI / Clippy (push) Successful in 1m30s
CI / Build and test (all features) (push) Successful in 5m8s
CI / Build and test (push) Successful in 5m55s
2026-03-23 14:48:52 -07:00
19 changed files with 352 additions and 345 deletions

1
Cargo.lock generated
View File

@@ -2653,6 +2653,7 @@ version = "0.0.2"
dependencies = [
"anyhow",
"async-trait",
"aws-sdk-s3",
"base64",
"blake3",
"chacha20poly1305",

View File

@@ -81,7 +81,6 @@ pub struct Datasets {
pub config: ConfigToml,
pub sources: HashMap<Label, Dataset>,
pub disabled_sources: HashMap<Label, Dataset>,
}
impl Datasets {
@@ -115,8 +114,6 @@ impl Datasets {
};
let mut sources = HashMap::new();
let mut disabled_sources = HashMap::new();
for (label, source) in &config.dataset.source {
match source {
Source::Filesystem {
@@ -124,12 +121,11 @@ impl Datasets {
path,
pattern,
} => {
let target = match enabled {
true => &mut sources,
false => &mut disabled_sources,
};
if !enabled {
continue;
}
target.insert(
sources.insert(
label.clone(),
Dataset::Dir(
DirDataSource::new(label, path_parent.join(path), pattern.clone())
@@ -148,29 +144,26 @@ impl Datasets {
pattern,
encryption_key,
} => {
let target = match enabled {
true => &mut sources,
false => &mut disabled_sources,
};
if !enabled {
continue;
}
let encryption_key = encryption_key.as_ref().map(|x| string_to_key(x));
match S3DataSource::new(
label,
bucket,
prefix.as_ref().map(|x| x.as_str()),
endpoint.as_ref().map(|x| x.as_str()),
region,
&credentials.access_key_id,
&credentials.secret_access_key,
10_000_000,
bucket.clone(),
prefix.clone(),
endpoint.clone(),
region.clone(),
credentials,
pattern.clone(),
encryption_key,
)
.await
{
Ok(ds) => {
target.insert(label.clone(), Dataset::S3(ds));
sources.insert(label.clone(), Dataset::S3(ds));
}
Err(err) => {
warn!("Could not open S3 source {label}: {err}");
@@ -186,7 +179,6 @@ impl Datasets {
path_parent,
config,
sources,
disabled_sources,
});
}
@@ -227,7 +219,6 @@ impl Datasets {
.join(config.dataset.name.as_str());
let mut sources = HashMap::new();
let mut disabled_sources = HashMap::new();
for (label, source) in &config.dataset.source {
match source {
Source::Filesystem {
@@ -235,12 +226,11 @@ impl Datasets {
path,
pattern,
} => {
let target = match enabled {
true => &mut sources,
false => &mut disabled_sources,
};
if !enabled {
continue;
}
target.insert(
sources.insert(
label.clone(),
Dataset::Dir(
DirDataSource::new(label, path_parent.join(path), pattern.clone())
@@ -259,29 +249,26 @@ impl Datasets {
pattern,
encryption_key,
} => {
let target = match enabled {
true => &mut sources,
false => &mut disabled_sources,
};
if !enabled {
continue;
}
let encryption_key = encryption_key.as_ref().map(|x| string_to_key(x));
match S3DataSource::new(
label,
bucket,
prefix.as_ref().map(|x| x.as_str()),
endpoint.as_ref().map(|x| x.as_str()),
region,
&credentials.access_key_id,
&credentials.secret_access_key,
10_000_000,
bucket.clone(),
prefix.clone(),
endpoint.clone(),
region.clone(),
credentials,
pattern.clone(),
encryption_key,
)
.await
{
Ok(ds) => {
target.insert(label.clone(), Dataset::S3(ds));
sources.insert(label.clone(), Dataset::S3(ds));
}
Err(err) => {
warn!("Could not open S3 source {label}: {err}");
@@ -297,7 +284,6 @@ impl Datasets {
path_parent,
config,
sources,
disabled_sources,
});
}
@@ -474,6 +460,37 @@ impl Datasets {
return Ok(results);
}
pub fn fts_lookup_fuzzy(
&self,
query: &str,
top_n: usize,
) -> Result<Vec<FtsLookupResult>, DatasetError> {
let workdir = match self.path_workdir.as_ref() {
Some(x) => x,
None => {
warn!("Skipping fts_lookup_fuzzy, no workdir");
return Ok(Vec::new());
}
};
let fts_dir = workdir.join("fts");
if !fts_dir.exists() {
return Err(DatasetError::NoFtsIndex);
}
if !fts_dir.is_dir() {
return Err(std::io::Error::new(
ErrorKind::NotADirectory,
format!("fts index {} is not a directory", fts_dir.display()),
)
.into());
}
let db_index = DbFtsIndex::new(&fts_dir, &self.config);
let results = db_index.lookup_fuzzy(query, Arc::new(TopDocs::with_limit(top_n)), 3)?;
return Ok(results);
}
/// Time at which fts was created
pub fn ts_fts(&self) -> Result<Option<DateTime<Utc>>, std::io::Error> {
let workdir = match self.path_workdir.as_ref() {

View File

@@ -7,8 +7,8 @@ use std::{path::PathBuf, sync::LazyLock};
use tantivy::{
DocAddress, Index, ReloadPolicy, TantivyDocument, TantivyError,
collector::Collector,
query::QueryParser,
schema::{self, Schema, Value as TantivyValue},
query::{BooleanQuery, FuzzyTermQuery, Occur, QueryParser},
schema::{self, Schema, Term, Value as TantivyValue},
};
use tracing::warn;
@@ -168,6 +168,74 @@ impl DbFtsIndex {
let query = query_parser.parse_query(&query)?;
let res = searcher.search(&query, collector.as_ref())?;
return Self::collect_results(&schema, &searcher, res);
}
/// Run a fuzzy query on this table's fts index.
/// Each whitespace-separated term is matched with edit distance 1.
///
/// See [`Self::lookup`] for caveats about concurrent writes.
pub fn lookup_fuzzy<C>(
&self,
query: impl Into<String>,
collector: impl AsRef<C> + Send + 'static,
distance: u8,
) -> Result<Vec<FtsLookupResult>, TantivyError>
where
C: Collector,
C::Fruit: IntoIterator<Item = (f32, DocAddress)>,
{
if !self.path.exists() {
return Ok(Vec::new());
}
if !self.path.is_dir() {
warn!("fts index at {} is not a directory?!", self.path.display());
return Ok(Vec::new());
}
let index = Index::open_in_dir(&self.path)?;
let reader = index
.reader_builder()
.reload_policy(ReloadPolicy::OnCommitWithDelay)
.try_into()?;
let schema = index.schema();
let search_fields: Vec<_> = self
.schema
.fields()
.filter(|(_, entry)| !entry.name().starts_with("_meta_"))
.map(|(field, _)| field)
.collect();
let query: String = query.into();
let mut clauses: Vec<(Occur, Box<dyn tantivy::query::Query>)> = Vec::new();
for term_str in query.split_whitespace() {
for &field in &search_fields {
let term = Term::from_field_text(field, term_str);
clauses.push((
Occur::Should,
Box::new(FuzzyTermQuery::new(term, distance, true)),
));
}
}
if clauses.is_empty() {
return Ok(Vec::new());
}
let searcher = reader.searcher();
let res = searcher.search(&BooleanQuery::new(clauses), collector.as_ref())?;
return Self::collect_results(&schema, &searcher, res);
}
fn collect_results(
schema: &Schema,
searcher: &tantivy::Searcher,
res: impl IntoIterator<Item = (f32, DocAddress)>,
) -> Result<Vec<FtsLookupResult>, TantivyError> {
let mut out = Vec::new();
for (score, doc) in res {
let retrieved_doc: TantivyDocument = searcher.doc(doc)?;

View File

@@ -18,8 +18,6 @@ pub struct FieldQuery {
source: String,
key: String,
path: String,
#[serde(default)]
download: bool,
}
/// Extract a specific field from an item's metadata
@@ -81,38 +79,21 @@ pub async fn get_field(
time_ms = start.elapsed().as_millis()
);
let disposition = if params.download {
"attachment"
} else {
"inline"
};
match value {
PileValue::String(s) => (
StatusCode::OK,
[
(header::CONTENT_TYPE, "text/plain".to_owned()),
(header::CONTENT_DISPOSITION, disposition.to_owned()),
],
[(header::CONTENT_TYPE, "text/plain")],
s.to_string(),
)
.into_response(),
PileValue::Blob { mime, bytes } => (
StatusCode::OK,
[
(header::CONTENT_TYPE, mime.to_string()),
(header::CONTENT_DISPOSITION, disposition.to_owned()),
],
[(header::CONTENT_TYPE, mime.to_string())],
bytes.as_ref().clone(),
)
.into_response(),
_ => match value.to_json(&state).await {
Ok(json) => (
StatusCode::OK,
[(header::CONTENT_DISPOSITION, disposition.to_owned())],
Json(json),
)
.into_response(),
Ok(json) => (StatusCode::OK, Json(json)).into_response(),
Err(e) => (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response(),
},
}

View File

@@ -19,8 +19,6 @@ use crate::Datasets;
pub struct ItemQuery {
source: String,
key: String,
#[serde(default)]
download: bool,
}
/// Parse a `Range: bytes=...` header value.
@@ -163,18 +161,11 @@ pub async fn item_get(
StatusCode::OK
};
let disposition = if params.download {
"attachment"
} else {
"inline"
};
let mut builder = axum::http::Response::builder()
.status(status)
.header(header::CONTENT_TYPE, mime)
.header(header::ACCEPT_RANGES, "bytes")
.header(header::CONTENT_LENGTH, length)
.header(header::CONTENT_DISPOSITION, disposition);
.header(header::CONTENT_LENGTH, length);
if is_range {
builder = builder.header(

View File

@@ -1,47 +0,0 @@
use binrw::{binrw, meta::ReadMagic};
#[binrw]
#[brw(little, magic = b"PileChaChav1")]
#[derive(Debug, Clone, Copy)]
pub struct ChaChaHeaderv1 {
pub config: ChaChaConfigv1,
pub plaintext_size: u64,
}
impl ChaChaHeaderv1 {
pub const SIZE: usize = ChaChaHeaderv1::MAGIC.len() + std::mem::size_of::<ChaChaConfigv1>() + 8;
}
#[test]
fn chachaheader_size() {
assert_eq!(ChaChaHeaderv1::SIZE, std::mem::size_of::<ChaChaHeaderv1>())
}
//
// MARK: config
//
#[binrw]
#[brw(little)]
#[derive(Debug, Clone, Copy)]
pub struct ChaChaConfigv1 {
pub chunk_size: u64,
pub nonce_size: u64,
pub tag_size: u64,
}
impl Default for ChaChaConfigv1 {
fn default() -> Self {
Self {
chunk_size: 64 * 1024,
nonce_size: 24,
tag_size: 16,
}
}
}
impl ChaChaConfigv1 {
pub(crate) fn enc_chunk_size(&self) -> u64 {
self.chunk_size + self.nonce_size + self.tag_size
}
}

View File

@@ -1,9 +0,0 @@
mod reader;
mod reader_async;
mod writer;
mod writer_async;
pub use {reader::*, reader_async::*, writer::*, writer_async::*};
mod format;
pub use format::*;

View File

@@ -1,15 +1,70 @@
use std::io::{Read, Seek, SeekFrom};
use crate::{AsyncReader, AsyncSeekReader, chacha::ChaChaHeaderv1};
use binrw::binrw;
use crate::{AsyncReader, AsyncSeekReader};
//
// MARK: header
//
/// Serialized size of [`ChaChaHeader`] in bytes: 12 magic + 3×8 config + 8 plaintext_size.
pub const HEADER_SIZE: usize = 44;
#[binrw]
#[brw(little, magic = b"PileChaChav1")]
#[derive(Debug, Clone, Copy)]
pub struct ChaChaHeader {
pub chunk_size: u64,
pub nonce_size: u64,
pub tag_size: u64,
pub plaintext_size: u64,
}
//
// MARK: config
//
#[derive(Debug, Clone, Copy)]
pub struct ChaChaReaderConfig {
pub chunk_size: u64,
pub nonce_size: u64,
pub tag_size: u64,
}
impl Default for ChaChaReaderConfig {
fn default() -> Self {
Self {
chunk_size: 1_048_576, // 1MiB
nonce_size: 24,
tag_size: 16,
}
}
}
impl ChaChaReaderConfig {
pub(crate) fn enc_chunk_size(&self) -> u64 {
self.chunk_size + self.nonce_size + self.tag_size
}
}
impl From<ChaChaHeader> for ChaChaReaderConfig {
fn from(h: ChaChaHeader) -> Self {
Self {
chunk_size: h.chunk_size,
nonce_size: h.nonce_size,
tag_size: h.tag_size,
}
}
}
//
// MARK: reader
//
pub struct ChaChaReaderv1<R: Read + Seek> {
pub struct ChaChaReader<R: Read + Seek> {
inner: R,
header: ChaChaHeaderv1,
config: ChaChaReaderConfig,
data_offset: u64,
encryption_key: [u8; 32],
cursor: u64,
@@ -17,17 +72,17 @@ pub struct ChaChaReaderv1<R: Read + Seek> {
cached_chunk: Option<(u64, Vec<u8>)>,
}
impl<R: Read + Seek> ChaChaReaderv1<R> {
impl<R: Read + Seek> ChaChaReader<R> {
pub fn new(mut inner: R, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
use binrw::BinReaderExt;
inner.seek(SeekFrom::Start(0))?;
let header: ChaChaHeaderv1 = inner.read_le().map_err(std::io::Error::other)?;
let header: ChaChaHeader = inner.read_le().map_err(std::io::Error::other)?;
let data_offset = inner.stream_position()?;
Ok(Self {
inner,
header,
config: header.into(),
data_offset,
encryption_key,
cursor: 0,
@@ -39,22 +94,21 @@ impl<R: Read + Seek> ChaChaReaderv1<R> {
fn fetch_chunk(&mut self, chunk_index: u64) -> Result<(), std::io::Error> {
use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead};
let enc_start = self.data_offset + chunk_index * self.header.config.enc_chunk_size();
let enc_start = self.data_offset + chunk_index * self.config.enc_chunk_size();
self.inner.seek(SeekFrom::Start(enc_start))?;
let mut encrypted = vec![0u8; self.header.config.enc_chunk_size() as usize];
let mut encrypted = vec![0u8; self.config.enc_chunk_size() as usize];
let n = self.read_exact_or_eof(&mut encrypted)?;
encrypted.truncate(n);
if encrypted.len() < (self.header.config.nonce_size + self.header.config.tag_size) as usize
{
if encrypted.len() < (self.config.nonce_size + self.config.tag_size) as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"encrypted chunk too short",
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(self.header.config.nonce_size as usize);
let (nonce_bytes, ciphertext) = encrypted.split_at(self.config.nonce_size as usize);
let nonce = XNonce::from_slice(nonce_bytes);
let key = chacha20poly1305::Key::from_slice(&self.encryption_key);
let cipher = XChaCha20Poly1305::new(key);
@@ -78,14 +132,14 @@ impl<R: Read + Seek> ChaChaReaderv1<R> {
}
}
impl<R: Read + Seek + Send> AsyncReader for ChaChaReaderv1<R> {
impl<R: Read + Seek + Send> AsyncReader for ChaChaReader<R> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let remaining = self.plaintext_size.saturating_sub(self.cursor);
if remaining == 0 || buf.is_empty() {
return Ok(0);
}
let chunk_index = self.cursor / self.header.config.chunk_size;
let chunk_index = self.cursor / self.config.chunk_size;
let need_fetch = match &self.cached_chunk {
None => true,
@@ -99,7 +153,7 @@ impl<R: Read + Seek + Send> AsyncReader for ChaChaReaderv1<R> {
#[expect(clippy::unwrap_used)]
let (_, chunk_data) = self.cached_chunk.as_ref().unwrap();
let offset_in_chunk = (self.cursor % self.header.config.chunk_size) as usize;
let offset_in_chunk = (self.cursor % self.config.chunk_size) as usize;
let available = chunk_data.len() - offset_in_chunk;
let to_copy = available.min(buf.len());
@@ -109,7 +163,7 @@ impl<R: Read + Seek + Send> AsyncReader for ChaChaReaderv1<R> {
}
}
impl<R: Read + Seek + Send> AsyncSeekReader for ChaChaReaderv1<R> {
impl<R: Read + Seek + Send> AsyncSeekReader for ChaChaReader<R> {
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
match pos {
SeekFrom::Start(x) => self.cursor = x.min(self.plaintext_size),

View File

@@ -1,11 +1,10 @@
use std::io::SeekFrom;
use crate::{AsyncReader, AsyncSeekReader, chacha::ChaChaHeaderv1};
use crate::{AsyncReader, AsyncSeekReader, ChaChaHeader, ChaChaReaderConfig, HEADER_SIZE};
pub struct ChaChaReaderv1Async<R: AsyncSeekReader> {
pub struct ChaChaReaderAsync<R: AsyncSeekReader> {
inner: R,
header: ChaChaHeaderv1,
config: ChaChaReaderConfig,
data_offset: u64,
encryption_key: [u8; 32],
cursor: u64,
@@ -13,22 +12,22 @@ pub struct ChaChaReaderv1Async<R: AsyncSeekReader> {
cached_chunk: Option<(u64, Vec<u8>)>,
}
impl<R: AsyncSeekReader> ChaChaReaderv1Async<R> {
impl<R: AsyncSeekReader> ChaChaReaderAsync<R> {
pub async fn new(mut inner: R, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
use binrw::BinReaderExt;
use std::io::Cursor;
inner.seek(SeekFrom::Start(0)).await?;
let mut buf = [0u8; ChaChaHeaderv1::SIZE];
let mut buf = [0u8; HEADER_SIZE];
read_exact(&mut inner, &mut buf).await?;
let header: ChaChaHeaderv1 = Cursor::new(&buf[..])
let header: ChaChaHeader = Cursor::new(&buf[..])
.read_le()
.map_err(std::io::Error::other)?;
Ok(Self {
inner,
header,
data_offset: buf.len() as u64,
config: header.into(),
data_offset: HEADER_SIZE as u64,
encryption_key,
cursor: 0,
plaintext_size: header.plaintext_size,
@@ -39,22 +38,21 @@ impl<R: AsyncSeekReader> ChaChaReaderv1Async<R> {
async fn fetch_chunk(&mut self, chunk_index: u64) -> Result<(), std::io::Error> {
use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::Aead};
let enc_start = self.data_offset + chunk_index * self.header.config.enc_chunk_size();
let enc_start = self.data_offset + chunk_index * self.config.enc_chunk_size();
self.inner.seek(SeekFrom::Start(enc_start)).await?;
let mut encrypted = vec![0u8; self.header.config.enc_chunk_size() as usize];
let mut encrypted = vec![0u8; self.config.enc_chunk_size() as usize];
let n = read_exact_or_eof(&mut self.inner, &mut encrypted).await?;
encrypted.truncate(n);
if encrypted.len() < (self.header.config.nonce_size + self.header.config.tag_size) as usize
{
if encrypted.len() < (self.config.nonce_size + self.config.tag_size) as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"encrypted chunk too short",
));
}
let (nonce_bytes, ciphertext) = encrypted.split_at(self.header.config.nonce_size as usize);
let (nonce_bytes, ciphertext) = encrypted.split_at(self.config.nonce_size as usize);
let nonce = XNonce::from_slice(nonce_bytes);
let key = chacha20poly1305::Key::from_slice(&self.encryption_key);
let cipher = XChaCha20Poly1305::new(key);
@@ -92,14 +90,14 @@ async fn read_exact_or_eof<R: AsyncReader>(
Ok(total)
}
impl<R: AsyncSeekReader> AsyncReader for ChaChaReaderv1Async<R> {
impl<R: AsyncSeekReader> AsyncReader for ChaChaReaderAsync<R> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let remaining = self.plaintext_size.saturating_sub(self.cursor);
if remaining == 0 || buf.is_empty() {
return Ok(0);
}
let chunk_index = self.cursor / self.header.config.chunk_size;
let chunk_index = self.cursor / self.config.chunk_size;
let need_fetch = match &self.cached_chunk {
None => true,
@@ -113,7 +111,7 @@ impl<R: AsyncSeekReader> AsyncReader for ChaChaReaderv1Async<R> {
#[expect(clippy::unwrap_used)]
let (_, chunk_data) = self.cached_chunk.as_ref().unwrap();
let offset_in_chunk = (self.cursor % self.header.config.chunk_size) as usize;
let offset_in_chunk = (self.cursor % self.config.chunk_size) as usize;
let available = chunk_data.len() - offset_in_chunk;
let to_copy = available.min(buf.len());
@@ -123,7 +121,7 @@ impl<R: AsyncSeekReader> AsyncReader for ChaChaReaderv1Async<R> {
}
}
impl<R: AsyncSeekReader> AsyncSeekReader for ChaChaReaderv1Async<R> {
impl<R: AsyncSeekReader> AsyncSeekReader for ChaChaReaderAsync<R> {
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, std::io::Error> {
match pos {
SeekFrom::Start(x) => self.cursor = x.min(self.plaintext_size),

View File

@@ -1,6 +1,6 @@
use std::io::{Seek, SeekFrom, Write};
use crate::chacha::{ChaChaConfigv1, ChaChaHeaderv1};
use crate::{ChaChaHeader, ChaChaReaderConfig};
/// Generate a random 32-byte encryption key suitable for use with [`ChaChaWriter`].
pub fn generate_key() -> [u8; 32] {
@@ -9,28 +9,30 @@ pub fn generate_key() -> [u8; 32] {
XChaCha20Poly1305::generate_key(&mut OsRng).into()
}
pub struct ChaChaWriterv1<W: Write + Seek> {
pub struct ChaChaWriter<W: Write + Seek> {
inner: W,
header: ChaChaHeaderv1,
config: ChaChaReaderConfig,
encryption_key: [u8; 32],
buffer: Vec<u8>,
plaintext_bytes_written: u64,
}
impl<W: Write + Seek> ChaChaWriterv1<W> {
impl<W: Write + Seek> ChaChaWriter<W> {
pub fn new(mut inner: W, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
use binrw::BinWriterExt;
let header = ChaChaHeaderv1 {
config: ChaChaConfigv1::default(),
let config = ChaChaReaderConfig::default();
let header = ChaChaHeader {
chunk_size: config.chunk_size,
nonce_size: config.nonce_size,
tag_size: config.tag_size,
plaintext_size: 0,
};
inner.write_le(&header).map_err(std::io::Error::other)?;
Ok(Self {
inner,
header,
config,
encryption_key,
buffer: Vec::new(),
plaintext_bytes_written: 0,
@@ -45,8 +47,10 @@ impl<W: Write + Seek> ChaChaWriterv1<W> {
self.flush_buffer()?;
self.inner.seek(SeekFrom::Start(0))?;
let header = ChaChaHeaderv1 {
config: self.header.config,
let header = ChaChaHeader {
chunk_size: self.config.chunk_size,
nonce_size: self.config.nonce_size,
tag_size: self.config.tag_size,
plaintext_size: self.plaintext_bytes_written,
};
self.inner
@@ -85,12 +89,12 @@ impl<W: Write + Seek> ChaChaWriterv1<W> {
}
}
impl<W: Write + Seek> Write for ChaChaWriterv1<W> {
impl<W: Write + Seek> Write for ChaChaWriter<W> {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.buffer.extend_from_slice(buf);
self.plaintext_bytes_written += buf.len() as u64;
let chunk_size = self.header.config.chunk_size as usize;
let chunk_size = self.config.chunk_size as usize;
while self.buffer.len() >= chunk_size {
let encrypted = self.encrypt_chunk(&self.buffer[..chunk_size])?;
self.inner.write_all(&encrypted)?;
@@ -116,13 +120,13 @@ impl<W: Write + Seek> Write for ChaChaWriterv1<W> {
mod tests {
use std::io::{Cursor, SeekFrom, Write};
use super::ChaChaWriterv1;
use crate::{AsyncReader, AsyncSeekReader, chacha::ChaChaReaderv1};
use super::ChaChaWriter;
use crate::{AsyncReader, AsyncSeekReader, ChaChaReader};
const KEY: [u8; 32] = [42u8; 32];
fn encrypt(data: &[u8]) -> Cursor<Vec<u8>> {
let mut writer = ChaChaWriterv1::new(Cursor::new(Vec::new()), KEY).unwrap();
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap();
writer.write_all(data).unwrap();
let mut buf = writer.finish().unwrap();
buf.set_position(0);
@@ -130,7 +134,7 @@ mod tests {
}
async fn decrypt_all(buf: Cursor<Vec<u8>>) -> Vec<u8> {
let mut reader = ChaChaReaderv1::new(buf, KEY).unwrap();
let mut reader = ChaChaReader::new(buf, KEY).unwrap();
reader.read_to_end().await.unwrap()
}
@@ -165,7 +169,7 @@ mod tests {
async fn roundtrip_incremental_writes() {
// Write one byte at a time
let data: Vec<u8> = (0u8..200).collect();
let mut writer = ChaChaWriterv1::new(Cursor::new(Vec::new()), KEY).unwrap();
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), KEY).unwrap();
for byte in &data {
writer.write_all(&[*byte]).unwrap();
}
@@ -177,7 +181,7 @@ mod tests {
#[tokio::test]
async fn wrong_key_fails() {
let buf = encrypt(b"secret data");
let mut reader = ChaChaReaderv1::new(buf, [0u8; 32]).unwrap();
let mut reader = ChaChaReader::new(buf, [0u8; 32]).unwrap();
assert!(reader.read_to_end().await.is_err());
}
@@ -187,13 +191,13 @@ mod tests {
let mut buf = encrypt(b"data");
buf.get_mut()[0] = 0xFF;
buf.set_position(0);
assert!(ChaChaReaderv1::new(buf, KEY).is_err());
assert!(ChaChaReader::new(buf, KEY).is_err());
}
#[tokio::test]
async fn seek_from_start() {
let data: Vec<u8> = (0u8..100).collect();
let mut reader = ChaChaReaderv1::new(encrypt(&data), KEY).unwrap();
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
reader.seek(SeekFrom::Start(50)).await.unwrap();
let mut buf = [0u8; 10];
@@ -207,7 +211,7 @@ mod tests {
#[tokio::test]
async fn seek_from_end() {
let data: Vec<u8> = (0u8..100).collect();
let mut reader = ChaChaReaderv1::new(encrypt(&data), KEY).unwrap();
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
reader.seek(SeekFrom::End(-10)).await.unwrap();
assert_eq!(reader.read_to_end().await.unwrap(), &data[90..]);
@@ -217,7 +221,7 @@ mod tests {
async fn seek_across_chunk_boundary() {
// Seek to 6 bytes before the end of chunk 0, read 12 bytes spanning into chunk 1
let data: Vec<u8> = (0u8..=255).cycle().take(65536 + 500).collect();
let mut reader = ChaChaReaderv1::new(encrypt(&data), KEY).unwrap();
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
reader.seek(SeekFrom::Start(65530)).await.unwrap();
let mut buf = vec![0u8; 12];
@@ -231,7 +235,7 @@ mod tests {
#[tokio::test]
async fn seek_current() {
let data: Vec<u8> = (0u8..=255).cycle().take(200).collect();
let mut reader = ChaChaReaderv1::new(encrypt(&data), KEY).unwrap();
let mut reader = ChaChaReader::new(encrypt(&data), KEY).unwrap();
// Read 10, seek back 5, read 5 — should get bytes 5..10
let mut first = [0u8; 10];
@@ -251,7 +255,7 @@ mod tests {
#[tokio::test]
async fn seek_past_end_clamps() {
let data = b"hello";
let mut reader = ChaChaReaderv1::new(encrypt(data), KEY).unwrap();
let mut reader = ChaChaReader::new(encrypt(data), KEY).unwrap();
let pos = reader.seek(SeekFrom::Start(9999)).await.unwrap();
assert_eq!(pos, data.len() as u64);

View File

@@ -2,12 +2,11 @@ use std::io::SeekFrom;
use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt};
use crate::chacha::{ChaChaConfigv1, ChaChaHeaderv1};
use crate::{ChaChaHeader, ChaChaReaderConfig};
pub struct ChaChaWriterAsync<W: AsyncWrite + AsyncSeek + Unpin + Send> {
inner: W,
header: ChaChaHeaderv1,
config: ChaChaReaderConfig,
encryption_key: [u8; 32],
buffer: Vec<u8>,
plaintext_bytes_written: u64,
@@ -15,15 +14,18 @@ pub struct ChaChaWriterAsync<W: AsyncWrite + AsyncSeek + Unpin + Send> {
impl<W: AsyncWrite + AsyncSeek + Unpin + Send> ChaChaWriterAsync<W> {
pub async fn new(mut inner: W, encryption_key: [u8; 32]) -> Result<Self, std::io::Error> {
let header = ChaChaHeaderv1 {
config: ChaChaConfigv1::default(),
let config = ChaChaReaderConfig::default();
let header_bytes = serialize_header(ChaChaHeader {
chunk_size: config.chunk_size,
nonce_size: config.nonce_size,
tag_size: config.tag_size,
plaintext_size: 0,
};
inner.write_all(&serialize_header(header)?).await?;
})?;
inner.write_all(&header_bytes).await?;
Ok(Self {
inner,
header,
config,
encryption_key,
buffer: Vec::new(),
plaintext_bytes_written: 0,
@@ -34,7 +36,7 @@ impl<W: AsyncWrite + AsyncSeek + Unpin + Send> ChaChaWriterAsync<W> {
self.buffer.extend_from_slice(buf);
self.plaintext_bytes_written += buf.len() as u64;
let chunk_size = self.header.config.chunk_size as usize;
let chunk_size = self.config.chunk_size as usize;
while self.buffer.len() >= chunk_size {
let encrypted = encrypt_chunk(&self.encryption_key, &self.buffer[..chunk_size])?;
self.inner.write_all(&encrypted).await?;
@@ -53,8 +55,10 @@ impl<W: AsyncWrite + AsyncSeek + Unpin + Send> ChaChaWriterAsync<W> {
}
self.inner.seek(SeekFrom::Start(0)).await?;
let header_bytes = serialize_header(ChaChaHeaderv1 {
config: self.header.config,
let header_bytes = serialize_header(ChaChaHeader {
chunk_size: self.config.chunk_size,
nonce_size: self.config.nonce_size,
tag_size: self.config.tag_size,
plaintext_size: self.plaintext_bytes_written,
})?;
self.inner.write_all(&header_bytes).await?;
@@ -81,7 +85,7 @@ fn encrypt_chunk(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, std::io::E
Ok(output)
}
fn serialize_header(header: ChaChaHeaderv1) -> Result<Vec<u8>, std::io::Error> {
fn serialize_header(header: ChaChaHeader) -> Result<Vec<u8>, std::io::Error> {
use binrw::BinWriterExt;
use std::io::Cursor;

View File

@@ -4,4 +4,14 @@ pub use asyncreader::*;
mod s3reader;
pub use s3reader::*;
pub mod chacha;
mod chachareader;
pub use chachareader::*;
mod chachawriter;
pub use chachawriter::*;
mod chachareader_async;
pub use chachareader_async::*;
mod chachawriter_async;
pub use chachawriter_async::*;

View File

@@ -1,103 +1,10 @@
use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region};
use smartstring::{LazyCompact, SmartString};
use std::{fmt::Debug, io::SeekFrom, sync::Arc};
use std::{io::SeekFrom, sync::Arc};
use crate::{AsyncReader, AsyncSeekReader};
//
// MARK: client
//
/// An interface to an S3 bucket.
///
/// This struct is [Send] + [Sync],
/// and should be shared between threads in an [Arc].
///
/// It provides intelligent caching that reduces the amount of data we need to download.
pub struct S3Client {
pub client: aws_sdk_s3::Client,
bucket: SmartString<LazyCompact>,
/// maximum number of bytes to use for cached data
cache_limit_bytes: usize,
}
impl Debug for S3Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("S3Client")
.field("bucket", &self.bucket)
.field("cache_limit_bytes", &self.cache_limit_bytes)
.finish()
}
}
impl S3Client {
pub async fn new(
bucket: &str,
endpoint: Option<&str>,
region: &str,
access_key_id: &str,
secret_access_key: &str,
cache_limit_bytes: usize,
) -> Arc<Self> {
let client = {
let mut s3_config = aws_sdk_s3::config::Builder::new()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region.to_owned()))
.credentials_provider(Credentials::new(
access_key_id,
secret_access_key,
None,
None,
"pile",
));
if let Some(ep) = endpoint {
s3_config = s3_config.endpoint_url(ep).force_path_style(true);
}
aws_sdk_s3::Client::from_conf(s3_config.build())
};
return Arc::new(Self {
bucket: bucket.into(),
client,
cache_limit_bytes,
});
}
pub fn bucket(&self) -> &str {
&self.bucket
}
pub async fn get(self: &Arc<Self>, key: &str) -> Result<S3Reader, std::io::Error> {
let head = self
.client
.head_object()
.bucket(self.bucket.as_str())
.key(key)
.send()
.await
.map_err(std::io::Error::other)?;
let size = head.content_length().unwrap_or(0) as u64;
Ok(S3Reader {
client: self.clone(),
bucket: self.bucket.clone(),
key: key.into(),
cursor: 0,
size,
})
}
}
//
// MARK: reader
//
pub struct S3Reader {
pub client: Arc<S3Client>,
pub client: Arc<aws_sdk_s3::Client>,
pub bucket: SmartString<LazyCompact>,
pub key: SmartString<LazyCompact>,
pub cursor: u64,
@@ -116,7 +23,6 @@ impl AsyncReader for S3Reader {
let end_byte = start_byte + len_to_read - 1;
let resp = self
.client
.client
.get_object()
.bucket(self.bucket.as_str())

View File

@@ -31,6 +31,7 @@ image = { workspace = true, optional = true }
id3 = { workspace = true }
tokio = { workspace = true }
async-trait = { workspace = true }
aws-sdk-s3 = { workspace = true }
mime = { workspace = true }
mime_guess = { workspace = true }

View File

@@ -1,9 +1,9 @@
use aws_sdk_s3::config::{BehaviorVersion, Credentials, Region};
use chrono::{DateTime, Utc};
use pile_config::{
Label,
Label, S3Credentials,
pattern::{GroupPattern, GroupSegment},
};
use pile_io::S3Client;
use smartstring::{LazyCompact, SmartString};
use std::{
collections::{HashMap, HashSet},
@@ -19,9 +19,9 @@ use crate::{
#[derive(Debug)]
pub struct S3DataSource {
pub name: Label,
pub client: Arc<S3Client>,
pub bucket: SmartString<LazyCompact>,
pub prefix: Option<SmartString<LazyCompact>>,
pub client: Arc<aws_sdk_s3::Client>,
pub pattern: GroupPattern,
pub encryption_key: Option<[u8; 32]>,
pub index: OnceLock<HashMap<SmartString<LazyCompact>, Item>>,
@@ -30,30 +30,40 @@ pub struct S3DataSource {
impl S3DataSource {
pub async fn new(
name: &Label,
bucket: &str,
prefix: Option<&str>,
endpoint: Option<&str>,
region: &str,
access_key_id: &str,
secret_access_key: &str,
cache_limit_bytes: usize,
bucket: String,
prefix: Option<String>,
endpoint: Option<String>,
region: String,
credentials: &S3Credentials,
pattern: GroupPattern,
encryption_key: Option<[u8; 32]>,
) -> Result<Arc<Self>, std::io::Error> {
let client = S3Client::new(
bucket,
endpoint,
region,
access_key_id,
secret_access_key,
cache_limit_bytes,
)
.await;
let client = {
let creds = Credentials::new(
&credentials.access_key_id,
&credentials.secret_access_key,
None,
None,
"pile",
);
let mut s3_config = aws_sdk_s3::config::Builder::new()
.behavior_version(BehaviorVersion::latest())
.region(Region::new(region))
.credentials_provider(creds);
if let Some(ep) = endpoint {
s3_config = s3_config.endpoint_url(ep).force_path_style(true);
}
aws_sdk_s3::Client::from_conf(s3_config.build())
};
let source = Arc::new(Self {
name: name.clone(),
client,
bucket: bucket.into(),
prefix: prefix.map(|x| x.into()),
client: Arc::new(client),
pattern,
encryption_key,
index: OnceLock::new(),
@@ -68,10 +78,9 @@ impl S3DataSource {
loop {
let mut req = source
.client
.client
.list_objects_v2()
.bucket(source.client.bucket());
.bucket(source.bucket.as_str());
if let Some(prefix) = &source.prefix {
req = req.prefix(prefix.as_str());
@@ -182,11 +191,7 @@ impl DataSource for Arc<S3DataSource> {
let mut continuation_token: Option<String> = None;
loop {
let mut req = self
.client
.client
.list_objects_v2()
.bucket(self.client.bucket());
let mut req = self.client.list_objects_v2().bucket(self.bucket.as_str());
if let Some(prefix) = &self.prefix {
req = req.prefix(prefix.as_str());

View File

@@ -1,6 +1,6 @@
use mime::Mime;
use pile_config::Label;
use pile_io::{SyncReadBridge, chacha::ChaChaReaderv1Async};
use pile_io::{ChaChaReaderAsync, S3Reader, SyncReadBridge};
use smartstring::{LazyCompact, SmartString};
use std::{collections::HashMap, fs::File, path::PathBuf, sync::Arc};
@@ -59,13 +59,39 @@ impl Item {
}
};
let reader = source.client.get(&full_key).await?;
let head = source
.client
.head_object()
.bucket(source.bucket.as_str())
.key(full_key.as_str())
.send()
.await
.map_err(std::io::Error::other)?;
let size = head.content_length().unwrap_or(0) as u64;
match source.encryption_key {
None => ItemReader::S3(reader),
Some(enc_key) => {
ItemReader::EncryptedS3(ChaChaReaderv1Async::new(reader, enc_key).await?)
}
None => ItemReader::S3(S3Reader {
client: source.client.clone(),
bucket: source.bucket.clone(),
key: full_key,
cursor: 0,
size,
}),
Some(enc_key) => ItemReader::EncryptedS3(
ChaChaReaderAsync::new(
S3Reader {
client: source.client.clone(),
bucket: source.bucket.clone(),
key: full_key,
cursor: 0,
size,
},
enc_key,
)
.await?,
),
}
}
})

View File

@@ -1,4 +1,4 @@
use pile_io::{AsyncReader, AsyncSeekReader, S3Reader, chacha::ChaChaReaderv1Async};
use pile_io::{AsyncReader, AsyncSeekReader, ChaChaReaderAsync, S3Reader};
use std::{fs::File, io::Seek};
//
@@ -8,7 +8,7 @@ use std::{fs::File, io::Seek};
pub enum ItemReader {
File(File),
S3(S3Reader),
EncryptedS3(ChaChaReaderv1Async<S3Reader>),
EncryptedS3(ChaChaReaderAsync<S3Reader>),
}
impl AsyncReader for ItemReader {

View File

@@ -1,7 +1,6 @@
use anyhow::{Context, Result};
use clap::Args;
use pile_io::AsyncReader;
use pile_io::chacha::{ChaChaReaderv1, ChaChaWriterv1};
use pile_io::{AsyncReader, ChaChaReader, ChaChaWriter};
use pile_toolbox::cancelabletask::{CancelFlag, CancelableTaskError};
use pile_value::source::string_to_key;
use std::io::{Cursor, Write};
@@ -38,7 +37,7 @@ impl CliCmd for EncryptCommand {
.await
.with_context(|| format!("while reading '{}'", self.path.display()))?;
let mut writer = ChaChaWriterv1::new(Cursor::new(Vec::new()), key)
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), key)
.context("while initializing encryptor")?;
writer.write_all(&plaintext).context("while encrypting")?;
let buf = writer.finish().context("while finalizing encryptor")?;
@@ -62,7 +61,7 @@ impl CliCmd for DecryptCommand {
.await
.with_context(|| format!("while reading '{}'", self.path.display()))?;
let mut reader = ChaChaReaderv1::new(Cursor::new(ciphertext), key)
let mut reader = ChaChaReader::new(Cursor::new(ciphertext), key)
.context("while initializing decryptor")?;
let plaintext = reader.read_to_end().await.context("while decrypting")?;

View File

@@ -4,7 +4,7 @@ use clap::Args;
use indicatif::ProgressBar;
use pile_config::Label;
use pile_dataset::{Dataset, Datasets};
use pile_io::chacha::ChaChaWriterv1;
use pile_io::ChaChaWriter;
use pile_toolbox::cancelabletask::{CancelFlag, CancelableTaskError};
use pile_value::source::{DataSource, DirDataSource, S3DataSource, encrypt_path};
use std::{
@@ -71,12 +71,12 @@ impl CliCmd for UploadCommand {
let bucket = self
.bucket
.as_deref()
.unwrap_or(s3_ds.client.bucket())
.unwrap_or(s3_ds.bucket.as_str())
.to_owned();
let full_prefix = self.prefix.trim_matches('/').to_owned();
// Check for existing objects at the target prefix
let existing_keys = list_prefix(&s3_ds.client.client, &bucket, &full_prefix)
let existing_keys = list_prefix(&s3_ds.client, &bucket, &full_prefix)
.await
.context("while checking for existing objects at target prefix")?;
@@ -89,7 +89,6 @@ impl CliCmd for UploadCommand {
);
for key in &existing_keys {
s3_ds
.client
.client
.delete_object()
.bucket(&bucket)
@@ -170,7 +169,7 @@ impl CliCmd for UploadCommand {
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<u8>> {
let plaintext = std::fs::read(&path)
.with_context(|| format!("while opening '{}'", path.display()))?;
let mut writer = ChaChaWriterv1::new(Cursor::new(Vec::new()), enc_key)
let mut writer = ChaChaWriter::new(Cursor::new(Vec::new()), enc_key)
.context("while initializing encryptor")?;
writer.write_all(&plaintext).context("while encrypting")?;
Ok(writer.finish().context("while finalizing")?.into_inner())
@@ -185,7 +184,6 @@ impl CliCmd for UploadCommand {
};
client
.client
.put_object()
.bucket(&bucket)
.key(&key)
@@ -226,7 +224,7 @@ fn get_dir_source(
label: &Label,
name: &str,
) -> Result<Arc<DirDataSource>, anyhow::Error> {
match ds.sources.get(label).or(ds.disabled_sources.get(label)) {
match ds.sources.get(label) {
Some(Dataset::Dir(d)) => Ok(Arc::clone(d)),
Some(_) => Err(anyhow::anyhow!(
"source '{name}' is not a filesystem source"
@@ -242,7 +240,7 @@ fn get_s3_source(
label: &Label,
name: &str,
) -> Result<Arc<S3DataSource>, anyhow::Error> {
match ds.sources.get(label).or(ds.disabled_sources.get(label)) {
match ds.sources.get(label) {
Some(Dataset::S3(s)) => Ok(Arc::clone(s)),
Some(_) => Err(anyhow::anyhow!("source '{name}' is not an S3 source")),
None => Err(anyhow::anyhow!("s3 source '{name}' not found in config")),