From 0792b2f2c6146f8e5d5d5192a526415b059a835a Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Mon, 23 Mar 2026 22:15:00 -0700 Subject: [PATCH] Proxy router --- Cargo.lock | 2 + crates/pile-client/Cargo.toml | 2 + crates/pile-client/src/lib.rs | 109 +++++++++++++++++++++++++++ crates/pile-dataset/src/dataset.rs | 6 +- crates/pile-dataset/src/serve/mod.rs | 11 ++- crates/pile-value/src/source/mod.rs | 6 +- 6 files changed, 129 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2075f0d..40fd616 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2705,11 +2705,13 @@ dependencies = [ name = "pile-client" version = "0.0.2" dependencies = [ + "axum", "bytes", "futures-core", "reqwest", "serde", "thiserror", + "tracing", ] [[package]] diff --git a/crates/pile-client/Cargo.toml b/crates/pile-client/Cargo.toml index dae49d5..e51e893 100644 --- a/crates/pile-client/Cargo.toml +++ b/crates/pile-client/Cargo.toml @@ -13,3 +13,5 @@ futures-core = "0.3" serde = { workspace = true } thiserror = { workspace = true } bytes = { workspace = true } +axum = { workspace = true } +tracing = { workspace = true } diff --git a/crates/pile-client/src/lib.rs b/crates/pile-client/src/lib.rs index 4508d6b..35bdbae 100644 --- a/crates/pile-client/src/lib.rs +++ b/crates/pile-client/src/lib.rs @@ -1,9 +1,14 @@ +use axum::{ + Router, body::Body as AxumBody, extract::State, response::Response as AxumResponse, + routing::any, +}; use bytes::Bytes; use futures_core::Stream; use reqwest::{Client, StatusCode, header}; use serde::{Deserialize, Serialize}; use std::pin::Pin; use thiserror::Error; +use tracing::{trace, warn}; // // MARK: Error @@ -77,6 +82,7 @@ pub struct FieldResponse { pub struct PileClient { base_url: String, client: Client, + token: Option, } impl PileClient { @@ -97,6 +103,7 @@ impl PileClient { Ok(Self { base_url: base_url.into(), client, + token: token.map(str::to_owned), }) } @@ -105,6 +112,7 @@ impl PileClient { DatasetClient { base_url: format!("{}/{name}", self.base_url), client: self.client.clone(), + token: self.token.clone(), } } @@ -128,6 +136,7 @@ impl PileClient { pub struct DatasetClient { base_url: String, client: Client, + token: Option, } impl DatasetClient { @@ -213,6 +222,106 @@ impl DatasetClient { check_status(resp).await?.json().await.map_err(Into::into) } + + /// Returns an axum [`Router`] that proxies all requests to this dataset's + /// endpoints on the remote pile server, streaming responses without buffering. + /// All headers are forwarded; hop-by-hop headers are stripped. + pub fn proxy_router(&self) -> Router { + let state = ProxyState { + base_url: self.base_url.clone(), + client: self.client.clone(), + token: self.token.clone(), + }; + Router::new() + .route("/", any(proxy_handler)) + .route("/{*path}", any(proxy_handler)) + .with_state(state) + } +} + +// +// MARK: Proxy +// + +#[derive(Clone)] +struct ProxyState { + base_url: String, + client: Client, + token: Option, +} + +async fn proxy_handler( + State(state): State, + req: axum::extract::Request, +) -> AxumResponse { + let path = req.uri().path().to_owned(); + let query_str = req + .uri() + .query() + .map(|q| format!("?{q}")) + .unwrap_or_default(); + let method = req.method().clone(); + + let url = format!("{}{}{}", state.base_url, path, query_str); + trace!(method = %method, url, "proxying request"); + let mut req_builder = state.client.request(method, &url); + + // Forward all request headers except hop-by-hop and Host. + // Authorization is skipped so the client's default bearer token is used. + for (name, value) in req.headers() { + if !is_hop_by_hop(name) && name != header::HOST && name != header::AUTHORIZATION { + req_builder = req_builder.header(name, value); + } + } + + // Attach bearer token if present (overrides client default for clarity). + if let Some(ref token) = state.token + && let Ok(value) = header::HeaderValue::from_str(&format!("Bearer {token}")) + { + req_builder = req_builder.header(header::AUTHORIZATION, value); + } + + // Stream the request body upstream. + let body_stream = req.into_body().into_data_stream(); + req_builder = req_builder.body(reqwest::Body::wrap_stream(body_stream)); + + let upstream = match req_builder.send().await { + Ok(r) => r, + Err(e) => { + warn!(error = %e, "upstream request failed"); + return AxumResponse::builder() + .status(StatusCode::BAD_GATEWAY.as_u16()) + .body(AxumBody::from(e.to_string())) + .unwrap_or_else(|_| AxumResponse::new(AxumBody::empty())); + } + }; + + let status = upstream.status().as_u16(); + trace!(status, "upstream response"); + let resp_headers = upstream.headers().clone(); + + let mut builder = AxumResponse::builder().status(status); + for (name, value) in &resp_headers { + if !is_hop_by_hop(name) { + builder = builder.header(name, value); + } + } + + // Stream the response body without buffering. + builder + .body(AxumBody::from_stream(upstream.bytes_stream())) + .unwrap_or_else(|_| AxumResponse::new(AxumBody::empty())) +} + +fn is_hop_by_hop(name: &header::HeaderName) -> bool { + name == header::CONNECTION + || name == header::TRANSFER_ENCODING + || name == header::TE + || name == header::UPGRADE + || name == header::PROXY_AUTHORIZATION + || name == header::PROXY_AUTHENTICATE + || name.as_str() == "keep-alive" + || name.as_str() == "trailers" } // diff --git a/crates/pile-dataset/src/dataset.rs b/crates/pile-dataset/src/dataset.rs index 0a2be96..d39ca79 100644 --- a/crates/pile-dataset/src/dataset.rs +++ b/crates/pile-dataset/src/dataset.rs @@ -61,7 +61,11 @@ impl Dataset { } } - pub fn iter_page(&self, offset: usize, limit: usize) -> Box + Send + '_> { + pub fn iter_page( + &self, + offset: usize, + limit: usize, + ) -> Box + Send + '_> { match self { Self::Dir(ds) => Box::new(ds.iter_page(offset, limit)), Self::S3(ds) => Box::new(ds.iter_page(offset, limit)), diff --git a/crates/pile-dataset/src/serve/mod.rs b/crates/pile-dataset/src/serve/mod.rs index 41abb16..d8e6f96 100644 --- a/crates/pile-dataset/src/serve/mod.rs +++ b/crates/pile-dataset/src/serve/mod.rs @@ -24,7 +24,16 @@ pub use items::*; #[openapi( tags(), paths(lookup, item_get, get_field, items_list), - components(schemas(LookupRequest, LookupResponse, LookupResult, ItemQuery, FieldQuery, ItemsQuery, ItemsResponse, ItemRef)) + components(schemas( + LookupRequest, + LookupResponse, + LookupResult, + ItemQuery, + FieldQuery, + ItemsQuery, + ItemsResponse, + ItemRef + )) )] pub(crate) struct Api; diff --git a/crates/pile-value/src/source/mod.rs b/crates/pile-value/src/source/mod.rs index 56de2e0..f231967 100644 --- a/crates/pile-value/src/source/mod.rs +++ b/crates/pile-value/src/source/mod.rs @@ -21,11 +21,7 @@ pub trait DataSource { fn iter(&self) -> impl Iterator; /// Iterate over a page of items, sorted by key - fn iter_page( - &self, - offset: usize, - limit: usize, - ) -> impl Iterator { + fn iter_page(&self, offset: usize, limit: usize) -> impl Iterator { self.iter().skip(offset).take(limit) }