use axum::{ Router, body::Body as AxumBody, extract::State, response::Response as AxumResponse, routing::any, }; use bytes::Bytes; use futures_core::Stream; use reqwest::{Client, StatusCode, header}; use serde::{Deserialize, Serialize}; use std::pin::Pin; use thiserror::Error; // // MARK: Error // #[derive(Debug, Error)] pub enum ClientError { #[error("invalid bearer token")] InvalidToken, #[error("HTTP {status}: {body}")] Http { status: StatusCode, body: String }, #[error(transparent)] Reqwest(#[from] reqwest::Error), } // // MARK: Response types // #[derive(Debug, Deserialize)] pub struct DatasetInfo { pub name: String, } #[derive(Debug, Serialize)] pub struct LookupRequest { pub query: String, #[serde(skip_serializing_if = "Option::is_none")] pub limit: Option, } #[derive(Debug, Deserialize)] pub struct LookupResult { pub score: f32, pub source: String, pub key: String, } #[derive(Debug, Deserialize)] pub struct LookupResponse { pub results: Vec, } #[derive(Debug, Deserialize)] pub struct ItemRef { pub source: String, pub key: String, } #[derive(Debug, Deserialize)] pub struct ItemsResponse { pub items: Vec, pub total: usize, pub offset: usize, pub limit: usize, } /// Raw field response: the content-type and body bytes as returned by the server. pub struct FieldResponse { pub content_type: String, pub data: Bytes, } // // MARK: PileClient // /// A client for a pile server. Use [`PileClient::dataset`] to get a dataset-scoped client. pub struct PileClient { base_url: String, client: Client, token: Option, } impl PileClient { pub fn new(base_url: impl Into, token: Option<&str>) -> Result { let mut headers = header::HeaderMap::new(); if let Some(token) = token { let value = header::HeaderValue::from_str(&format!("Bearer {token}")) .map_err(|_| ClientError::InvalidToken)?; headers.insert(header::AUTHORIZATION, value); } let client = Client::builder() .default_headers(headers) .build() .map_err(ClientError::Reqwest)?; Ok(Self { base_url: base_url.into(), client, token: token.map(str::to_owned), }) } /// Returns a client scoped to a specific dataset (i.e. `/{name}/...`). pub fn dataset(&self, name: &str) -> DatasetClient { DatasetClient { base_url: format!("{}/{name}", self.base_url), client: self.client.clone(), token: self.token.clone(), } } /// `GET /datasets` — list all datasets served by this server. pub async fn list_datasets(&self) -> Result, ClientError> { let resp = self .client .get(format!("{}/datasets", self.base_url)) .send() .await?; check_status(resp).await?.json().await.map_err(Into::into) } } // // MARK: DatasetClient // /// A client scoped to a single dataset on the server. pub struct DatasetClient { base_url: String, client: Client, token: Option, } impl DatasetClient { /// `POST /lookup` — full-text search within this dataset. pub async fn lookup( &self, query: impl Into, limit: Option, ) -> Result { let body = LookupRequest { query: query.into(), limit, }; let resp = self .client .post(format!("{}/lookup", self.base_url)) .json(&body) .send() .await?; check_status(resp).await?.json().await.map_err(Into::into) } /// `GET /item` — stream the raw bytes of an item. /// /// The returned stream yields chunks as they arrive from the server. pub async fn get_item( &self, source: &str, key: &str, ) -> Result> + Send>>, ClientError> { let resp = self .client .get(format!("{}/item", self.base_url)) .query(&[("source", source), ("key", key)]) .send() .await?; Ok(Box::pin(check_status(resp).await?.bytes_stream())) } /// `GET /field` — extract a field from an item by object path (e.g. `$.flac.title`). pub async fn get_field( &self, source: &str, key: &str, path: &str, ) -> Result { let resp = self .client .get(format!("{}/field", self.base_url)) .query(&[("source", source), ("key", key), ("path", path)]) .send() .await?; let resp = check_status(resp).await?; let content_type = resp .headers() .get(header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .unwrap_or("application/octet-stream") .to_owned(); let data = resp.bytes().await?; Ok(FieldResponse { content_type, data }) } /// `GET /items` — paginate over all items in this dataset, ordered by (source, key). pub async fn list_items( &self, offset: usize, limit: usize, ) -> Result { let resp = self .client .get(format!("{}/items", self.base_url)) .query(&[("offset", offset), ("limit", limit)]) .send() .await?; check_status(resp).await?.json().await.map_err(Into::into) } /// Returns an axum [`Router`] that proxies all requests to this dataset's /// endpoints on the remote pile server, streaming responses without buffering. /// All headers are forwarded; hop-by-hop headers are stripped. pub fn proxy_router(&self) -> Router { let state = ProxyState { base_url: self.base_url.clone(), client: self.client.clone(), token: self.token.clone(), }; Router::new() .route("/", any(proxy_handler)) .route("/{*path}", any(proxy_handler)) .with_state(state) } } // // MARK: Proxy // #[derive(Clone)] struct ProxyState { base_url: String, client: Client, token: Option, } async fn proxy_handler( State(state): State, req: axum::extract::Request, ) -> AxumResponse { let path = req.uri().path().to_owned(); let query_str = req .uri() .query() .map(|q| format!("?{q}")) .unwrap_or_default(); let method = req.method().clone(); let url = format!("{}{}{}", state.base_url, path, query_str); let mut req_builder = state.client.request(method, &url); // Forward all request headers except hop-by-hop and Host. // Authorization is skipped so the client's default bearer token is used. for (name, value) in req.headers() { if !is_hop_by_hop(name) && name != header::HOST && name != header::AUTHORIZATION { req_builder = req_builder.header(name, value); } } // Attach bearer token if present (overrides client default for clarity). if let Some(ref token) = state.token && let Ok(value) = header::HeaderValue::from_str(&format!("Bearer {token}")) { req_builder = req_builder.header(header::AUTHORIZATION, value); } // Stream the request body upstream. let body_stream = req.into_body().into_data_stream(); req_builder = req_builder.body(reqwest::Body::wrap_stream(body_stream)); let upstream = match req_builder.send().await { Ok(r) => r, Err(e) => { return AxumResponse::builder() .status(StatusCode::BAD_GATEWAY.as_u16()) .body(AxumBody::from(e.to_string())) .unwrap_or_else(|_| AxumResponse::new(AxumBody::empty())); } }; let status = upstream.status().as_u16(); let resp_headers = upstream.headers().clone(); let mut builder = AxumResponse::builder().status(status); for (name, value) in &resp_headers { if !is_hop_by_hop(name) { builder = builder.header(name, value); } } // Stream the response body without buffering. builder .body(AxumBody::from_stream(upstream.bytes_stream())) .unwrap_or_else(|_| AxumResponse::new(AxumBody::empty())) } fn is_hop_by_hop(name: &header::HeaderName) -> bool { name == header::CONNECTION || name == header::TRANSFER_ENCODING || name == header::TE || name == header::UPGRADE || name == header::PROXY_AUTHORIZATION || name == header::PROXY_AUTHENTICATE || name.as_str() == "keep-alive" || name.as_str() == "trailers" } // // MARK: helpers // async fn check_status(resp: reqwest::Response) -> Result { let status = resp.status(); if status.is_success() { return Ok(resp); } let body = resp.text().await.unwrap_or_default(); Err(ClientError::Http { status, body }) }