Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'helix-event/src/cancel.rs')
-rw-r--r--helix-event/src/cancel.rs282
1 files changed, 275 insertions, 7 deletions
diff --git a/helix-event/src/cancel.rs b/helix-event/src/cancel.rs
index f027be80..2029c945 100644
--- a/helix-event/src/cancel.rs
+++ b/helix-event/src/cancel.rs
@@ -1,15 +1,18 @@
+use std::borrow::Borrow;
use std::future::Future;
+use std::sync::atomic::AtomicU64;
+use std::sync::atomic::Ordering::Relaxed;
+use std::sync::Arc;
-pub use oneshot::channel as cancelation;
-use tokio::sync::oneshot;
+use tokio::sync::Notify;
-pub type CancelTx = oneshot::Sender<()>;
-pub type CancelRx = oneshot::Receiver<()>;
-
-pub async fn cancelable_future<T>(future: impl Future<Output = T>, cancel: CancelRx) -> Option<T> {
+pub async fn cancelable_future<T>(
+ future: impl Future<Output = T>,
+ cancel: impl Borrow<TaskHandle>,
+) -> Option<T> {
tokio::select! {
biased;
- _ = cancel => {
+ _ = cancel.borrow().canceled() => {
None
}
res = future => {
@@ -17,3 +20,268 @@ pub async fn cancelable_future<T>(future: impl Future<Output = T>, cancel: Cance
}
}
}
+
+#[derive(Default, Debug)]
+struct Shared {
+ state: AtomicU64,
+ // `Notify` has some features that we don't really need here because it
+ // supports waking single tasks (`notify_one`) and does its own (more
+ // complicated) state tracking, we could reimplement the waiter linked list
+ // with modest effort and reduce memory consumption by one word/8 bytes and
+ // reduce code complexity/number of atomic operations.
+ //
+ // I don't think that's worth the complexity (unsafe code).
+ //
+ // if we only cared about async code then we could also only use a notify
+ // (without the generation count), this would be equivalent (or maybe more
+ // correct if we want to allow cloning the TX) but it would be extremly slow
+ // to frequently check for cancelation from sync code
+ notify: Notify,
+}
+
+impl Shared {
+ fn generation(&self) -> u32 {
+ self.state.load(Relaxed) as u32
+ }
+
+ fn num_running(&self) -> u32 {
+ (self.state.load(Relaxed) >> 32) as u32
+ }
+
+ /// Increments the generation count and sets `num_running`
+ /// to the provided value, this operation is not with
+ /// regard to the generation counter (doesn't use `fetch_add`)
+ /// so the calling code must ensure it cannot execute concurrently
+ /// to maintain correctness (but not safety)
+ fn inc_generation(&self, num_running: u32) -> (u32, u32) {
+ let state = self.state.load(Relaxed);
+ let generation = state as u32;
+ let prev_running = (state >> 32) as u32;
+ // no need to create a new generation if the refcount is zero (fastpath)
+ if prev_running == 0 && num_running == 0 {
+ return (generation, 0);
+ }
+ let new_generation = generation.saturating_add(1);
+ self.state.store(
+ new_generation as u64 | ((num_running as u64) << 32),
+ Relaxed,
+ );
+ self.notify.notify_waiters();
+ (new_generation, prev_running)
+ }
+
+ fn inc_running(&self, generation: u32) {
+ let mut state = self.state.load(Relaxed);
+ loop {
+ let current_generation = state as u32;
+ if current_generation != generation {
+ break;
+ }
+ let off = 1 << 32;
+ let res = self.state.compare_exchange_weak(
+ state,
+ state.saturating_add(off),
+ Relaxed,
+ Relaxed,
+ );
+ match res {
+ Ok(_) => break,
+ Err(new_state) => state = new_state,
+ }
+ }
+ }
+
+ fn dec_running(&self, generation: u32) {
+ let mut state = self.state.load(Relaxed);
+ loop {
+ let current_generation = state as u32;
+ if current_generation != generation {
+ break;
+ }
+ let num_running = (state >> 32) as u32;
+ // running can't be zero here, that would mean we miscounted somewhere
+ assert_ne!(num_running, 0);
+ let off = 1 << 32;
+ let res = self
+ .state
+ .compare_exchange_weak(state, state - off, Relaxed, Relaxed);
+ match res {
+ Ok(_) => break,
+ Err(new_state) => state = new_state,
+ }
+ }
+ }
+}
+
+// This intentionally doesn't implement `Clone` and requires a mutable reference
+// for cancelation to avoid races (in inc_generation).
+
+/// A task controller allows managing a single subtask enabling the controller
+/// to cancel the subtask and to check whether it is still running.
+///
+/// For efficiency reasons the controller can be reused/restarted,
+/// in that case the previous task is automatically canceled.
+///
+/// If the controller is dropped, the subtasks are automatically canceled.
+#[derive(Default, Debug)]
+pub struct TaskController {
+ shared: Arc<Shared>,
+}
+
+impl TaskController {
+ pub fn new() -> Self {
+ TaskController::default()
+ }
+ /// Cancels the active task (handle).
+ ///
+ /// Returns whether any tasks were still running before the cancelation.
+ pub fn cancel(&mut self) -> bool {
+ self.shared.inc_generation(0).1 != 0
+ }
+
+ /// Checks whether there are any task handles
+ /// that haven't been dropped (or canceled) yet.
+ pub fn is_running(&self) -> bool {
+ self.shared.num_running() != 0
+ }
+
+ /// Starts a new task and cancels the previous task (handles).
+ pub fn restart(&mut self) -> TaskHandle {
+ TaskHandle {
+ generation: self.shared.inc_generation(1).0,
+ shared: self.shared.clone(),
+ }
+ }
+}
+
+impl Drop for TaskController {
+ fn drop(&mut self) {
+ self.cancel();
+ }
+}
+
+/// A handle that is used to link a task with a task controller.
+///
+/// It can be used to cancel async futures very efficiently but can also be checked for
+/// cancelation very quickly (single atomic read) in blocking code.
+/// The handle can be cheaply cloned (reference counted).
+///
+/// The TaskController can check whether a task is "running" by inspecting the
+/// refcount of the (current) tasks handles. Therefore, if that information
+/// is important, ensure that the handle is not dropped until the task fully
+/// completes.
+pub struct TaskHandle {
+ shared: Arc<Shared>,
+ generation: u32,
+}
+
+impl Clone for TaskHandle {
+ fn clone(&self) -> Self {
+ self.shared.inc_running(self.generation);
+ TaskHandle {
+ shared: self.shared.clone(),
+ generation: self.generation,
+ }
+ }
+}
+
+impl Drop for TaskHandle {
+ fn drop(&mut self) {
+ self.shared.dec_running(self.generation);
+ }
+}
+
+impl TaskHandle {
+ /// Waits until [`TaskController::cancel`] is called for the corresponding
+ /// [`TaskController`]. Immediately returns if `cancel` was already called since
+ pub async fn canceled(&self) {
+ let notified = self.shared.notify.notified();
+ if !self.is_canceled() {
+ notified.await
+ }
+ }
+
+ pub fn is_canceled(&self) -> bool {
+ self.generation != self.shared.generation()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::future::poll_fn;
+
+ use futures_executor::block_on;
+ use tokio::task::yield_now;
+
+ use crate::{cancelable_future, TaskController};
+
+ #[test]
+ fn immediate_cancel() {
+ let mut controller = TaskController::new();
+ let handle = controller.restart();
+ controller.cancel();
+ assert!(handle.is_canceled());
+ controller.restart();
+ assert!(handle.is_canceled());
+
+ let res = block_on(cancelable_future(
+ poll_fn(|_cx| std::task::Poll::Ready(())),
+ handle,
+ ));
+ assert!(res.is_none());
+ }
+
+ #[test]
+ fn running_count() {
+ let mut controller = TaskController::new();
+ let handle = controller.restart();
+ assert!(controller.is_running());
+ assert!(!handle.is_canceled());
+ drop(handle);
+ assert!(!controller.is_running());
+ assert!(!controller.cancel());
+ let handle = controller.restart();
+ assert!(!handle.is_canceled());
+ assert!(controller.is_running());
+ let handle2 = handle.clone();
+ assert!(!handle.is_canceled());
+ assert!(controller.is_running());
+ drop(handle2);
+ assert!(!handle.is_canceled());
+ assert!(controller.is_running());
+ assert!(controller.cancel());
+ assert!(handle.is_canceled());
+ assert!(!controller.is_running());
+ }
+
+ #[test]
+ fn no_cancel() {
+ let mut controller = TaskController::new();
+ let handle = controller.restart();
+ assert!(!handle.is_canceled());
+
+ let res = block_on(cancelable_future(
+ poll_fn(|_cx| std::task::Poll::Ready(())),
+ handle,
+ ));
+ assert!(res.is_some());
+ }
+
+ #[test]
+ fn delayed_cancel() {
+ let mut controller = TaskController::new();
+ let handle = controller.restart();
+
+ let mut hit = false;
+ let res = block_on(cancelable_future(
+ async {
+ controller.cancel();
+ hit = true;
+ yield_now().await;
+ },
+ handle,
+ ));
+ assert!(res.is_none());
+ assert!(hit);
+ }
+}