1#![cfg_attr(not(feature = "full"), allow(dead_code))]
2
3use crate::runtime::context;
33
34#[derive(Debug, Copy, Clone)]
37pub(crate) struct Budget(Option<u8>);
38
39pub(crate) struct BudgetDecrement {
40    success: bool,
41    hit_zero: bool,
42}
43
44impl Budget {
45    const fn initial() -> Budget {
56        Budget(Some(128))
57    }
58
59    pub(super) const fn unconstrained() -> Budget {
61        Budget(None)
62    }
63
64    fn has_remaining(self) -> bool {
65        self.0.map_or(true, |budget| budget > 0)
66    }
67}
68
69#[inline(always)]
72pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R {
73    with_budget(Budget::initial(), f)
74}
75
76#[inline(always)]
79pub(crate) fn with_unconstrained<R>(f: impl FnOnce() -> R) -> R {
80    with_budget(Budget::unconstrained(), f)
81}
82
83#[inline(always)]
84fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R {
85    struct ResetGuard {
86        prev: Budget,
87    }
88
89    impl Drop for ResetGuard {
90        fn drop(&mut self) {
91            let _ = context::budget(|cell| {
92                cell.set(self.prev);
93            });
94        }
95    }
96
97    #[allow(unused_variables)]
98    let maybe_guard = context::budget(|cell| {
99        let prev = cell.get();
100        cell.set(budget);
101
102        ResetGuard { prev }
103    });
104
105    f()
108}
109
110#[inline(always)]
111pub(crate) fn has_budget_remaining() -> bool {
112    context::budget(|cell| cell.get().has_remaining()).unwrap_or(true)
115}
116
117cfg_rt_multi_thread! {
118    pub(crate) fn set(budget: Budget) {
120        let _ = context::budget(|cell| cell.set(budget));
121    }
122}
123
124cfg_rt! {
125    pub(crate) fn stop() -> Budget {
129        context::budget(|cell| {
130            let prev = cell.get();
131            cell.set(Budget::unconstrained());
132            prev
133        }).unwrap_or(Budget::unconstrained())
134    }
135}
136
137cfg_coop! {
138    use std::cell::Cell;
139    use std::task::{Context, Poll};
140
141    #[must_use]
142    pub(crate) struct RestoreOnPending(Cell<Budget>);
143
144    impl RestoreOnPending {
145        pub(crate) fn made_progress(&self) {
146            self.0.set(Budget::unconstrained());
147        }
148    }
149
150    impl Drop for RestoreOnPending {
151        fn drop(&mut self) {
152            let budget = self.0.get();
155            if !budget.is_unconstrained() {
156                let _ = context::budget(|cell| {
157                    cell.set(budget);
158                });
159            }
160        }
161    }
162
163    #[inline]
176    pub(crate) fn poll_proceed(cx: &mut Context<'_>) -> Poll<RestoreOnPending> {
177        context::budget(|cell| {
178            let mut budget = cell.get();
179
180            let decrement = budget.decrement();
181
182            if decrement.success {
183                let restore = RestoreOnPending(Cell::new(cell.get()));
184                cell.set(budget);
185
186                if decrement.hit_zero {
188                    inc_budget_forced_yield_count();
189                }
190
191                Poll::Ready(restore)
192            } else {
193                cx.waker().wake_by_ref();
194                Poll::Pending
195            }
196        }).unwrap_or(Poll::Ready(RestoreOnPending(Cell::new(Budget::unconstrained()))))
197    }
198
199    cfg_rt! {
200        cfg_unstable_metrics! {
201            #[inline(always)]
202            fn inc_budget_forced_yield_count() {
203                let _ = context::with_current(|handle| {
204                    handle.scheduler_metrics().inc_budget_forced_yield_count();
205                });
206            }
207        }
208
209        cfg_not_unstable_metrics! {
210            #[inline(always)]
211            fn inc_budget_forced_yield_count() {}
212        }
213    }
214
215    cfg_not_rt! {
216        #[inline(always)]
217        fn inc_budget_forced_yield_count() {}
218    }
219
220    impl Budget {
221        fn decrement(&mut self) -> BudgetDecrement {
224            if let Some(num) = &mut self.0 {
225                if *num > 0 {
226                    *num -= 1;
227
228                    let hit_zero = *num == 0;
229
230                    BudgetDecrement { success: true, hit_zero }
231                } else {
232                    BudgetDecrement { success: false, hit_zero: false }
233                }
234            } else {
235                BudgetDecrement { success: true, hit_zero: false }
236            }
237        }
238
239        fn is_unconstrained(self) -> bool {
240            self.0.is_none()
241        }
242    }
243}
244
245#[cfg(all(test, not(loom)))]
246mod test {
247    use super::*;
248
249    #[cfg(all(target_family = "wasm", not(target_os = "wasi")))]
250    use wasm_bindgen_test::wasm_bindgen_test as test;
251
252    fn get() -> Budget {
253        context::budget(|cell| cell.get()).unwrap_or(Budget::unconstrained())
254    }
255
256    #[test]
257    fn budgeting() {
258        use futures::future::poll_fn;
259        use tokio_test::*;
260
261        assert!(get().0.is_none());
262
263        let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
264
265        assert!(get().0.is_none());
266        drop(coop);
267        assert!(get().0.is_none());
268
269        budget(|| {
270            assert_eq!(get().0, Budget::initial().0);
271
272            let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
273            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
274            drop(coop);
275            assert_eq!(get().0, Budget::initial().0);
277
278            let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
279            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
280            coop.made_progress();
281            drop(coop);
282            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
284
285            let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
286            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
287            coop.made_progress();
288            drop(coop);
289            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
290
291            budget(|| {
292                assert_eq!(get().0, Budget::initial().0);
293
294                let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
295                assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
296                coop.made_progress();
297                drop(coop);
298                assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 1);
299            });
300
301            assert_eq!(get().0.unwrap(), Budget::initial().0.unwrap() - 2);
302        });
303
304        assert!(get().0.is_none());
305
306        budget(|| {
307            let n = get().0.unwrap();
308
309            for _ in 0..n {
310                let coop = assert_ready!(task::spawn(()).enter(|cx, _| poll_proceed(cx)));
311                coop.made_progress();
312            }
313
314            let mut task = task::spawn(poll_fn(|cx| {
315                let coop = ready!(poll_proceed(cx));
316                coop.made_progress();
317                Poll::Ready(())
318            }));
319
320            assert_pending!(task.poll());
321        });
322    }
323}