async_task/
header.rs

1use core::cell::UnsafeCell;
2use core::fmt;
3use core::sync::atomic::{AtomicUsize, Ordering};
4use core::task::Waker;
5
6use crate::raw::TaskVTable;
7use crate::state::*;
8use crate::utils::abort_on_panic;
9
10/// The header of a task.
11///
12/// This header is stored in memory at the beginning of the heap-allocated task.
13pub(crate) struct Header {
14    /// Current state of the task.
15    ///
16    /// Contains flags representing the current state and the reference count.
17    pub(crate) state: AtomicUsize,
18
19    /// The task that is blocked on the `Task` handle.
20    ///
21    /// This waker needs to be woken up once the task completes or is closed.
22    pub(crate) awaiter: UnsafeCell<Option<Waker>>,
23
24    /// The virtual table.
25    ///
26    /// In addition to the actual waker virtual table, it also contains pointers to several other
27    /// methods necessary for bookkeeping the heap-allocated task.
28    pub(crate) vtable: &'static TaskVTable,
29}
30
31impl Header {
32    /// Notifies the awaiter blocked on this task.
33    ///
34    /// If the awaiter is the same as the current waker, it will not be notified.
35    #[inline]
36    pub(crate) fn notify(&self, current: Option<&Waker>) {
37        if let Some(w) = self.take(current) {
38            abort_on_panic(|| w.wake());
39        }
40    }
41
42    /// Takes the awaiter blocked on this task.
43    ///
44    /// If there is no awaiter or if it is the same as the current waker, returns `None`.
45    #[inline]
46    pub(crate) fn take(&self, current: Option<&Waker>) -> Option<Waker> {
47        // Set the bit indicating that the task is notifying its awaiter.
48        let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel);
49
50        // If the task was not notifying or registering an awaiter...
51        if state & (NOTIFYING | REGISTERING) == 0 {
52            // Take the waker out.
53            let waker = unsafe { (*self.awaiter.get()).take() };
54
55            // Unset the bit indicating that the task is notifying its awaiter.
56            self.state
57                .fetch_and(!NOTIFYING & !AWAITER, Ordering::Release);
58
59            // Finally, notify the waker if it's different from the current waker.
60            if let Some(w) = waker {
61                match current {
62                    None => return Some(w),
63                    Some(c) if !w.will_wake(c) => return Some(w),
64                    Some(_) => abort_on_panic(|| drop(w)),
65                }
66            }
67        }
68
69        None
70    }
71
72    /// Registers a new awaiter blocked on this task.
73    ///
74    /// This method is called when `Task` is polled and it has not yet completed.
75    #[inline]
76    pub(crate) fn register(&self, waker: &Waker) {
77        // Load the state and synchronize with it.
78        let mut state = self.state.fetch_or(0, Ordering::Acquire);
79
80        loop {
81            // There can't be two concurrent registrations because `Task` can only be polled
82            // by a unique pinned reference.
83            debug_assert!(state & REGISTERING == 0);
84
85            // If we're in the notifying state at this moment, just wake and return without
86            // registering.
87            if state & NOTIFYING != 0 {
88                abort_on_panic(|| waker.wake_by_ref());
89                return;
90            }
91
92            // Mark the state to let other threads know we're registering a new awaiter.
93            match self.state.compare_exchange_weak(
94                state,
95                state | REGISTERING,
96                Ordering::AcqRel,
97                Ordering::Acquire,
98            ) {
99                Ok(_) => {
100                    state |= REGISTERING;
101                    break;
102                }
103                Err(s) => state = s,
104            }
105        }
106
107        // Put the waker into the awaiter field.
108        unsafe {
109            abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone()));
110        }
111
112        // This variable will contain the newly registered waker if a notification comes in before
113        // we complete registration.
114        let mut waker = None;
115
116        loop {
117            // If there was a notification, take the waker out of the awaiter field.
118            if state & NOTIFYING != 0 {
119                if let Some(w) = unsafe { (*self.awaiter.get()).take() } {
120                    abort_on_panic(|| waker = Some(w));
121                }
122            }
123
124            // The new state is not being notified nor registered, but there might or might not be
125            // an awaiter depending on whether there was a concurrent notification.
126            let new = if waker.is_none() {
127                (state & !NOTIFYING & !REGISTERING) | AWAITER
128            } else {
129                state & !NOTIFYING & !REGISTERING & !AWAITER
130            };
131
132            match self
133                .state
134                .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
135            {
136                Ok(_) => break,
137                Err(s) => state = s,
138            }
139        }
140
141        // If there was a notification during registration, wake the awaiter now.
142        if let Some(w) = waker {
143            abort_on_panic(|| w.wake());
144        }
145    }
146}
147
148impl fmt::Debug for Header {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        let state = self.state.load(Ordering::SeqCst);
151
152        f.debug_struct("Header")
153            .field("scheduled", &(state & SCHEDULED != 0))
154            .field("running", &(state & RUNNING != 0))
155            .field("completed", &(state & COMPLETED != 0))
156            .field("closed", &(state & CLOSED != 0))
157            .field("awaiter", &(state & AWAITER != 0))
158            .field("task", &(state & TASK != 0))
159            .field("ref_count", &(state / REFERENCE))
160            .finish()
161    }
162}