diff --git a/crates/pile-toolbox/Cargo.toml b/crates/pile-toolbox/Cargo.toml new file mode 100644 index 0000000..e225fa0 --- /dev/null +++ b/crates/pile-toolbox/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "pile-toolbox" +version = { workspace = true } +rust-version = { workspace = true } +edition = { workspace = true } + +[lints] +workspace = true + +[dependencies] +tokio = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/pile-toolbox/src/cancelabletask.rs b/crates/pile-toolbox/src/cancelabletask.rs new file mode 100644 index 0000000..64b9ec6 --- /dev/null +++ b/crates/pile-toolbox/src/cancelabletask.rs @@ -0,0 +1,156 @@ +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; +use thiserror::Error; +use tokio::{ + sync::Notify, + task::{JoinError, JoinHandle}, +}; + +/// A helper type that makes CancelableTasks easier to write. +/// +/// A future that returns a `Result>` +/// can trivially be converted to a `CancelableTask>`. +/// +/// Be careful propagating this with ?, it isn't always safe to do so. +#[derive(Debug, Error)] +pub enum CancelableTaskError { + Error(#[from] E), + Cancelled, +} + +impl CancelableTaskError { + pub fn map_err F>(self, f: O) -> CancelableTaskError { + match self { + Self::Cancelled => CancelableTaskError::Cancelled, + Self::Error(e) => CancelableTaskError::Error(f(e)), + } + } +} + +impl From>> for CancelableTaskResult> { + fn from(value: Result>) -> Self { + match value { + Err(CancelableTaskError::Cancelled) => CancelableTaskResult::Cancelled, + Ok(t) => CancelableTaskResult::Finished(Ok(t)), + Err(CancelableTaskError::Error(e)) => CancelableTaskResult::Finished(Err(e)), + } + } +} + +pub enum CancelableTaskResult { + /// This operation was canceled, + /// and was able to exit cleanly. + Cancelled, + + /// This operation finished. + /// + /// This variant is returned if an error + /// occurs after this task is cancelled. + Finished(T), +} + +impl From for CancelableTaskResult { + fn from(value: T) -> Self { + Self::Finished(value) + } +} + +#[derive(Clone)] +pub struct CancelFlag { + flag: Arc, + notify: Arc, +} + +impl CancelFlag { + #[inline] + pub fn new() -> Self { + Self { + flag: Arc::new(AtomicBool::new(false)), + notify: Arc::new(Notify::new()), + } + } + + #[inline] + pub fn cancel(&self) { + self.flag.store(true, Ordering::Release); + self.notify.notify_waiters(); + } + + #[inline] + pub async fn await_cancel(&self) { + self.notify.notified().await; + assert!(self.is_cancelled()); + } + + #[inline] + pub fn is_cancelled(&self) -> bool { + return self.flag.load(Ordering::Acquire); + } +} + +/// A handle to an asynchronous task that may be cancelled. +/// This is a wrapper around [JoinHandle]. +/// +/// Execution may only be interrupted at a well-defined breakpoint, +/// as opposed to native future cancellation (which might stop at any +/// `await` barrier.) +/// +/// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe. +/// If this is dropped, `task` will continue to run. +pub struct CancelableTask { + flag: CancelFlag, + + /// Always some unless joined. + /// Used to detect dropped unjoined tasks. + task: Option>>, +} + +impl CancelableTask { + /// Spawn a new cancellable task + /// + /// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe. + /// If this is dropped, `task` will continue to run. + /// Nonetheless, avoid doing so. + #[must_use] + pub fn spawn(spawner: S) -> Self + where + S: FnOnce(CancelFlag) -> F, + F: Future + Send + 'static, + F::Output: Into>, + { + let flag = CancelFlag::new(); + let task = spawner(flag.clone()); + let task = tokio::task::spawn(async move { + let res = task.await; + let out: CancelableTaskResult = res.into(); + out + }); + + Self { + flag, + task: Some(task), + } + } + + #[inline] + pub fn flag(&self) -> &CancelFlag { + &self.flag + } + + /// If true, this task is no longer running. + /// May be cancelled or finished. + #[inline] + pub fn is_finished(&self) -> bool { + #[expect(clippy::unwrap_used)] + self.task.as_ref().unwrap().is_finished() + } + + /// Wait for this task to finish + #[inline] + pub async fn join(mut self) -> Result, JoinError> { + #[expect(clippy::unwrap_used)] + self.task.take().unwrap().await + } +} diff --git a/crates/pile-toolbox/src/lib.rs b/crates/pile-toolbox/src/lib.rs new file mode 100644 index 0000000..2c07e91 --- /dev/null +++ b/crates/pile-toolbox/src/lib.rs @@ -0,0 +1 @@ +pub mod cancelabletask;