157 lines
3.7 KiB
Rust
157 lines
3.7 KiB
Rust
use std::sync::{
|
|
Arc,
|
|
atomic::{AtomicBool, Ordering},
|
|
};
|
|
use thiserror::Error;
|
|
use tokio::{
|
|
sync::Notify,
|
|
task::{JoinError, JoinHandle},
|
|
};
|
|
|
|
/// A helper type that makes CancelableTasks easier to write.
|
|
///
|
|
/// A future that returns a `Result<T, CancelableTaskError<E>>`
|
|
/// can trivially be converted to a `CancelableTask<Result<T, E>>`.
|
|
///
|
|
/// Be careful propagating this with ?, it isn't always safe to do so.
|
|
#[derive(Debug, Error)]
|
|
pub enum CancelableTaskError<E> {
|
|
Error(#[from] E),
|
|
Cancelled,
|
|
}
|
|
|
|
impl<E> CancelableTaskError<E> {
|
|
pub fn map_err<F, O: FnOnce(E) -> F>(self, f: O) -> CancelableTaskError<F> {
|
|
match self {
|
|
Self::Cancelled => CancelableTaskError::Cancelled,
|
|
Self::Error(e) => CancelableTaskError::Error(f(e)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T, E> From<Result<T, CancelableTaskError<E>>> for CancelableTaskResult<Result<T, E>> {
|
|
fn from(value: Result<T, CancelableTaskError<E>>) -> Self {
|
|
match value {
|
|
Err(CancelableTaskError::Cancelled) => CancelableTaskResult::Cancelled,
|
|
Ok(t) => CancelableTaskResult::Finished(Ok(t)),
|
|
Err(CancelableTaskError::Error(e)) => CancelableTaskResult::Finished(Err(e)),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub enum CancelableTaskResult<T> {
|
|
/// This operation was canceled,
|
|
/// and was able to exit cleanly.
|
|
Cancelled,
|
|
|
|
/// This operation finished.
|
|
///
|
|
/// This variant is returned if an error
|
|
/// occurs after this task is cancelled.
|
|
Finished(T),
|
|
}
|
|
|
|
impl<T> From<T> for CancelableTaskResult<T> {
|
|
fn from(value: T) -> Self {
|
|
Self::Finished(value)
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct CancelFlag {
|
|
flag: Arc<AtomicBool>,
|
|
notify: Arc<Notify>,
|
|
}
|
|
|
|
impl CancelFlag {
|
|
#[inline]
|
|
pub fn new() -> Self {
|
|
Self {
|
|
flag: Arc::new(AtomicBool::new(false)),
|
|
notify: Arc::new(Notify::new()),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn cancel(&self) {
|
|
self.flag.store(true, Ordering::Release);
|
|
self.notify.notify_waiters();
|
|
}
|
|
|
|
#[inline]
|
|
pub async fn await_cancel(&self) {
|
|
self.notify.notified().await;
|
|
assert!(self.is_cancelled());
|
|
}
|
|
|
|
#[inline]
|
|
pub fn is_cancelled(&self) -> bool {
|
|
return self.flag.load(Ordering::Acquire);
|
|
}
|
|
}
|
|
|
|
/// A handle to an asynchronous task that may be cancelled.
|
|
/// This is a wrapper around [JoinHandle].
|
|
///
|
|
/// Execution may only be interrupted at a well-defined breakpoint,
|
|
/// as opposed to native future cancellation (which might stop at any
|
|
/// `await` barrier.)
|
|
///
|
|
/// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe.
|
|
/// If this is dropped, `task` will continue to run.
|
|
pub struct CancelableTask<T: Send + 'static> {
|
|
flag: CancelFlag,
|
|
|
|
/// Always some unless joined.
|
|
/// Used to detect dropped unjoined tasks.
|
|
task: Option<JoinHandle<CancelableTaskResult<T>>>,
|
|
}
|
|
|
|
impl<T: Send + 'static> CancelableTask<T> {
|
|
/// Spawn a new cancellable task
|
|
///
|
|
/// It is safe to drop this struct because [tokio::JoinHandle] is cancel safe.
|
|
/// If this is dropped, `task` will continue to run.
|
|
/// Nonetheless, avoid doing so.
|
|
#[must_use]
|
|
pub fn spawn<S, F>(spawner: S) -> Self
|
|
where
|
|
S: FnOnce(CancelFlag) -> F,
|
|
F: Future + Send + 'static,
|
|
F::Output: Into<CancelableTaskResult<T>>,
|
|
{
|
|
let flag = CancelFlag::new();
|
|
let task = spawner(flag.clone());
|
|
let task = tokio::task::spawn(async move {
|
|
let res = task.await;
|
|
let out: CancelableTaskResult<T> = res.into();
|
|
out
|
|
});
|
|
|
|
Self {
|
|
flag,
|
|
task: Some(task),
|
|
}
|
|
}
|
|
|
|
#[inline]
|
|
pub fn flag(&self) -> &CancelFlag {
|
|
&self.flag
|
|
}
|
|
|
|
/// If true, this task is no longer running.
|
|
/// May be cancelled or finished.
|
|
#[inline]
|
|
pub fn is_finished(&self) -> bool {
|
|
#[expect(clippy::unwrap_used)]
|
|
self.task.as_ref().unwrap().is_finished()
|
|
}
|
|
|
|
/// Wait for this task to finish
|
|
#[inline]
|
|
pub async fn join(mut self) -> Result<CancelableTaskResult<T>, JoinError> {
|
|
#[expect(clippy::unwrap_used)]
|
|
self.task.take().unwrap().await
|
|
}
|
|
}
|