Add pile serve
This commit is contained in:
@@ -17,7 +17,7 @@ use crate::{Item, PileValue, extract::MetaExtractor};
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FtsLookupResult {
|
||||
pub score: f32,
|
||||
pub source_name: Label,
|
||||
pub source: Label,
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
@@ -270,7 +270,7 @@ impl DbFtsIndex {
|
||||
|
||||
out.push(FtsLookupResult {
|
||||
score,
|
||||
source_name,
|
||||
source: source_name,
|
||||
key,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -26,3 +26,7 @@ signal-hook = { workspace = true }
|
||||
anstyle = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
axum = { workspace = true }
|
||||
utoipa = { workspace = true }
|
||||
utoipa-swagger-ui = { workspace = true }
|
||||
|
||||
@@ -10,6 +10,7 @@ mod index;
|
||||
mod init;
|
||||
mod lookup;
|
||||
mod probe;
|
||||
mod serve;
|
||||
|
||||
use crate::{Cli, GlobalContext};
|
||||
|
||||
@@ -54,6 +55,12 @@ pub enum SubCommand {
|
||||
#[command(flatten)]
|
||||
cmd: probe::ProbeCommand,
|
||||
},
|
||||
|
||||
/// Expose a dataset via an http api
|
||||
Serve {
|
||||
#[command(flatten)]
|
||||
cmd: serve::cli::ServeCommand,
|
||||
},
|
||||
}
|
||||
|
||||
impl CliCmdDispatch for SubCommand {
|
||||
@@ -65,6 +72,7 @@ impl CliCmdDispatch for SubCommand {
|
||||
Self::Index { cmd } => cmd.start(ctx),
|
||||
Self::Lookup { cmd } => cmd.start(ctx),
|
||||
Self::Probe { cmd } => cmd.start(ctx),
|
||||
Self::Serve { cmd } => cmd.start(ctx),
|
||||
|
||||
Self::Docs {} => {
|
||||
print_help_recursively(&mut Cli::command(), None);
|
||||
|
||||
89
crates/pile/src/command/serve/api.rs
Normal file
89
crates/pile/src/command/serve/api.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::{DefaultBodyLimit, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use utoipa::{OpenApi, ToSchema};
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
use crate::command::serve::cli::ServeState;
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
tags(),
|
||||
paths(lookup),
|
||||
components(schemas(LookupRequest, LookupResponse, LookupResult))
|
||||
)]
|
||||
pub(super) struct Api;
|
||||
|
||||
#[inline]
|
||||
pub(super) fn router(state: Arc<ServeState>) -> Router<()> {
|
||||
let docs_path = "/docs";
|
||||
let docs = SwaggerUi::new(docs_path).url(format!("{}/openapi.json", docs_path), Api::openapi());
|
||||
|
||||
Router::new()
|
||||
.route("/lookup", post(lookup))
|
||||
.merge(docs)
|
||||
.with_state(state)
|
||||
.layer(DefaultBodyLimit::max(32 * 1024 * 1024))
|
||||
}
|
||||
|
||||
//
|
||||
// MARK: lookup
|
||||
//
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Debug)]
|
||||
pub struct LookupRequest {
|
||||
pub query: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
struct LookupResponse {
|
||||
pub results: Vec<LookupResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct LookupResult {
|
||||
pub score: f32,
|
||||
pub source: String,
|
||||
pub key: String,
|
||||
}
|
||||
|
||||
/// Search a user's captures
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/lookup",
|
||||
responses(
|
||||
(status = 200, description = "Search results", body = Vec<LookupResponse>),
|
||||
(status = 400, description = "Invalid request"),
|
||||
(status = 401, description = "Unauthorized"),
|
||||
(status = 404, description = "URL not found"),
|
||||
(status = 500, description = "Internal server error"),
|
||||
)
|
||||
)]
|
||||
async fn lookup(State(state): State<Arc<ServeState>>, Json(body): Json<LookupRequest>) -> Response {
|
||||
let results: Vec<LookupResult> =
|
||||
match state.ds.fts_lookup(&body.query, body.limit.unwrap_or(10)) {
|
||||
Ok(x) => x
|
||||
.into_iter()
|
||||
.map(|x| LookupResult {
|
||||
key: x.key,
|
||||
score: x.score,
|
||||
source: x.source.into(),
|
||||
})
|
||||
.collect(),
|
||||
|
||||
Err(error) => {
|
||||
return (StatusCode::INTERNAL_SERVER_ERROR, format!("{error:?}")).into_response();
|
||||
}
|
||||
};
|
||||
|
||||
return (StatusCode::OK, Json(LookupResponse { results })).into_response();
|
||||
}
|
||||
99
crates/pile/src/command/serve/cli.rs
Normal file
99
crates/pile/src/command/serve/cli.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Args;
|
||||
use pile_dataset::Datasets;
|
||||
use pile_toolbox::cancelabletask::{CancelFlag, CancelableTaskError};
|
||||
use std::{fmt::Debug, path::PathBuf, sync::Arc};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{CliCmd, GlobalContext, command::serve::api};
|
||||
|
||||
pub(super) struct ServeState {
|
||||
pub ds: Datasets,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct ServeCommand {
|
||||
/// Address to bind to
|
||||
#[arg(default_value = "0.0.0.0:9000")]
|
||||
addr: String,
|
||||
|
||||
/// Path to dataset config
|
||||
#[arg(long, short = 'c', default_value = "./pile.toml")]
|
||||
config: PathBuf,
|
||||
|
||||
/// If provided, refresh fts if it is out-of-date
|
||||
#[arg(long)]
|
||||
refresh: bool,
|
||||
|
||||
/// Number of threads to use for indexing
|
||||
#[arg(long, short = 'j', default_value = "3")]
|
||||
jobs: usize,
|
||||
}
|
||||
|
||||
impl CliCmd for ServeCommand {
|
||||
async fn run(
|
||||
self,
|
||||
_ctx: GlobalContext,
|
||||
flag: CancelFlag,
|
||||
) -> Result<i32, CancelableTaskError<anyhow::Error>> {
|
||||
let ds = Datasets::open(&self.config)
|
||||
.with_context(|| format!("while opening dataset for {}", self.config.display()))?;
|
||||
|
||||
if self.refresh && ds.needs_fts().await.context("while checking dataset fts")? {
|
||||
info!("FTS index is missing or out-of-date, regenerating");
|
||||
ds.fts_refresh(self.jobs, Some(flag.clone()))
|
||||
.await
|
||||
.map_err(|x| {
|
||||
x.map_err(|x| {
|
||||
anyhow::Error::from(x).context(format!(
|
||||
"while refreshing fts for {}",
|
||||
self.config.display()
|
||||
))
|
||||
})
|
||||
})?;
|
||||
}
|
||||
|
||||
let app = api::router(Arc::new(ServeState { ds }))
|
||||
.into_make_service_with_connect_info::<std::net::SocketAddr>();
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(self.addr.clone()).await {
|
||||
Ok(x) => x,
|
||||
Err(error) => {
|
||||
match error.kind() {
|
||||
std::io::ErrorKind::AddrInUse => {
|
||||
error!(
|
||||
message = "Cannot bind to address, already in use",
|
||||
addr = self.addr
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
error!(message = "Error while starting server", ?error);
|
||||
}
|
||||
}
|
||||
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
match listener.local_addr() {
|
||||
Ok(x) => info!("listening on http://{x}/docs"),
|
||||
Err(error) => {
|
||||
error!(message = "Could not determine local address", ?error);
|
||||
return Err(anyhow::Error::from(error).into());
|
||||
}
|
||||
}
|
||||
|
||||
match axum::serve(listener, app)
|
||||
.with_graceful_shutdown(async move { flag.await_cancel().await })
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(error) => {
|
||||
error!(message = "Error while serving api", ?error);
|
||||
return Err(anyhow::Error::from(error).into());
|
||||
}
|
||||
}
|
||||
|
||||
return Err(CancelableTaskError::Cancelled);
|
||||
}
|
||||
}
|
||||
2
crates/pile/src/command/serve/mod.rs
Normal file
2
crates/pile/src/command/serve/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod api;
|
||||
pub mod cli;
|
||||
Reference in New Issue
Block a user