Some checks failed
CI / Check typos (push) Failing after 9s
CI / Check links (push) Failing after 14s
CI / Clippy (push) Successful in 53s
CI / Build and test (push) Successful in 1m19s
CI / Build container (push) Has been skipped
CI / Deploy on waypoint (push) Has been skipped
276 lines
7.1 KiB
Rust
276 lines
7.1 KiB
Rust
use axum::{
|
|
Router,
|
|
body::Body,
|
|
http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header},
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use chrono::TimeDelta;
|
|
use std::{
|
|
collections::{BTreeMap, HashMap},
|
|
convert::Infallible,
|
|
net::SocketAddr,
|
|
pin::Pin,
|
|
sync::Arc,
|
|
task::{Context, Poll},
|
|
time::Instant,
|
|
};
|
|
use toolbox::mime::MimeType;
|
|
use tower::Service;
|
|
use tracing::trace;
|
|
|
|
use crate::{ClientInfo, RenderContext, Rendered, RenderedBody, servable::Servable};
|
|
|
|
struct Default404 {}
|
|
impl Servable for Default404 {
|
|
fn head<'a>(
|
|
&'a self,
|
|
_ctx: &'a RenderContext,
|
|
) -> Pin<Box<dyn Future<Output = Rendered<()>> + 'a + Send + Sync>> {
|
|
Box::pin(async {
|
|
return Rendered {
|
|
code: StatusCode::NOT_FOUND,
|
|
body: (),
|
|
ttl: Some(TimeDelta::days(1)),
|
|
immutable: true,
|
|
headers: HeaderMap::new(),
|
|
mime: Some(MimeType::Html),
|
|
};
|
|
})
|
|
}
|
|
|
|
fn render<'a>(
|
|
&'a self,
|
|
ctx: &'a RenderContext,
|
|
) -> Pin<Box<dyn Future<Output = Rendered<RenderedBody>> + 'a + Send + Sync>> {
|
|
Box::pin(async { self.head(ctx).await.with_body(RenderedBody::Empty) })
|
|
}
|
|
}
|
|
|
|
/// A set of related [Servable]s under one route.
|
|
///
|
|
/// Use as follows:
|
|
/// ```ignore
|
|
///
|
|
/// // Add compression, for example.
|
|
/// // Also consider CORS and timeout.
|
|
/// let compression: CompressionLayer = CompressionLayer::new()
|
|
/// .br(true)
|
|
/// .deflate(true)
|
|
/// .gzip(true)
|
|
/// .zstd(true)
|
|
/// .compress_when(DefaultPredicate::new());
|
|
///
|
|
/// let route = ServableRoute::new()
|
|
/// .add_page(
|
|
/// "/page",
|
|
/// StaticAsset {
|
|
/// bytes: "I am a page".as_bytes(),
|
|
/// mime: MimeType::Text,
|
|
/// },
|
|
/// );
|
|
///
|
|
/// Router::new()
|
|
/// .nest_service("/", route)
|
|
/// .layer(compression.clone());
|
|
/// ```
|
|
#[derive(Clone)]
|
|
pub struct ServableRoute {
|
|
pages: Arc<HashMap<String, Arc<dyn Servable>>>,
|
|
notfound: Arc<dyn Servable>,
|
|
}
|
|
|
|
impl ServableRoute {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
pages: Arc::new(HashMap::new()),
|
|
notfound: Arc::new(Default404 {}),
|
|
}
|
|
}
|
|
|
|
/// Set this server's "not found" page
|
|
pub fn with_404<S: Servable + 'static>(mut self, page: S) -> Self {
|
|
self.notfound = Arc::new(page);
|
|
self
|
|
}
|
|
|
|
/// Add a page to this server at the given route.
|
|
/// - panics if route does not start with a `/`, ends with a `/`, or contains `//`.
|
|
/// - urls are normalized, routes that violate this condition will never be served.
|
|
/// - `/` is an exception, it is valid.
|
|
/// - panics if called after this service is started
|
|
/// - overwrites existing pages
|
|
pub fn add_page<S: Servable + 'static>(mut self, route: impl Into<String>, page: S) -> Self {
|
|
let route = route.into();
|
|
|
|
if !route.starts_with("/") {
|
|
panic!("route must start with /")
|
|
};
|
|
|
|
if route.ends_with("/") && route != "/" {
|
|
panic!("route must not end with /")
|
|
};
|
|
|
|
if route.contains("//") {
|
|
panic!("route must not contain //")
|
|
};
|
|
|
|
#[expect(clippy::expect_used)]
|
|
Arc::get_mut(&mut self.pages)
|
|
.expect("add_pages called after service was started")
|
|
.insert(route, Arc::new(page));
|
|
|
|
self
|
|
}
|
|
|
|
/// Convenience method.
|
|
/// Turns this service into a router.
|
|
///
|
|
/// Equivalent to:
|
|
/// ```ignore
|
|
/// Router::new().fallback_service(self)
|
|
/// ```
|
|
pub fn into_router<T: Clone + Send + Sync + 'static>(self) -> Router<T> {
|
|
Router::new().fallback_service(self)
|
|
}
|
|
}
|
|
|
|
//
|
|
// MARK: impl Service
|
|
//
|
|
|
|
impl Service<Request<Body>> for ServableRoute {
|
|
type Response = Response;
|
|
type Error = Infallible;
|
|
type Future =
|
|
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
|
|
|
|
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
|
|
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
|
if req.method() != Method::GET && req.method() != Method::HEAD {
|
|
let mut headers = HeaderMap::with_capacity(1);
|
|
headers.insert(header::ACCEPT, HeaderValue::from_static("GET,HEAD"));
|
|
return Box::pin(async {
|
|
Ok((StatusCode::METHOD_NOT_ALLOWED, headers).into_response())
|
|
});
|
|
}
|
|
|
|
let pages = self.pages.clone();
|
|
let notfound = self.notfound.clone();
|
|
Box::pin(async move {
|
|
let addr = req.extensions().get::<SocketAddr>().copied();
|
|
let route = req.uri().path().to_owned();
|
|
let headers = req.headers().clone();
|
|
let query: BTreeMap<String, String> =
|
|
serde_urlencoded::from_str(req.uri().query().unwrap_or("")).unwrap_or_default();
|
|
|
|
let start = Instant::now();
|
|
let client_info = ClientInfo::from_headers(&headers);
|
|
let ua = headers
|
|
.get("user-agent")
|
|
.and_then(|x| x.to_str().ok())
|
|
.unwrap_or("");
|
|
|
|
trace!(
|
|
message = "Serving route",
|
|
route,
|
|
addr = ?addr,
|
|
user_agent = ua,
|
|
device_type = ?client_info.device_type
|
|
);
|
|
|
|
// Normalize url with redirect
|
|
if (route.ends_with('/') && route != "/") || route.contains("//") {
|
|
let mut new_route = route.clone();
|
|
while new_route.contains("//") {
|
|
new_route = new_route.replace("//", "/");
|
|
}
|
|
let new_route = new_route.trim_matches('/');
|
|
|
|
trace!(
|
|
message = "Redirecting",
|
|
route,
|
|
new_route,
|
|
addr = ?addr,
|
|
user_agent = ua,
|
|
device_type = ?client_info.device_type
|
|
);
|
|
|
|
let mut headers = HeaderMap::with_capacity(1);
|
|
match HeaderValue::from_str(&format!("/{new_route}")) {
|
|
Ok(x) => headers.append(header::LOCATION, x),
|
|
Err(_) => return Ok(StatusCode::BAD_REQUEST.into_response()),
|
|
};
|
|
return Ok((StatusCode::PERMANENT_REDIRECT, headers).into_response());
|
|
}
|
|
|
|
let ctx = RenderContext {
|
|
client_info,
|
|
route,
|
|
query,
|
|
};
|
|
|
|
let page = pages.get(&ctx.route).unwrap_or(¬found);
|
|
let mut rend = match req.method() == Method::HEAD {
|
|
true => page.head(&ctx).await.with_body(RenderedBody::Empty),
|
|
false => page.render(&ctx).await,
|
|
};
|
|
|
|
// Tweak headers
|
|
{
|
|
if !rend.headers.contains_key(header::CACHE_CONTROL) {
|
|
let max_age = rend.ttl.map(|x| x.num_seconds()).unwrap_or(1).max(1);
|
|
|
|
let mut value = String::new();
|
|
if rend.immutable {
|
|
value.push_str("immutable, ");
|
|
}
|
|
|
|
value.push_str("public, ");
|
|
value.push_str(&format!("max-age={}, ", max_age));
|
|
|
|
#[expect(clippy::unwrap_used)]
|
|
rend.headers.insert(
|
|
header::CACHE_CONTROL,
|
|
HeaderValue::from_str(value.trim().trim_end_matches(',')).unwrap(),
|
|
);
|
|
}
|
|
|
|
if !rend.headers.contains_key("Accept-CH") {
|
|
rend.headers
|
|
.insert("Accept-CH", HeaderValue::from_static("Sec-CH-UA-Mobile"));
|
|
}
|
|
|
|
if !rend.headers.contains_key(header::CONTENT_TYPE)
|
|
&& let Some(mime) = &rend.mime
|
|
{
|
|
#[expect(clippy::unwrap_used)]
|
|
rend.headers.insert(
|
|
header::CONTENT_TYPE,
|
|
HeaderValue::from_str(&mime.to_string()).unwrap(),
|
|
);
|
|
}
|
|
}
|
|
|
|
trace!(
|
|
message = "Served route",
|
|
route = ctx.route,
|
|
addr = ?addr,
|
|
user_agent = ua,
|
|
device_type = ?client_info.device_type,
|
|
time_ns = start.elapsed().as_nanos()
|
|
);
|
|
|
|
Ok(match rend.body {
|
|
RenderedBody::Markup(m) => (rend.code, rend.headers, m.0).into_response(),
|
|
RenderedBody::Static(d) => (rend.code, rend.headers, d).into_response(),
|
|
RenderedBody::Bytes(d) => (rend.code, rend.headers, d).into_response(),
|
|
RenderedBody::String(s) => (rend.code, rend.headers, s).into_response(),
|
|
RenderedBody::Empty => (rend.code, rend.headers).into_response(),
|
|
})
|
|
})
|
|
}
|
|
}
|