336 lines
7.9 KiB
Rust
336 lines
7.9 KiB
Rust
use axum::{
|
|
Router, body::Body as AxumBody, extract::State, response::Response as AxumResponse,
|
|
routing::any,
|
|
};
|
|
use bytes::Bytes;
|
|
use futures_core::Stream;
|
|
use reqwest::{Client, StatusCode, header};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::pin::Pin;
|
|
use thiserror::Error;
|
|
|
|
//
|
|
// MARK: Error
|
|
//
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum ClientError {
|
|
#[error("invalid bearer token")]
|
|
InvalidToken,
|
|
|
|
#[error("HTTP {status}: {body}")]
|
|
Http { status: StatusCode, body: String },
|
|
|
|
#[error(transparent)]
|
|
Reqwest(#[from] reqwest::Error),
|
|
}
|
|
|
|
//
|
|
// MARK: Response types
|
|
//
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct DatasetInfo {
|
|
pub name: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct LookupRequest {
|
|
pub query: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub limit: Option<usize>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct LookupResult {
|
|
pub score: f32,
|
|
pub source: String,
|
|
pub key: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct LookupResponse {
|
|
pub results: Vec<LookupResult>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ItemRef {
|
|
pub source: String,
|
|
pub key: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct ItemsResponse {
|
|
pub items: Vec<ItemRef>,
|
|
pub total: usize,
|
|
pub offset: usize,
|
|
pub limit: usize,
|
|
}
|
|
|
|
/// Raw field response: the content-type and body bytes as returned by the server.
|
|
pub struct FieldResponse {
|
|
pub content_type: String,
|
|
pub data: Bytes,
|
|
}
|
|
|
|
//
|
|
// MARK: PileClient
|
|
//
|
|
|
|
/// A client for a pile server. Use [`PileClient::dataset`] to get a dataset-scoped client.
|
|
pub struct PileClient {
|
|
base_url: String,
|
|
client: Client,
|
|
token: Option<String>,
|
|
}
|
|
|
|
impl PileClient {
|
|
pub fn new(base_url: impl Into<String>, token: Option<&str>) -> Result<Self, ClientError> {
|
|
let mut headers = header::HeaderMap::new();
|
|
|
|
if let Some(token) = token {
|
|
let value = header::HeaderValue::from_str(&format!("Bearer {token}"))
|
|
.map_err(|_| ClientError::InvalidToken)?;
|
|
headers.insert(header::AUTHORIZATION, value);
|
|
}
|
|
|
|
let client = Client::builder()
|
|
.default_headers(headers)
|
|
.build()
|
|
.map_err(ClientError::Reqwest)?;
|
|
|
|
Ok(Self {
|
|
base_url: base_url.into(),
|
|
client,
|
|
token: token.map(str::to_owned),
|
|
})
|
|
}
|
|
|
|
/// Returns a client scoped to a specific dataset (i.e. `/{name}/...`).
|
|
pub fn dataset(&self, name: &str) -> DatasetClient {
|
|
DatasetClient {
|
|
base_url: format!("{}/{name}", self.base_url),
|
|
client: self.client.clone(),
|
|
token: self.token.clone(),
|
|
}
|
|
}
|
|
|
|
/// `GET /datasets` — list all datasets served by this server.
|
|
pub async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, ClientError> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{}/datasets", self.base_url))
|
|
.send()
|
|
.await?;
|
|
|
|
check_status(resp).await?.json().await.map_err(Into::into)
|
|
}
|
|
}
|
|
|
|
//
|
|
// MARK: DatasetClient
|
|
//
|
|
|
|
/// A client scoped to a single dataset on the server.
|
|
pub struct DatasetClient {
|
|
base_url: String,
|
|
client: Client,
|
|
token: Option<String>,
|
|
}
|
|
|
|
impl DatasetClient {
|
|
/// `POST /lookup` — full-text search within this dataset.
|
|
pub async fn lookup(
|
|
&self,
|
|
query: impl Into<String>,
|
|
limit: Option<usize>,
|
|
) -> Result<LookupResponse, ClientError> {
|
|
let body = LookupRequest {
|
|
query: query.into(),
|
|
limit,
|
|
};
|
|
|
|
let resp = self
|
|
.client
|
|
.post(format!("{}/lookup", self.base_url))
|
|
.json(&body)
|
|
.send()
|
|
.await?;
|
|
|
|
check_status(resp).await?.json().await.map_err(Into::into)
|
|
}
|
|
|
|
/// `GET /item` — stream the raw bytes of an item.
|
|
///
|
|
/// The returned stream yields chunks as they arrive from the server.
|
|
pub async fn get_item(
|
|
&self,
|
|
source: &str,
|
|
key: &str,
|
|
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes, reqwest::Error>> + Send>>, ClientError> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{}/item", self.base_url))
|
|
.query(&[("source", source), ("key", key)])
|
|
.send()
|
|
.await?;
|
|
|
|
Ok(Box::pin(check_status(resp).await?.bytes_stream()))
|
|
}
|
|
|
|
/// `GET /field` — extract a field from an item by object path (e.g. `$.flac.title`).
|
|
pub async fn get_field(
|
|
&self,
|
|
source: &str,
|
|
key: &str,
|
|
path: &str,
|
|
) -> Result<FieldResponse, ClientError> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{}/field", self.base_url))
|
|
.query(&[("source", source), ("key", key), ("path", path)])
|
|
.send()
|
|
.await?;
|
|
|
|
let resp = check_status(resp).await?;
|
|
|
|
let content_type = resp
|
|
.headers()
|
|
.get(header::CONTENT_TYPE)
|
|
.and_then(|v| v.to_str().ok())
|
|
.unwrap_or("application/octet-stream")
|
|
.to_owned();
|
|
|
|
let data = resp.bytes().await?;
|
|
|
|
Ok(FieldResponse { content_type, data })
|
|
}
|
|
|
|
/// `GET /items` — paginate over all items in this dataset, ordered by (source, key).
|
|
pub async fn list_items(
|
|
&self,
|
|
offset: usize,
|
|
limit: usize,
|
|
) -> Result<ItemsResponse, ClientError> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{}/items", self.base_url))
|
|
.query(&[("offset", offset), ("limit", limit)])
|
|
.send()
|
|
.await?;
|
|
|
|
check_status(resp).await?.json().await.map_err(Into::into)
|
|
}
|
|
|
|
/// Returns an axum [`Router`] that proxies all requests to this dataset's
|
|
/// endpoints on the remote pile server, streaming responses without buffering.
|
|
/// All headers are forwarded; hop-by-hop headers are stripped.
|
|
pub fn proxy_router(&self) -> Router {
|
|
let state = ProxyState {
|
|
base_url: self.base_url.clone(),
|
|
client: self.client.clone(),
|
|
token: self.token.clone(),
|
|
};
|
|
Router::new()
|
|
.route("/", any(proxy_handler))
|
|
.route("/{*path}", any(proxy_handler))
|
|
.with_state(state)
|
|
}
|
|
}
|
|
|
|
//
|
|
// MARK: Proxy
|
|
//
|
|
|
|
#[derive(Clone)]
|
|
struct ProxyState {
|
|
base_url: String,
|
|
client: Client,
|
|
token: Option<String>,
|
|
}
|
|
|
|
async fn proxy_handler(
|
|
State(state): State<ProxyState>,
|
|
req: axum::extract::Request,
|
|
) -> AxumResponse {
|
|
let path = req.uri().path().to_owned();
|
|
let query_str = req
|
|
.uri()
|
|
.query()
|
|
.map(|q| format!("?{q}"))
|
|
.unwrap_or_default();
|
|
let method = req.method().clone();
|
|
|
|
let url = format!("{}{}{}", state.base_url, path, query_str);
|
|
let mut req_builder = state.client.request(method, &url);
|
|
|
|
// Forward all request headers except hop-by-hop and Host.
|
|
// Authorization is skipped so the client's default bearer token is used.
|
|
for (name, value) in req.headers() {
|
|
if !is_hop_by_hop(name) && name != header::HOST && name != header::AUTHORIZATION {
|
|
req_builder = req_builder.header(name, value);
|
|
}
|
|
}
|
|
|
|
// Attach bearer token if present (overrides client default for clarity).
|
|
if let Some(ref token) = state.token
|
|
&& let Ok(value) = header::HeaderValue::from_str(&format!("Bearer {token}"))
|
|
{
|
|
req_builder = req_builder.header(header::AUTHORIZATION, value);
|
|
}
|
|
|
|
// Stream the request body upstream.
|
|
let body_stream = req.into_body().into_data_stream();
|
|
req_builder = req_builder.body(reqwest::Body::wrap_stream(body_stream));
|
|
|
|
let upstream = match req_builder.send().await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
return AxumResponse::builder()
|
|
.status(StatusCode::BAD_GATEWAY.as_u16())
|
|
.body(AxumBody::from(e.to_string()))
|
|
.unwrap_or_else(|_| AxumResponse::new(AxumBody::empty()));
|
|
}
|
|
};
|
|
|
|
let status = upstream.status().as_u16();
|
|
let resp_headers = upstream.headers().clone();
|
|
|
|
let mut builder = AxumResponse::builder().status(status);
|
|
for (name, value) in &resp_headers {
|
|
if !is_hop_by_hop(name) {
|
|
builder = builder.header(name, value);
|
|
}
|
|
}
|
|
|
|
// Stream the response body without buffering.
|
|
builder
|
|
.body(AxumBody::from_stream(upstream.bytes_stream()))
|
|
.unwrap_or_else(|_| AxumResponse::new(AxumBody::empty()))
|
|
}
|
|
|
|
fn is_hop_by_hop(name: &header::HeaderName) -> bool {
|
|
name == header::CONNECTION
|
|
|| name == header::TRANSFER_ENCODING
|
|
|| name == header::TE
|
|
|| name == header::UPGRADE
|
|
|| name == header::PROXY_AUTHORIZATION
|
|
|| name == header::PROXY_AUTHENTICATE
|
|
|| name.as_str() == "keep-alive"
|
|
|| name.as_str() == "trailers"
|
|
}
|
|
|
|
//
|
|
// MARK: helpers
|
|
//
|
|
|
|
async fn check_status(resp: reqwest::Response) -> Result<reqwest::Response, ClientError> {
|
|
let status = resp.status();
|
|
if status.is_success() {
|
|
return Ok(resp);
|
|
}
|
|
|
|
let body = resp.text().await.unwrap_or_default();
|
|
Err(ClientError::Http { status, body })
|
|
}
|