Add libservice

This commit is contained in:
2025-11-01 21:39:15 -07:00
parent 75820f97fc
commit 47f6296896
2 changed files with 384 additions and 0 deletions

View File

@@ -0,0 +1,369 @@
//! Abstractions for modular http API routes
use axum::{
Json, Router,
extract::{Request, State},
http::{StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use std::ops::Deref;
use tower_http::trace::TraceLayer;
use tracing::info;
use utoipa::openapi::{
OpenApi, SecurityRequirement, Tag,
security::{Http, HttpAuthScheme, SecurityScheme},
};
use utoipa_swagger_ui::SwaggerUi;
/// A `Service` provides a set of api endpoints and docs.
/// This has no relation to [tower::Service].
pub trait ToService
where
Self: Sized + 'static,
{
/// Create a router for this service.
fn make_router(&self) -> Option<Router<()>>;
/// Create an openapi spec for this service
fn make_openapi(&self) -> OpenApi;
/// Get the service name for grouping endpoints
fn service_name(&self) -> Option<String> {
None
}
fn to_service(self) -> Service {
Service::new().merge(self)
}
}
//
// MARK: bearer
//
/// A router state that contains a bearer token
pub trait BearerAuthProvider
where
Self: Send + Sync + 'static + Clone,
{
/// Returns `true` if the provided token is valid
fn check_bearer_token(&self, token: &str) -> bool;
}
// For Arc<T>
impl<T, P> BearerAuthProvider for T
where
T: Deref<Target = P>,
T: Send + Sync + 'static + Clone,
P: BearerAuthProvider,
{
fn check_bearer_token(&self, token: &str) -> bool {
self.deref().check_bearer_token(token)
}
}
/// Wraps a service and enforces bearer auth for all requests.
/// Adds axum middleware & modifies utoipa docs.
pub struct BearerAuth<S, Provider>
where
S: ToService,
Provider: BearerAuthProvider,
Provider: Send + Sync + 'static + Clone,
{
service: S,
provider: Option<Provider>,
security_scheme_name: String,
}
impl<S, Provider> BearerAuth<S, Provider>
where
S: ToService,
Provider: BearerAuthProvider,
{
/// Middleware for bearer auth
///
/// Usage:
/// ```notrust
/// Router::new()
/// .nest(..)
/// .route_layer(middleware::from_fn_with_state(state.clone(), bearer_auth));
/// ```
async fn bearer_auth(State(state): State<Provider>, request: Request, next: Next) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.and_then(|value| value.strip_prefix("Bearer "));
if let Some(token) = auth_header
&& state.check_bearer_token(token)
{
return next.run(request).await;
}
info!(message = "Authentication failed");
return (StatusCode::UNAUTHORIZED, Json("invalid secret")).into_response();
}
}
impl<S, Provider> ToService for BearerAuth<S, Provider>
where
S: ToService,
Provider: BearerAuthProvider,
{
#[inline]
fn make_router(&self) -> Option<Router<()>> {
self.service
.make_router()
.map(|router| match &self.provider {
Some(provider) => router.route_layer(axum::middleware::from_fn_with_state(
provider.clone(),
Self::bearer_auth,
)),
None => router,
})
}
#[inline]
fn make_openapi(&self) -> utoipa::openapi::OpenApi {
let mut api = self.service.make_openapi();
if let Some(components) = api.components.as_mut() {
components.add_security_scheme(
self.security_scheme_name.clone(),
SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)),
)
}
// Attach `security((self.security_scheme_name = []))`
// to all routes in the api.
if self.provider.is_some() {
for (_path, path_item) in api.paths.paths.iter_mut() {
let req = SecurityRequirement::new(
self.security_scheme_name.clone(),
Vec::<String>::new(),
);
if let Some(operation) = &mut path_item.get {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.post {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.put {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.delete {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.options {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.head {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.patch {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
if let Some(operation) = &mut path_item.trace {
match &mut operation.security {
Some(x) => x.push(req.clone()),
None => operation.security = Some(vec![req.clone()]),
}
}
}
}
return api;
}
}
pub trait AttachBearer<Provider>
where
Self: ToService,
Provider: BearerAuthProvider,
Provider: Send + Sync + 'static + Clone,
{
fn bearer(
self,
security_scheme_name: impl Into<String>,
provider: Option<Provider>,
) -> BearerAuth<Self, Provider> {
BearerAuth {
security_scheme_name: security_scheme_name.into(),
service: self,
provider,
}
}
}
impl<S, Provider> AttachBearer<Provider> for S
where
Self: ToService,
Provider: BearerAuthProvider,
Provider: Send + Sync + 'static + Clone,
{
}
//
// MARK: service
//
/// Wraps a service and enforces bearer auth for all requests.
/// Adds axum middleware & modifies utoipa docs.
pub struct Service {
router: Option<Router<()>>,
openapi: OpenApi,
trace: bool,
}
impl Service {
#[inline]
pub fn new() -> Self {
Self {
router: Some(Router::new()),
openapi: OpenApi::default(),
trace: false,
}
}
/// Tag all operations in the OpenAPI spec with the given service name
fn tag_routes(openapi: &mut OpenApi, service_name: &str) {
// Add service tag to tags list
let tag = Tag::new(service_name);
match &mut openapi.tags {
Some(tags) => tags.push(tag),
None => openapi.tags = Some(vec![tag]),
}
// Tag all operations with service name
for (_path, path_item) in openapi.paths.paths.iter_mut() {
let tag_name = service_name.to_owned();
if let Some(operation) = &mut path_item.get {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.post {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.put {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.delete {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.options {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.head {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.patch {
operation.tags = Some(vec![tag_name.clone()]);
}
if let Some(operation) = &mut path_item.trace {
operation.tags = Some(vec![tag_name.clone()]);
}
}
}
pub fn merge<S: ToService>(mut self, service: S) -> Self {
let mut service_openapi = service.make_openapi();
// Tag all paths with service name if provided
if let Some(service_name) = service.service_name() {
Self::tag_routes(&mut service_openapi, &service_name);
}
self.openapi.merge(service_openapi);
let this = self.router;
let other = service.make_router();
match (this, other) {
(_, None) | (None, _) => self.router = None,
(Some(this), Some(other)) => self.router = Some(this.merge(other)),
}
self
}
pub fn nest<S: ToService>(mut self, path: &str, service: S) -> Self {
let mut service_openapi = service.make_openapi();
// Tag all paths with service name if provided
if let Some(service_name) = service.service_name() {
Self::tag_routes(&mut service_openapi, &service_name);
}
self.openapi = self.openapi.nest(path, service_openapi);
let this = self.router;
let other = service.make_router();
match (this, other) {
(_, None) | (None, _) => self.router = None,
(Some(this), Some(other)) => self.router = Some(this.nest(path, other)),
}
self
}
pub fn trace(mut self) -> Self {
self.trace = true;
self
}
pub fn swagger(mut self, docs_path: &'static str) -> Self {
match self.router {
None => {}
Some(this) => {
let docs = SwaggerUi::new(docs_path)
.url(format!("{}/openapi.json", docs_path), self.openapi.clone());
self.router = Some(this.merge(docs))
}
}
return self;
}
}
impl ToService for Service {
#[inline]
fn make_router(&self) -> Option<Router<()>> {
match self.router.clone() {
None => return None,
Some(mut router) => {
if self.trace {
router = router.layer(TraceLayer::new_for_http())
}
return Some(router);
}
}
}
#[inline]
fn make_openapi(&self) -> utoipa::openapi::OpenApi {
self.openapi.clone()
}
}