From 47f6296896188ffa6f7473653c97f7f205b2bba7 Mon Sep 17 00:00:00 2001 From: rm-dr <96270320+rm-dr@users.noreply.github.com> Date: Sat, 1 Nov 2025 21:39:15 -0700 Subject: [PATCH] Add `libservice` --- crates/libservice/Cargo.toml | 15 ++ crates/libservice/src/lib.rs | 369 +++++++++++++++++++++++++++++++++++ 2 files changed, 384 insertions(+) create mode 100644 crates/libservice/Cargo.toml create mode 100644 crates/libservice/src/lib.rs diff --git a/crates/libservice/Cargo.toml b/crates/libservice/Cargo.toml new file mode 100644 index 0000000..3f40782 --- /dev/null +++ b/crates/libservice/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "libservice" +version = { workspace = true } +rust-version = { workspace = true } +edition = { workspace = true } + +[lints] +workspace = true + +[dependencies] +axum = { workspace = true } +tracing = { workspace = true } +tower-http = { workspace = true } +utoipa = { workspace = true } +utoipa-swagger-ui = { workspace = true } diff --git a/crates/libservice/src/lib.rs b/crates/libservice/src/lib.rs new file mode 100644 index 0000000..45bef33 --- /dev/null +++ b/crates/libservice/src/lib.rs @@ -0,0 +1,369 @@ +//! Abstractions for modular http API routes + +use axum::{ + Json, Router, + extract::{Request, State}, + http::{StatusCode, header}, + middleware::Next, + response::{IntoResponse, Response}, +}; +use std::ops::Deref; +use tower_http::trace::TraceLayer; +use tracing::info; +use utoipa::openapi::{ + OpenApi, SecurityRequirement, Tag, + security::{Http, HttpAuthScheme, SecurityScheme}, +}; +use utoipa_swagger_ui::SwaggerUi; + +/// A `Service` provides a set of api endpoints and docs. +/// This has no relation to [tower::Service]. +pub trait ToService +where + Self: Sized + 'static, +{ + /// Create a router for this service. + fn make_router(&self) -> Option>; + + /// Create an openapi spec for this service + fn make_openapi(&self) -> OpenApi; + + /// Get the service name for grouping endpoints + fn service_name(&self) -> Option { + None + } + + fn to_service(self) -> Service { + Service::new().merge(self) + } +} + +// +// MARK: bearer +// + +/// A router state that contains a bearer token +pub trait BearerAuthProvider +where + Self: Send + Sync + 'static + Clone, +{ + /// Returns `true` if the provided token is valid + fn check_bearer_token(&self, token: &str) -> bool; +} + +// For Arc +impl BearerAuthProvider for T +where + T: Deref, + T: Send + Sync + 'static + Clone, + P: BearerAuthProvider, +{ + fn check_bearer_token(&self, token: &str) -> bool { + self.deref().check_bearer_token(token) + } +} + +/// Wraps a service and enforces bearer auth for all requests. +/// Adds axum middleware & modifies utoipa docs. +pub struct BearerAuth +where + S: ToService, + Provider: BearerAuthProvider, + Provider: Send + Sync + 'static + Clone, +{ + service: S, + provider: Option, + security_scheme_name: String, +} + +impl BearerAuth +where + S: ToService, + Provider: BearerAuthProvider, +{ + /// Middleware for bearer auth + /// + /// Usage: + /// ```notrust + /// Router::new() + /// .nest(..) + /// .route_layer(middleware::from_fn_with_state(state.clone(), bearer_auth)); + /// ``` + async fn bearer_auth(State(state): State, request: Request, next: Next) -> Response { + let auth_header = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .and_then(|value| value.strip_prefix("Bearer ")); + + if let Some(token) = auth_header + && state.check_bearer_token(token) + { + return next.run(request).await; + } + + info!(message = "Authentication failed"); + return (StatusCode::UNAUTHORIZED, Json("invalid secret")).into_response(); + } +} + +impl ToService for BearerAuth +where + S: ToService, + Provider: BearerAuthProvider, +{ + #[inline] + fn make_router(&self) -> Option> { + self.service + .make_router() + .map(|router| match &self.provider { + Some(provider) => router.route_layer(axum::middleware::from_fn_with_state( + provider.clone(), + Self::bearer_auth, + )), + None => router, + }) + } + + #[inline] + fn make_openapi(&self) -> utoipa::openapi::OpenApi { + let mut api = self.service.make_openapi(); + + if let Some(components) = api.components.as_mut() { + components.add_security_scheme( + self.security_scheme_name.clone(), + SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)), + ) + } + + // Attach `security((self.security_scheme_name = []))` + // to all routes in the api. + if self.provider.is_some() { + for (_path, path_item) in api.paths.paths.iter_mut() { + let req = SecurityRequirement::new( + self.security_scheme_name.clone(), + Vec::::new(), + ); + + if let Some(operation) = &mut path_item.get { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.post { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.put { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.delete { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.options { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.head { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.patch { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + if let Some(operation) = &mut path_item.trace { + match &mut operation.security { + Some(x) => x.push(req.clone()), + None => operation.security = Some(vec![req.clone()]), + } + } + } + } + + return api; + } +} + +pub trait AttachBearer +where + Self: ToService, + Provider: BearerAuthProvider, + Provider: Send + Sync + 'static + Clone, +{ + fn bearer( + self, + security_scheme_name: impl Into, + provider: Option, + ) -> BearerAuth { + BearerAuth { + security_scheme_name: security_scheme_name.into(), + service: self, + provider, + } + } +} + +impl AttachBearer for S +where + Self: ToService, + Provider: BearerAuthProvider, + Provider: Send + Sync + 'static + Clone, +{ +} + +// +// MARK: service +// + +/// Wraps a service and enforces bearer auth for all requests. +/// Adds axum middleware & modifies utoipa docs. +pub struct Service { + router: Option>, + openapi: OpenApi, + trace: bool, +} + +impl Service { + #[inline] + pub fn new() -> Self { + Self { + router: Some(Router::new()), + openapi: OpenApi::default(), + trace: false, + } + } + + /// Tag all operations in the OpenAPI spec with the given service name + fn tag_routes(openapi: &mut OpenApi, service_name: &str) { + // Add service tag to tags list + let tag = Tag::new(service_name); + match &mut openapi.tags { + Some(tags) => tags.push(tag), + None => openapi.tags = Some(vec![tag]), + } + + // Tag all operations with service name + for (_path, path_item) in openapi.paths.paths.iter_mut() { + let tag_name = service_name.to_owned(); + if let Some(operation) = &mut path_item.get { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.post { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.put { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.delete { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.options { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.head { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.patch { + operation.tags = Some(vec![tag_name.clone()]); + } + if let Some(operation) = &mut path_item.trace { + operation.tags = Some(vec![tag_name.clone()]); + } + } + } + + pub fn merge(mut self, service: S) -> Self { + let mut service_openapi = service.make_openapi(); + + // Tag all paths with service name if provided + if let Some(service_name) = service.service_name() { + Self::tag_routes(&mut service_openapi, &service_name); + } + + self.openapi.merge(service_openapi); + + let this = self.router; + let other = service.make_router(); + match (this, other) { + (_, None) | (None, _) => self.router = None, + (Some(this), Some(other)) => self.router = Some(this.merge(other)), + } + + self + } + + pub fn nest(mut self, path: &str, service: S) -> Self { + let mut service_openapi = service.make_openapi(); + + // Tag all paths with service name if provided + if let Some(service_name) = service.service_name() { + Self::tag_routes(&mut service_openapi, &service_name); + } + + self.openapi = self.openapi.nest(path, service_openapi); + + let this = self.router; + let other = service.make_router(); + match (this, other) { + (_, None) | (None, _) => self.router = None, + (Some(this), Some(other)) => self.router = Some(this.nest(path, other)), + } + + self + } + + pub fn trace(mut self) -> Self { + self.trace = true; + self + } + + pub fn swagger(mut self, docs_path: &'static str) -> Self { + match self.router { + None => {} + Some(this) => { + let docs = SwaggerUi::new(docs_path) + .url(format!("{}/openapi.json", docs_path), self.openapi.clone()); + self.router = Some(this.merge(docs)) + } + } + + return self; + } +} + +impl ToService for Service { + #[inline] + fn make_router(&self) -> Option> { + match self.router.clone() { + None => return None, + Some(mut router) => { + if self.trace { + router = router.layer(TraceLayer::new_for_http()) + } + + return Some(router); + } + } + } + + #[inline] + fn make_openapi(&self) -> utoipa::openapi::OpenApi { + self.openapi.clone() + } +}