run_test_suite_lib/
stream_util.rs1use futures::stream::Stream;
6use futures::task::{Context, Poll};
7use std::pin::Pin;
8
9pub(crate) trait StreamUtil<S: Stream<Item = T> + Unpin, T> {
11 fn take_until_stop_after<F: Fn(&T) -> bool + Unpin>(
18 self,
19 stop_after: F,
20 ) -> TakeUntilStopAfterStream<F, S, T>;
21}
22
23impl<S, T> StreamUtil<S, T> for S
24where
25 S: Stream<Item = T> + Unpin,
26{
27 fn take_until_stop_after<F: Fn(&T) -> bool + Unpin>(
28 self,
29 stop_after_fn: F,
30 ) -> TakeUntilStopAfterStream<F, S, T> {
31 TakeUntilStopAfterStream { stop_after_fn, inner: self, stopped: false }
32 }
33}
34
35pub struct TakeUntilStopAfterStream<F, S, T>
37where
38 F: Fn(&T) -> bool + Unpin,
39 S: Stream<Item = T> + Unpin,
40{
41 stop_after_fn: F,
42 inner: S,
43 stopped: bool,
44}
45
46impl<F, S, T> Stream for TakeUntilStopAfterStream<F, S, T>
47where
48 F: Fn(&T) -> bool + Unpin,
49 S: Stream<Item = T> + Unpin,
50{
51 type Item = <S as Stream>::Item;
52
53 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 let self_mut = self.get_mut();
55 if self_mut.stopped {
56 return Poll::Ready(None);
57 }
58
59 let inner_poll = Pin::new(&mut self_mut.inner).poll_next(cx);
60 self_mut.stopped = match &inner_poll {
61 Poll::Ready(None) => true,
62 Poll::Ready(Some(item)) => (self_mut.stop_after_fn)(item),
63 Poll::Pending => false,
64 };
65 inner_poll
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use futures::stream::StreamExt;
73
74 #[fuchsia_async::run_singlethreaded(test)]
75 async fn stops_after_test_fn_returns_true() {
76 let stream = futures::stream::iter(0..u32::MAX);
77 let results: Vec<_> = stream.take_until_stop_after(|num| *num == 5).collect().await;
78 assert_eq!(vec![0, 1, 2, 3, 4, 5], results);
79 }
80
81 #[fuchsia_async::run_singlethreaded(test)]
82 async fn does_not_poll_after_test_fn_returns_true() {
83 let stream = futures::stream::iter(0..6).chain(futures::stream::pending());
84 let results: Vec<_> = stream.take_until_stop_after(|num| *num == 5).collect().await;
85 assert_eq!(vec![0, 1, 2, 3, 4, 5], results);
86 }
87}