use std::sync::{ Arc, atomic::{AtomicBool, Ordering}, }; use thiserror::Error; use tokio::{ sync::Notify, task::{JoinError, JoinHandle}, }; /// A helper type that makes CancelableTasks easier to write. /// /// A future that returns a `Result>` /// can trivially be converted to a `CancelableTask>`. /// /// Be careful propagating this with ?, it isn't always safe to do so. #[derive(Debug, Error)] pub enum CancelableTaskError { Error(#[from] E), Cancelled, } impl CancelableTaskError { pub fn map_err F>(self, f: O) -> CancelableTaskError { match self { Self::Cancelled => CancelableTaskError::Cancelled, Self::Error(e) => CancelableTaskError::Error(f(e)), } } } impl From>> for CancelableTaskResult> { fn from(value: Result>) -> Self { match value { Err(CancelableTaskError::Cancelled) => CancelableTaskResult::Cancelled, Ok(t) => CancelableTaskResult::Finished(Ok(t)), Err(CancelableTaskError::Error(e)) => CancelableTaskResult::Finished(Err(e)), } } } pub enum CancelableTaskResult { /// This operation was canceled, /// and was able to exit cleanly. Cancelled, /// This operation finished. /// /// This variant is returned if an error /// occurs after this task is cancelled. Finished(T), } impl From for CancelableTaskResult { fn from(value: T) -> Self { Self::Finished(value) } } #[derive(Clone)] pub struct CancelFlag { flag: Arc, notify: Arc, } impl CancelFlag { #[inline] pub fn new() -> Self { Self { flag: Arc::new(AtomicBool::new(false)), notify: Arc::new(Notify::new()), } } #[inline] pub fn cancel(&self) { self.flag.store(true, Ordering::Release); self.notify.notify_waiters(); } #[inline] pub async fn await_cancel(&self) { self.notify.notified().await; assert!(self.is_cancelled()); } #[inline] pub fn is_cancelled(&self) -> bool { return self.flag.load(Ordering::Acquire); } } /// A handle to an asynchronous task that may be cancelled. /// This is a wrapper around [JoinHandle]. /// /// Execution may only be interrupted at a well-defined breakpoint, /// as opposed to native future cancellation (which might stop at any /// `await` barrier.) /// /// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe. /// If this is dropped, `task` will continue to run. pub struct CancelableTask { flag: CancelFlag, /// Always some unless joined. /// Used to detect dropped unjoined tasks. task: Option>>, } impl CancelableTask { /// Spawn a new cancellable task /// /// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe. /// If this is dropped, `task` will continue to run. /// Nonetheless, avoid doing so. #[must_use] pub fn spawn(spawner: S) -> Self where S: FnOnce(CancelFlag) -> F, F: Future + Send + 'static, F::Output: Into>, { let flag = CancelFlag::new(); let task = spawner(flag.clone()); let task = tokio::task::spawn(async move { let res = task.await; let out: CancelableTaskResult = res.into(); out }); Self { flag, task: Some(task), } } #[inline] pub fn flag(&self) -> &CancelFlag { &self.flag } /// If true, this task is no longer running. /// May be cancelled or finished. #[inline] pub fn is_finished(&self) -> bool { #[expect(clippy::unwrap_used)] self.task.as_ref().unwrap().is_finished() } /// Wait for this task to finish #[inline] pub async fn join(mut self) -> Result, JoinError> { #[expect(clippy::unwrap_used)] self.task.take().unwrap().await } }