1use core::future::Future;
6use core::marker::PhantomData;
7use core::pin::Pin;
8use core::task::{Context, Poll, ready};
9
10use fidl_next_codec::{
11 AsDecoder, AsDecoderExt as _, Constrained, Decode, Decoded, EncodeError, FromWire, IntoNatural,
12 Wire,
13};
14use fidl_next_protocol::{Message, NonBlockingTransport, Transport};
15use pin_project::pin_project;
16
17use crate::{Error, Response, TwoWayMethod};
18
19#[pin_project(project = TwoWayFutureStateProj, project_replace = TwoWayFutureStateOwn)]
20enum TwoWayFutureState<'a, T: Transport> {
21 EncodeError(EncodeError),
22 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>),
23 SendingRequest(#[pin] fidl_next_protocol::TwoWayRequestFuture<'a, T>),
24 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>),
25 ReceivingResponse(#[pin] fidl_next_protocol::TwoWayResponseFuture<'a, T>),
26 DecodeBuffer(Message<T>),
27 Finished,
28}
29
30macro_rules! impl_two_way_future_state {
31 ($(
32 $variant:ident($ty:ty) => $check:ident $unwrap:ident
33 ),* $(,)?) => {
34 impl<T: Transport> TwoWayFutureState<'_, T> {
35 $(
36 #[allow(dead_code)]
37 fn $check(&self) -> bool {
38 matches!(self, Self::$variant(_))
39 }
40 )*
41 }
42
43 impl<'a, T: Transport> TwoWayFutureStateOwn<'a, T> {
44 $(
45 #[allow(dead_code)]
46 fn $unwrap(self) -> $ty {
47 let Self::$variant(value) = self else {
48 unreachable!()
49 };
50 value
51 }
52 )*
53 }
54 };
55}
56
57impl_two_way_future_state! {
58 EncodeError(EncodeError) => is_encode_error unwrap_encode_error,
59 SendRequest(fidl_next_protocol::TwoWayRequestFuture<'a, T>)
60 => is_send_request unwrap_send_request,
61 ReceiveResponse(fidl_next_protocol::TwoWayResponseFuture<'a, T>)
62 => is_receive_response unwrap_receive_response,
63 DecodeBuffer(Message<T>) => is_decode_buffer unwrap_decode_buffer,
64}
65
66impl<'a, T: Transport> TwoWayFutureState<'a, T> {
67 fn finish(self: Pin<&mut Self>) -> TwoWayFutureStateOwn<'a, T> {
68 self.project_replace(Self::Finished)
69 }
70
71 fn poll_advance(
72 mut self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 ) -> Poll<Result<(), Error<T::Error>>> {
75 Poll::Ready(match self.as_mut().project() {
76 TwoWayFutureStateProj::EncodeError(_) => {
77 Err(Error::Encode(self.finish().unwrap_encode_error()))
78 }
79 TwoWayFutureStateProj::SendRequest(_) => {
80 let future = self.as_mut().finish().unwrap_send_request();
81 self.project_replace(Self::SendingRequest(future));
82 Ok(())
83 }
84 TwoWayFutureStateProj::SendingRequest(future) => match ready!(future.poll(cx)) {
85 Ok(future) => {
86 self.project_replace(Self::ReceiveResponse(future));
87 Ok(())
88 }
89 Err(error) => {
90 self.finish();
91 Err(Error::Protocol(error))
92 }
93 },
94 TwoWayFutureStateProj::ReceiveResponse(_) => {
95 let future = self.as_mut().finish().unwrap_receive_response();
96 self.project_replace(Self::ReceivingResponse(future));
97 Ok(())
98 }
99 TwoWayFutureStateProj::ReceivingResponse(future) => match ready!(future.poll(cx)) {
100 Ok(body) => {
101 self.project_replace(Self::DecodeBuffer(body));
102 Ok(())
103 }
104 Err(error) => {
105 self.finish();
106 Err(Error::Protocol(error))
107 }
108 },
109 TwoWayFutureStateProj::DecodeBuffer(_) | TwoWayFutureStateProj::Finished => {
110 panic!("TwoWayFutureState polled after completing");
111 }
112 })
113 }
114
115 fn poll_until(
116 mut self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 is_done: impl Fn(&Self) -> bool,
119 ) -> Poll<Result<TwoWayFutureStateOwn<'a, T>, Error<T::Error>>> {
120 while !is_done(&self) {
121 if let Err(error) = ready!(self.as_mut().poll_advance(cx)) {
122 return Poll::Ready(Err(error));
123 }
124 }
125 Poll::Ready(Ok(self.finish()))
126 }
127}
128
129macro_rules! two_way_futures {
130 ($(
131 $(#[$metas:meta])* $future:ident -> $output:ty
132 where [$($tt:tt)*]
133 {
134 $check:ident => |$state:ident| $expr:expr
135 }
136 ),* $(,)?) => {
137 $(
138 $(#[$metas])*
139 #[must_use = "futures do nothing unless polled"]
140 #[pin_project]
141 pub struct $future<
142 'a,
143 M,
144 T: Transport,
145 > {
146 #[pin]
147 state: TwoWayFutureState<'a, T>,
148 _method: PhantomData<M>,
149 }
150
151 impl<'a, M, T> Future for $future<'a, M, T>
152 where
153 T: Transport,
154 $($tt)*
155 {
156 type Output = Result<$output, Error<T::Error>>;
157
158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 let $state = ready!(self.project().state.poll_until(
160 cx,
161 TwoWayFutureState::$check,
162 ))?;
163 Poll::Ready(Ok($expr))
164 }
165 }
166 )*
167 }
168}
169
170two_way_futures! {
171 TwoWayFuture -> <<M::Response as Response>::Payload as IntoNatural>::Natural
175 where [
176 M: TwoWayMethod,
177 M::Response: Wire<Constraint = ()>,
178 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
179 <M::Response as Response>::Payload: Wire + IntoNatural,
180 <<M::Response as Response>::Payload as IntoNatural>::Natural: for<'de> FromWire<<<M::Response as Response>::Payload as Wire>::Narrowed<'de>>,
181 ]
182 {
183 is_decode_buffer => |state| Response::into_payload(state.unwrap_decode_buffer().into_decoded::<M::Response>()?).take()
184 },
185
186 EncodedTwoWayFuture -> <<M::Response as Response>::Payload as IntoNatural>::Natural
193 where [
194 M: TwoWayMethod,
195 M::Response: Wire<Constraint = ()>,
196 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
197 <M::Response as Response>::Payload: Wire + IntoNatural,
198 <<M::Response as Response>::Payload as IntoNatural>::Natural: for<'de> FromWire<<<M::Response as Response>::Payload as Wire>::Narrowed<'de>>,
199 ]
200 {
201 is_decode_buffer => |state| Response::into_payload(state.unwrap_decode_buffer().into_decoded::<M::Response>()?).take()
202 },
203
204 SendTwoWayFuture -> SentTwoWayFuture<'a, M, T>
210 where []
211 {
212 is_receive_response => |state| SentTwoWayFuture {
213 state: TwoWayFutureState::ReceiveResponse(state.unwrap_receive_response()),
214 _method: PhantomData,
215 }
216 },
217
218 SentTwoWayFuture -> <<M::Response as Response>::Payload as IntoNatural>::Natural
225 where [
226 M: TwoWayMethod,
227 M::Response: Wire<Constraint = ()>,
228 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
229 <M::Response as Response>::Payload: Wire + IntoNatural,
230 <<M::Response as Response>::Payload as IntoNatural>::Natural: for<'de> FromWire<<<M::Response as Response>::Payload as Wire>::Narrowed<'de>>,
231 ]
232 {
233 is_decode_buffer => |state| Response::into_payload(state.unwrap_decode_buffer().into_decoded::<M::Response>()?).take()
234 },
235
236 RecvBufferTwoWayFuture -> Message<T>
242 where []
243 {
244 is_decode_buffer => |state| state.unwrap_decode_buffer()
245 },
246
247 WireTwoWayFuture -> Decoded<<M::Response as Response>::Payload, Message<T>>
253 where [
254 M: TwoWayMethod,
255 M::Response: Wire<Constraint = ()>,
256 for<'de> <M::Response as Wire>::Narrowed<'de>: Decode<<T::RecvBuffer as AsDecoder<'de>>::Decoder, Constraint = <M::Response as Constrained>::Constraint>,
257 <M::Response as Response>::Payload: Wire + IntoNatural,
258 <<M::Response as Response>::Payload as IntoNatural>::Natural: for<'de> FromWire<<<M::Response as Response>::Payload as Wire>::Narrowed<'de>>,
259 ]
260 {
261 is_decode_buffer => |state| Response::into_payload(state.unwrap_decode_buffer().into_decoded::<M::Response>()?)
262 }
263}
264
265macro_rules! impl_for_futures {
266 (
267 $($futures:ident)*,
268 $encode:item
269 ) => {
270 $(
271 impl<'a, M, T: Transport> $futures<'a, M, T> {
272 $encode
273 }
274 )*
275 }
276}
277
278impl_for_futures! {
283 TwoWayFuture,
284
285 pub fn encode(self) -> Result<EncodedTwoWayFuture<'a, M, T>, Error<T::Error>> {
289 Ok(EncodedTwoWayFuture {
290 state: match self.state {
291 TwoWayFutureState::EncodeError(error) => return Err(Error::Encode(error)),
292 state => state,
293 },
294 _method: PhantomData,
295 })
296 }
297}
298
299impl_for_futures! {
300 TwoWayFuture EncodedTwoWayFuture,
301
302 pub fn send(self) -> SendTwoWayFuture<'a, M, T> {
306 SendTwoWayFuture {
307 state: self.state,
308 _method: PhantomData,
309 }
310 }
311}
312
313impl_for_futures! {
314 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
315
316 pub fn recv_buffer(self) -> RecvBufferTwoWayFuture<'a, M, T> {
320 RecvBufferTwoWayFuture {
321 state: self.state,
322 _method: PhantomData,
323 }
324 }
325}
326
327impl_for_futures! {
328 TwoWayFuture EncodedTwoWayFuture SentTwoWayFuture,
329
330 pub fn wire(self) -> WireTwoWayFuture<'a, M, T> {
335 WireTwoWayFuture {
336 state: self.state,
337 _method: PhantomData,
338 }
339 }
340}
341
342impl<'a, M, T: Transport> TwoWayFuture<'a, M, T> {
343 pub fn from_untyped(
345 result: Result<fidl_next_protocol::TwoWayRequestFuture<'a, T>, EncodeError>,
346 ) -> Self {
347 Self {
348 state: match result {
349 Ok(future) => TwoWayFutureState::SendRequest(future),
350 Err(error) => TwoWayFutureState::EncodeError(error),
351 },
352 _method: PhantomData,
353 }
354 }
355}
356
357impl<'a, M, T: NonBlockingTransport> SendTwoWayFuture<'a, M, T> {
358 pub fn send_immediately(self) -> Result<SentTwoWayFuture<'a, M, T>, Error<T::Error>> {
369 match self.state {
370 TwoWayFutureState::EncodeError(e) => Err(Error::Encode(e)),
371 TwoWayFutureState::SendRequest(future) => Ok(SentTwoWayFuture {
372 state: TwoWayFutureState::ReceiveResponse(future.send_immediately()?),
373 _method: PhantomData,
374 }),
375 _ => unreachable!(),
376 }
377 }
378}