//! Abstractions for modular http API routes use axum::{ Json, Router, extract::{Request, State}, http::{StatusCode, header}, middleware::Next, response::{IntoResponse, Response}, }; use std::ops::Deref; use tower_http::trace::TraceLayer; use tracing::info; use utoipa::openapi::{ OpenApi, SecurityRequirement, Tag, security::{Http, HttpAuthScheme, SecurityScheme}, }; use utoipa_swagger_ui::SwaggerUi; /// A `Service` provides a set of api endpoints and docs. /// This has no relation to [tower::Service]. pub trait ToService where Self: Sized + 'static, { /// Create a router for this service. fn make_router(&self) -> Option>; /// Create an openapi spec for this service fn make_openapi(&self) -> OpenApi; /// Get the service name for grouping endpoints fn service_name(&self) -> Option { None } fn to_service(self) -> Service { Service::new().merge(self) } } // // MARK: bearer // /// A router state that contains a bearer token pub trait BearerAuthProvider where Self: Send + Sync + 'static + Clone, { /// Returns `true` if the provided token is valid fn check_bearer_token(&self, token: &str) -> bool; } // For Arc impl BearerAuthProvider for T where T: Deref, T: Send + Sync + 'static + Clone, P: BearerAuthProvider, { fn check_bearer_token(&self, token: &str) -> bool { self.deref().check_bearer_token(token) } } /// Wraps a service and enforces bearer auth for all requests. /// Adds axum middleware & modifies utoipa docs. pub struct BearerAuth where S: ToService, Provider: BearerAuthProvider, Provider: Send + Sync + 'static + Clone, { service: S, provider: Option, security_scheme_name: String, } impl BearerAuth where S: ToService, Provider: BearerAuthProvider, { /// Middleware for bearer auth /// /// Usage: /// ```notrust /// Router::new() /// .nest(..) /// .route_layer(middleware::from_fn_with_state(state.clone(), bearer_auth)); /// ``` async fn bearer_auth(State(state): State, request: Request, next: Next) -> Response { let auth_header = request .headers() .get(header::AUTHORIZATION) .and_then(|header| header.to_str().ok()) .and_then(|value| value.strip_prefix("Bearer ")); if let Some(token) = auth_header && state.check_bearer_token(token) { return next.run(request).await; } info!(message = "Authentication failed"); return (StatusCode::UNAUTHORIZED, Json("invalid secret")).into_response(); } } impl ToService for BearerAuth where S: ToService, Provider: BearerAuthProvider, { #[inline] fn make_router(&self) -> Option> { self.service .make_router() .map(|router| match &self.provider { Some(provider) => router.route_layer(axum::middleware::from_fn_with_state( provider.clone(), Self::bearer_auth, )), None => router, }) } #[inline] fn make_openapi(&self) -> utoipa::openapi::OpenApi { let mut api = self.service.make_openapi(); if let Some(components) = api.components.as_mut() { components.add_security_scheme( self.security_scheme_name.clone(), SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)), ) } // Attach `security((self.security_scheme_name = []))` // to all routes in the api. if self.provider.is_some() { for (_path, path_item) in api.paths.paths.iter_mut() { let req = SecurityRequirement::new( self.security_scheme_name.clone(), Vec::::new(), ); if let Some(operation) = &mut path_item.get { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.post { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.put { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.delete { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.options { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.head { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.patch { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } if let Some(operation) = &mut path_item.trace { match &mut operation.security { Some(x) => x.push(req.clone()), None => operation.security = Some(vec![req.clone()]), } } } } return api; } } pub trait AttachBearer where Self: ToService, Provider: BearerAuthProvider, Provider: Send + Sync + 'static + Clone, { fn bearer( self, security_scheme_name: impl Into, provider: Option, ) -> BearerAuth { BearerAuth { security_scheme_name: security_scheme_name.into(), service: self, provider, } } } impl AttachBearer for S where Self: ToService, Provider: BearerAuthProvider, Provider: Send + Sync + 'static + Clone, { } // // MARK: service // /// Wraps a service and enforces bearer auth for all requests. /// Adds axum middleware & modifies utoipa docs. pub struct Service { router: Option>, openapi: OpenApi, trace: bool, } impl Service { #[inline] pub fn new() -> Self { Self { router: Some(Router::new()), openapi: OpenApi::default(), trace: false, } } /// Tag all operations in the OpenAPI spec with the given service name fn tag_routes(openapi: &mut OpenApi, service_name: &str) { // Add service tag to tags list let tag = Tag::new(service_name); match &mut openapi.tags { Some(tags) => tags.push(tag), None => openapi.tags = Some(vec![tag]), } // Tag all operations with service name for (_path, path_item) in openapi.paths.paths.iter_mut() { let tag_name = service_name.to_owned(); if let Some(operation) = &mut path_item.get { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.post { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.put { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.delete { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.options { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.head { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.patch { operation.tags = Some(vec![tag_name.clone()]); } if let Some(operation) = &mut path_item.trace { operation.tags = Some(vec![tag_name.clone()]); } } } pub fn merge(mut self, service: S) -> Self { let mut service_openapi = service.make_openapi(); // Tag all paths with service name if provided if let Some(service_name) = service.service_name() { Self::tag_routes(&mut service_openapi, &service_name); } self.openapi.merge(service_openapi); let this = self.router; let other = service.make_router(); match (this, other) { (_, None) | (None, _) => self.router = None, (Some(this), Some(other)) => self.router = Some(this.merge(other)), } self } pub fn nest(mut self, path: &str, service: S) -> Self { let mut service_openapi = service.make_openapi(); // Tag all paths with service name if provided if let Some(service_name) = service.service_name() { Self::tag_routes(&mut service_openapi, &service_name); } self.openapi = self.openapi.nest(path, service_openapi); let this = self.router; let other = service.make_router(); match (this, other) { (_, None) | (None, _) => self.router = None, (Some(this), Some(other)) => self.router = Some(this.nest(path, other)), } self } pub fn trace(mut self) -> Self { self.trace = true; self } pub fn swagger(mut self, docs_path: &'static str) -> Self { match self.router { None => {} Some(this) => { let docs = SwaggerUi::new(docs_path) .url(format!("{}/openapi.json", docs_path), self.openapi.clone()); self.router = Some(this.merge(docs)) } } return self; } } impl ToService for Service { #[inline] fn make_router(&self) -> Option> { match self.router.clone() { None => return None, Some(mut router) => { if self.trace { router = router.layer(TraceLayer::new_for_http()) } return Some(router); } } } #[inline] fn make_openapi(&self) -> utoipa::openapi::OpenApi { self.openapi.clone() } }