Add server client
Some checks failed
CI / Typos (push) Successful in 24s
CI / Clippy (push) Successful in 1m16s
CI / Build and test (all features) (push) Failing after 5m5s
CI / Build and test (push) Failing after 6m55s

This commit is contained in:
2026-03-23 21:53:39 -07:00
parent dfcb4b0a24
commit e83c522e78
11 changed files with 673 additions and 11 deletions

255
Cargo.lock generated
View File

@@ -928,6 +928,16 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -1393,6 +1403,21 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
@@ -1433,6 +1458,12 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]]
name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.32"
@@ -1463,8 +1494,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"slab",
]
@@ -1780,6 +1814,22 @@ dependencies = [
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper 1.8.1",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
@@ -1798,9 +1848,11 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2 0.6.2",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -2019,6 +2071,16 @@ version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2"
[[package]]
name = "iri-string"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
dependencies = [
"memchr",
"serde",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.2"
@@ -2351,6 +2413,23 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13d2233c9842d08cfe13f9eac96e207ca6a2ea10b80259ebe8ad0268be27d2af"
[[package]]
name = "native-tls"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -2419,12 +2498,50 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl"
version = "0.10.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "openssl-probe"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe"
[[package]]
name = "openssl-sys"
version = "0.9.112"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "outref"
version = "0.5.2"
@@ -2584,6 +2701,17 @@ dependencies = [
"utoipa-swagger-ui",
]
[[package]]
name = "pile-client"
version = "0.0.2"
dependencies = [
"bytes",
"futures-core",
"reqwest",
"serde",
"thiserror",
]
[[package]]
name = "pile-config"
version = "0.0.2"
@@ -2911,6 +3039,49 @@ version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58"
[[package]]
name = "reqwest"
version = "0.12.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.4.13",
"http 1.4.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.8.1",
"hyper-rustls 0.27.7",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
]
[[package]]
name = "rfc6979"
version = "0.3.1"
@@ -3150,7 +3321,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags",
"core-foundation",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
"security-framework-sys",
@@ -3499,6 +3670,9 @@ name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
dependencies = [
"futures-core",
]
[[package]]
name = "synstructure"
@@ -3511,6 +3685,27 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "system-configuration"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "tantivy"
version = "0.25.0"
@@ -3783,6 +3978,16 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.24.1"
@@ -3882,6 +4087,24 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-http"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [
"bitflags",
"bytes",
"futures-util",
"http 1.4.0",
"http-body 1.0.1",
"iri-string",
"pin-project-lite",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
@@ -4184,6 +4407,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vecmath"
version = "1.0.0"
@@ -4349,6 +4578,19 @@ dependencies = [
"wasmparser",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
dependencies = [
"futures-util",
"js-sys",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "wasmparser"
version = "0.244.0"
@@ -4459,6 +4701,17 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-result"
version = "0.4.1"

View File

@@ -70,6 +70,7 @@ pile-flac = { path = "crates/pile-flac" }
pile-dataset = { path = "crates/pile-dataset" }
pile-value = { path = "crates/pile-value" }
pile-io = { path = "crates/pile-io" }
pile-client = { path = "crates/pile-client" }
# Clients & servers
tantivy = "0.25.0"

View File

@@ -0,0 +1,15 @@
[package]
name = "pile-client"
version = { workspace = true }
rust-version = { workspace = true }
edition = { workspace = true }
[lints]
workspace = true
[dependencies]
reqwest = { version = "0.12", features = ["json", "stream"] }
futures-core = "0.3"
serde = { workspace = true }
thiserror = { workspace = true }
bytes = { workspace = true }

View File

@@ -0,0 +1,230 @@
use bytes::Bytes;
use futures_core::Stream;
use reqwest::{Client, StatusCode, header};
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use thiserror::Error;
//
// MARK: Error
//
#[derive(Debug, Error)]
pub enum ClientError {
#[error("invalid bearer token")]
InvalidToken,
#[error("HTTP {status}: {body}")]
Http { status: StatusCode, body: String },
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
}
//
// MARK: Response types
//
#[derive(Debug, Deserialize)]
pub struct DatasetInfo {
pub name: String,
}
#[derive(Debug, Serialize)]
pub struct LookupRequest {
pub query: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub limit: Option<usize>,
}
#[derive(Debug, Deserialize)]
pub struct LookupResult {
pub score: f32,
pub source: String,
pub key: String,
}
#[derive(Debug, Deserialize)]
pub struct LookupResponse {
pub results: Vec<LookupResult>,
}
#[derive(Debug, Deserialize)]
pub struct ItemRef {
pub source: String,
pub key: String,
}
#[derive(Debug, Deserialize)]
pub struct ItemsResponse {
pub items: Vec<ItemRef>,
pub total: usize,
pub offset: usize,
pub limit: usize,
}
/// Raw field response: the content-type and body bytes as returned by the server.
pub struct FieldResponse {
pub content_type: String,
pub data: Bytes,
}
//
// MARK: PileClient
//
/// A client for a pile server. Use [`PileClient::dataset`] to get a dataset-scoped client.
pub struct PileClient {
base_url: String,
client: Client,
}
impl PileClient {
pub fn new(base_url: impl Into<String>, token: Option<&str>) -> Result<Self, ClientError> {
let mut headers = header::HeaderMap::new();
if let Some(token) = token {
let value = header::HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|_| ClientError::InvalidToken)?;
headers.insert(header::AUTHORIZATION, value);
}
let client = Client::builder()
.default_headers(headers)
.build()
.map_err(ClientError::Reqwest)?;
Ok(Self {
base_url: base_url.into(),
client,
})
}
/// Returns a client scoped to a specific dataset (i.e. `/{name}/...`).
pub fn dataset(&self, name: &str) -> DatasetClient {
DatasetClient {
base_url: format!("{}/{name}", self.base_url),
client: self.client.clone(),
}
}
/// `GET /datasets` — list all datasets served by this server.
pub async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, ClientError> {
let resp = self
.client
.get(format!("{}/datasets", self.base_url))
.send()
.await?;
check_status(resp).await?.json().await.map_err(Into::into)
}
}
//
// MARK: DatasetClient
//
/// A client scoped to a single dataset on the server.
pub struct DatasetClient {
base_url: String,
client: Client,
}
impl DatasetClient {
/// `POST /lookup` — full-text search within this dataset.
pub async fn lookup(
&self,
query: impl Into<String>,
limit: Option<usize>,
) -> Result<LookupResponse, ClientError> {
let body = LookupRequest {
query: query.into(),
limit,
};
let resp = self
.client
.post(format!("{}/lookup", self.base_url))
.json(&body)
.send()
.await?;
check_status(resp).await?.json().await.map_err(Into::into)
}
/// `GET /item` — stream the raw bytes of an item.
///
/// The returned stream yields chunks as they arrive from the server.
pub async fn get_item(
&self,
source: &str,
key: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>, ClientError> {
let resp = self
.client
.get(format!("{}/item", self.base_url))
.query(&[("source", source), ("key", key)])
.send()
.await?;
Ok(Box::pin(check_status(resp).await?.bytes_stream()))
}
/// `GET /field` — extract a field from an item by object path (e.g. `$.flac.title`).
pub async fn get_field(
&self,
source: &str,
key: &str,
path: &str,
) -> Result<FieldResponse, ClientError> {
let resp = self
.client
.get(format!("{}/field", self.base_url))
.query(&[("source", source), ("key", key), ("path", path)])
.send()
.await?;
let resp = check_status(resp).await?;
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/octet-stream")
.to_owned();
let data = resp.bytes().await?;
Ok(FieldResponse { content_type, data })
}
/// `GET /items` — paginate over all items in this dataset, ordered by (source, key).
pub async fn list_items(
&self,
offset: usize,
limit: usize,
) -> Result<ItemsResponse, ClientError> {
let resp = self
.client
.get(format!("{}/items", self.base_url))
.query(&[("offset", offset), ("limit", limit)])
.send()
.await?;
check_status(resp).await?.json().await.map_err(Into::into)
}
}
//
// MARK: helpers
//
async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, ClientError> {
let status = resp.status();
if status.is_success() {
return Ok(resp);
}
let body = resp.text().await.unwrap_or_default();
Err(ClientError::Http { status, body })
}

View File

@@ -61,6 +61,13 @@ impl Dataset {
}
}
pub fn iter_page(&self, offset: usize, limit: usize) -> Box<dyn Iterator<Item = &Item> + Send + '_> {
match self {
Self::Dir(ds) => Box::new(ds.iter_page(offset, limit)),
Self::S3(ds) => Box::new(ds.iter_page(offset, limit)),
}
}
pub async fn latest_change(&self) -> Result<Option<DateTime<Utc>>, std::io::Error> {
match self {
Self::Dir(ds) => ds.latest_change().await,

View File

@@ -0,0 +1,104 @@
use axum::{
Json,
extract::{Query, State},
http::StatusCode,
response::{IntoResponse, Response},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::debug;
use utoipa::ToSchema;
use crate::Datasets;
#[derive(Deserialize, ToSchema)]
pub struct ItemsQuery {
#[serde(default)]
offset: usize,
#[serde(default = "default_limit")]
limit: usize,
}
fn default_limit() -> usize {
100
}
#[derive(Serialize, ToSchema)]
pub struct ItemsResponse {
pub items: Vec<ItemRef>,
pub total: usize,
pub offset: usize,
pub limit: usize,
}
#[derive(Serialize, ToSchema)]
pub struct ItemRef {
pub source: String,
pub key: String,
}
/// List all items across all sources with consistent ordering, paginated by offset and limit
#[utoipa::path(
get,
path = "/items",
params(
("offset" = usize, Query, description = "Number of items to skip"),
("limit" = usize, Query, description = "Maximum number of items to return (max 1000)"),
),
responses(
(status = 200, description = "Paginated list of items", body = ItemsResponse),
)
)]
pub async fn items_list(
State(state): State<Arc<Datasets>>,
Query(params): Query<ItemsQuery>,
) -> Response {
let limit = params.limit.min(1000);
let offset = params.offset;
debug!(message = "Serving /items", offset, limit);
// Sort sources by label for a consistent global order: (source, key)
let mut source_labels: Vec<_> = state.sources.keys().collect();
source_labels.sort();
let mut items: Vec<ItemRef> = Vec::with_capacity(limit);
let mut total = 0usize;
let mut remaining_offset = offset;
for label in source_labels {
let dataset = &state.sources[label];
let source_len = dataset.len();
if remaining_offset >= source_len {
// This entire source is before our window; skip it efficiently
remaining_offset -= source_len;
total += source_len;
continue;
}
let want = (limit - items.len()).min(source_len - remaining_offset);
let source_str = label.as_str().to_owned();
for item in dataset.iter_page(remaining_offset, want) {
items.push(ItemRef {
source: source_str.clone(),
key: item.key().to_string(),
});
}
remaining_offset = 0;
total += source_len;
}
debug!(message = "Served /items", offset, limit, total);
(
StatusCode::OK,
Json(ItemsResponse {
items,
total,
offset,
limit,
}),
)
.into_response()
}

View File

@@ -17,11 +17,14 @@ pub use item::*;
mod field;
pub use field::*;
mod items;
pub use items::*;
#[derive(OpenApi)]
#[openapi(
tags(),
paths(lookup, item_get, get_field),
components(schemas(LookupRequest, LookupResponse, LookupResult, ItemQuery, FieldQuery))
paths(lookup, item_get, get_field, items_list),
components(schemas(LookupRequest, LookupResponse, LookupResult, ItemQuery, FieldQuery, ItemsQuery, ItemsResponse, ItemRef))
)]
pub(crate) struct Api;
@@ -37,6 +40,7 @@ impl Datasets {
.route("/lookup", post(lookup))
.route("/item", get(item_get))
.route("/field", get(get_field))
.route("/items", get(items_list))
.with_state(self.clone());
if let Some(prefix) = prefix {

View File

@@ -5,7 +5,7 @@ use pile_config::{
};
use smartstring::{LazyCompact, SmartString};
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
path::PathBuf,
sync::{Arc, OnceLock},
};
@@ -22,7 +22,7 @@ pub struct DirDataSource {
pub name: Label,
pub dir: PathBuf,
pub pattern: GroupPattern,
pub index: OnceLock<HashMap<SmartString<LazyCompact>, Item>>,
pub index: OnceLock<BTreeMap<SmartString<LazyCompact>, Item>>,
}
impl DirDataSource {
@@ -73,7 +73,7 @@ impl DirDataSource {
// MARK: resolve groups
//
let mut index = HashMap::new();
let mut index = BTreeMap::new();
'entry: for path in paths_items.difference(&paths_grouped_items) {
let path_str = match path.to_str() {
Some(x) => x,

View File

@@ -17,9 +17,18 @@ pub trait DataSource {
key: &str,
) -> impl Future<Output = Result<Option<crate::value::Item>, std::io::Error>> + Send;
/// Iterate over all items in this source in an arbitrary order
/// Iterate over all items in this source in sorted key order
fn iter(&self) -> impl Iterator<Item = &crate::value::Item>;
/// Iterate over a page of items, sorted by key
fn iter_page(
&self,
offset: usize,
limit: usize,
) -> impl Iterator<Item = &crate::value::Item> {
self.iter().skip(offset).take(limit)
}
/// Return the time of the latest change to the data in this source
fn latest_change(
&self,

View File

@@ -6,7 +6,7 @@ use pile_config::{
use pile_io::S3Client;
use smartstring::{LazyCompact, SmartString};
use std::{
collections::{HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
sync::{Arc, OnceLock},
};
@@ -24,7 +24,7 @@ pub struct S3DataSource {
pub prefix: Option<SmartString<LazyCompact>>,
pub pattern: GroupPattern,
pub encryption_key: Option<[u8; 32]>,
pub index: OnceLock<HashMap<SmartString<LazyCompact>, Item>>,
pub index: OnceLock<BTreeMap<SmartString<LazyCompact>, Item>>,
}
impl S3DataSource {
@@ -119,7 +119,7 @@ impl S3DataSource {
}
}
let mut index = HashMap::new();
let mut index = BTreeMap::new();
for key in all_keys.difference(&keys_grouped) {
let groups = resolve_groups(&source.pattern, key).await;
let group = groups

View File

@@ -1,8 +1,9 @@
use anyhow::{Context, Result};
use axum::{
Json, Router,
extract::State,
extract::{Request, State},
http::StatusCode,
middleware::{Next, from_fn_with_state},
response::{IntoResponse, Response},
routing::get,
};
@@ -30,6 +31,10 @@ pub struct ServerCommand {
/// If provided, do not serve docs
#[arg(long)]
no_docs: bool,
/// If provided, require this bearer token for all requests
#[arg(long)]
token: Option<String>,
}
impl CliCmd for ServerCommand {
@@ -50,6 +55,8 @@ impl CliCmd for ServerCommand {
Arc::new(datasets)
};
let bearer = BearerToken(self.token.map(Arc::new));
let mut router = Router::new();
for d in datasets.iter() {
let prefix = format!("/{}", d.config.dataset.name);
@@ -70,6 +77,8 @@ impl CliCmd for ServerCommand {
router = router.merge(docs);
}
router = router.layer(from_fn_with_state(bearer, bearer_auth_middleware));
let app = router.into_make_service_with_connect_info::<std::net::SocketAddr>();
let listener = match tokio::net::TcpListener::bind(self.addr.clone()).await {
@@ -114,6 +123,36 @@ impl CliCmd for ServerCommand {
}
}
//
// MARK: bearer auth middleware
//
#[derive(Clone)]
struct BearerToken(Option<Arc<String>>);
async fn bearer_auth_middleware(
State(BearerToken(expected)): State<BearerToken>,
request: Request,
next: Next,
) -> Response {
let Some(expected) = expected else {
return next.run(request).await;
};
let authorized = request
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.is_some_and(|token| token == expected.as_str());
if authorized {
next.run(request).await
} else {
StatusCode::UNAUTHORIZED.into_response()
}
}
//
// MARK: routes
//