fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::mem::ManuallyDrop;
9use core::pin::Pin;
10use core::ptr;
11use core::task::{Context, Poll, ready};
12
13use fidl_constants::EPITAPH_ORDINAL;
14use fidl_next_codec::{AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire};
15use pin_project::{pin_project, pinned_drop};
16
17use crate::concurrency::sync::{Arc, Mutex};
18use crate::endpoints::connection::{Connection, SendFutureState};
19use crate::endpoints::lockers::{LockerError, Lockers};
20use crate::wire::{Epitaph, MessageHeader};
21use crate::{Body, Flexibility, NonBlockingTransport, ProtocolError, SendFuture, Transport};
22
23struct ClientInner<T: Transport> {
24 connection: Connection<T>,
25 responses: Mutex<Lockers<Body<T>>>,
26}
27
28impl<T: Transport> ClientInner<T> {
29 fn new(shared: T::Shared) -> Self {
30 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
31 }
32}
33
34pub struct Client<T: Transport> {
36 inner: Arc<ClientInner<T>>,
37}
38
39impl<T: Transport> Drop for Client<T> {
40 fn drop(&mut self) {
41 if Arc::strong_count(&self.inner) == 2 {
42 self.close();
45 }
46 }
47}
48
49impl<T: Transport> Client<T> {
50 pub fn close(&self) {
52 self.inner.connection.stop();
53 }
54
55 pub fn send_one_way<W>(
57 &self,
58 ordinal: u64,
59 flexibility: Flexibility,
60 request: impl Encode<W, T::SendBuffer>,
61 ) -> Result<SendFuture<'_, T>, EncodeError>
62 where
63 W: Wire<Constraint = ()>,
64 {
65 Ok(SendFuture::from_raw_parts(
66 &self.inner.connection,
67 self.send_message_raw(0, ordinal, flexibility, request)?,
68 ))
69 }
70
71 pub fn send_two_way<W>(
73 &self,
74 ordinal: u64,
75 flexibility: Flexibility,
76 request: impl Encode<W, T::SendBuffer>,
77 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
78 where
79 W: Wire<Constraint = ()>,
80 {
81 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
82
83 match self.send_message_raw(index + 1, ordinal, flexibility, request) {
85 Ok(state) => Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), state }),
86 Err(e) => {
87 self.inner.responses.lock().unwrap().free(index);
88 Err(e)
89 }
90 }
91 }
92
93 fn send_message_raw<W>(
94 &self,
95 txid: u32,
96 ordinal: u64,
97 flexibility: Flexibility,
98 message: impl Encode<W, T::SendBuffer>,
99 ) -> Result<SendFutureState<T>, EncodeError>
100 where
101 W: Wire<Constraint = ()>,
102 {
103 self.inner.connection.send_message_raw(|buffer| {
104 buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
105 buffer.encode_next(message)
106 })
107 }
108}
109
110impl<T: Transport> Clone for Client<T> {
111 fn clone(&self) -> Self {
112 Self { inner: self.inner.clone() }
113 }
114}
115
116pub struct TwoWayResponseFuture<'a, T: Transport> {
118 inner: &'a ClientInner<T>,
119 index: Option<u32>,
120}
121
122impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
123 fn drop(&mut self) {
124 if let Some(index) = self.index {
126 let mut responses = self.inner.responses.lock().unwrap();
127 if responses.get(index).unwrap().cancel() {
128 responses.free(index);
129 }
130 }
131 }
132}
133
134impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
135 type Output = Result<Body<T>, ProtocolError<T::Error>>;
136
137 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138 let this = Pin::into_inner(self);
139 let Some(index) = this.index else {
140 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
141 };
142
143 let mut responses = this.inner.responses.lock().unwrap();
144 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
145 Ok(ready)
146 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
147 Err(termination_reason)
148 } else {
149 return Poll::Pending;
150 };
151
152 responses.free(index);
153 this.index = None;
154 Poll::Ready(ready)
155 }
156}
157
158#[pin_project(PinnedDrop)]
160pub struct TwoWayRequestFuture<'a, T: Transport> {
161 inner: &'a ClientInner<T>,
162 index: Option<u32>,
163 #[pin]
164 state: SendFutureState<T>,
165}
166
167#[pinned_drop]
168impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
169 fn drop(self: Pin<&mut Self>) {
170 if let Some(index) = self.index {
171 let mut responses = self.inner.responses.lock().unwrap();
172
173 responses.free(index);
176 }
177 }
178}
179
180impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
181 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
182
183 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184 let this = self.project();
185
186 let Some(index) = *this.index else {
187 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
188 };
189
190 let result = ready!(this.state.poll_send(cx, &this.inner.connection));
191 *this.index = None;
192 if let Err(error) = result {
193 this.inner.responses.lock().unwrap().free(index);
195 Poll::Ready(Err(error))
196 } else {
197 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
198 }
199 }
200}
201
202impl<'a, T: NonBlockingTransport> TwoWayRequestFuture<'a, T> {
203 pub fn send_immediately(self) -> Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>> {
214 let inner = self.inner;
215 let index = self.index;
216 let state = unsafe { ptr::read(&ManuallyDrop::new(self).state) };
217 if let Err(e) = state.send_immediately(&inner.connection) {
218 inner.responses.lock().unwrap().free(index.unwrap());
219 return Err(e);
220 }
221
222 Ok(TwoWayResponseFuture { inner, index })
223 }
224}
225
226pub trait LocalClientHandler<T: Transport> {
231 fn on_event(
235 &mut self,
236 ordinal: u64,
237 flexibility: Flexibility,
238 body: Body<T>,
239 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
240}
241
242pub trait ClientHandler<T: Transport>: Send {
244 fn on_event(
250 &mut self,
251 ordinal: u64,
252 flexibility: Flexibility,
253 body: Body<T>,
254 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
255}
256
257#[repr(transparent)]
259pub struct ClientHandlerToLocalAdapter<H>(H);
260
261impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
262where
263 T: Transport,
264 H: ClientHandler<T>,
265{
266 #[inline]
267 fn on_event(
268 &mut self,
269 ordinal: u64,
270 flexibility: Flexibility,
271 body: Body<T>,
272 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
273 self.0.on_event(ordinal, flexibility, body)
274 }
275}
276
277pub struct ClientDispatcher<T: Transport> {
286 inner: Arc<ClientInner<T>>,
287 exclusive: T::Exclusive,
288 is_terminated: bool,
289}
290
291impl<T: Transport> Drop for ClientDispatcher<T> {
292 fn drop(&mut self) {
293 if !self.is_terminated {
294 unsafe {
296 self.terminate(ProtocolError::Stopped);
297 }
298 }
299 }
300}
301
302impl<T: Transport> ClientDispatcher<T> {
303 pub fn new(transport: T) -> Self {
305 let (shared, exclusive) = transport.split();
306 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
307 }
308
309 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
313 unsafe {
315 self.inner.connection.terminate(error);
316 }
317 self.inner.responses.lock().unwrap().wake_all();
318 }
319
320 pub fn client(&self) -> Client<T> {
324 Client { inner: self.inner.clone() }
325 }
326
327 pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
329 where
330 H: ClientHandler<T>,
331 {
332 self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
335 }
336
337 pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
339 where
340 H: LocalClientHandler<T>,
341 {
342 let error = loop {
348 let result = unsafe { self.run_one(&mut handler).await };
350 if let Err(error) = result {
351 break error;
352 }
353 };
354
355 unsafe {
357 self.terminate(error.clone());
358 }
359 self.is_terminated = true;
360
361 match error {
362 ProtocolError::Stopped => Ok(handler),
365
366 _ => Err(error),
368 }
369 }
370
371 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
375 where
376 H: LocalClientHandler<T>,
377 {
378 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
380
381 let header = {
393 let mut decoder = buffer.as_decoder();
394
395 let header = decoder
396 .decode_prefix::<MessageHeader>()
397 .map_err(ProtocolError::InvalidMessageHeader)?;
398
399 if header.ordinal == EPITAPH_ORDINAL {
403 let epitaph =
404 decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
405 return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
406 }
407
408 header
409 };
410
411 if header.txid == 0 {
412 handler.on_event(*header.ordinal, header.flexibility(), Body::new(buffer)).await?;
413 } else {
414 let mut responses = self.inner.responses.lock().unwrap();
415 let locker = responses
416 .get(*header.txid - 1)
417 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid: *header.txid })?;
418
419 match locker.write(*header.ordinal, Body::new(buffer)) {
420 Ok(false) => (),
422 Ok(true) => responses.free(*header.txid - 1),
424 Err(LockerError::NotWriteable) => {
425 return Err(ProtocolError::UnrequestedResponse { txid: *header.txid });
426 }
427 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
428 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
429 }
430 }
431 }
432
433 Ok(())
434 }
435}