fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::pin::Pin;
9use core::task::{Context, Poll, ready};
10
11use fidl_next_codec::{Constrained, Encode, EncodeError, EncoderExt};
12use pin_project::{pin_project, pinned_drop};
13
14use crate::concurrency::sync::{Arc, Mutex};
15use crate::endpoints::connection::{Connection, ORDINAL_EPITAPH};
16use crate::endpoints::lockers::{LockerError, Lockers};
17use crate::{ProtocolError, SendFuture, Transport, decode_epitaph, decode_header, encode_header};
18
19struct ClientInner<T: Transport> {
20 connection: Connection<T>,
21 responses: Mutex<Lockers<T::RecvBuffer>>,
22}
23
24impl<T: Transport> ClientInner<T> {
25 fn new(shared: T::Shared) -> Self {
26 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
27 }
28}
29
30pub struct Client<T: Transport> {
32 inner: Arc<ClientInner<T>>,
33}
34
35impl<T: Transport> Drop for Client<T> {
36 fn drop(&mut self) {
37 if Arc::strong_count(&self.inner) == 2 {
38 self.close();
41 }
42 }
43}
44
45impl<T: Transport> Client<T> {
46 pub fn close(&self) {
48 self.inner.connection.stop();
49 }
50
51 pub fn send_one_way<M>(
53 &self,
54 ordinal: u64,
55 request: M,
56 ) -> Result<SendFuture<'_, T>, EncodeError>
57 where
58 M: Encode<T::SendBuffer>,
59 M::Encoded: Constrained<Constraint = ()>,
60 {
61 self.send_message(0, ordinal, request)
62 }
63
64 pub fn send_two_way<M>(
66 &self,
67 ordinal: u64,
68 request: M,
69 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
70 where
71 M: Encode<T::SendBuffer>,
72 M::Encoded: Constrained<Constraint = ()>,
73 {
74 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
75
76 match self.send_message(index + 1, ordinal, request) {
78 Ok(send_future) => {
79 Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), send_future })
80 }
81 Err(e) => {
82 self.inner.responses.lock().unwrap().free(index);
83 Err(e)
84 }
85 }
86 }
87
88 fn send_message<M>(
89 &self,
90 txid: u32,
91 ordinal: u64,
92 message: M,
93 ) -> Result<SendFuture<'_, T>, EncodeError>
94 where
95 M: Encode<T::SendBuffer>,
96 M::Encoded: Constrained<Constraint = ()>,
97 {
98 self.inner.connection.send_message(|buffer| {
99 encode_header::<T>(buffer, txid, ordinal)?;
100 buffer.encode_next(message, ())
101 })
102 }
103}
104
105impl<T: Transport> Clone for Client<T> {
106 fn clone(&self) -> Self {
107 Self { inner: self.inner.clone() }
108 }
109}
110
111pub struct TwoWayResponseFuture<'a, T: Transport> {
113 inner: &'a ClientInner<T>,
114 index: Option<u32>,
115}
116
117impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
118 fn drop(&mut self) {
119 if let Some(index) = self.index {
121 let mut responses = self.inner.responses.lock().unwrap();
122 if responses.get(index).unwrap().cancel() {
123 responses.free(index);
124 }
125 }
126 }
127}
128
129impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
130 type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
131
132 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133 let this = Pin::into_inner(self);
134 let Some(index) = this.index else {
135 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
136 };
137
138 let mut responses = this.inner.responses.lock().unwrap();
139 let ready = if let Some(ready) = responses.get(index).unwrap().read(cx.waker()) {
140 Ok(ready)
141 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
142 Err(termination_reason)
143 } else {
144 return Poll::Pending;
145 };
146
147 responses.free(index);
148 this.index = None;
149 Poll::Ready(ready)
150 }
151}
152
153#[pin_project(PinnedDrop)]
155pub struct TwoWayRequestFuture<'a, T: Transport> {
156 inner: &'a ClientInner<T>,
157 index: Option<u32>,
158 #[pin]
159 send_future: SendFuture<'a, T>,
160}
161
162#[pinned_drop]
163impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
164 fn drop(self: Pin<&mut Self>) {
165 if let Some(index) = self.index {
166 let mut responses = self.inner.responses.lock().unwrap();
167
168 responses.free(index);
171 }
172 }
173}
174
175impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
176 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
177
178 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 let this = self.project();
180
181 let Some(index) = *this.index else {
182 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
183 };
184
185 let result = ready!(this.send_future.poll(cx));
186 *this.index = None;
187 if let Err(error) = result {
188 this.inner.responses.lock().unwrap().free(index);
190 Poll::Ready(Err(error))
191 } else {
192 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
193 }
194 }
195}
196
197pub trait ClientHandler<T: Transport> {
199 fn on_event(
206 &mut self,
207 ordinal: u64,
208 buffer: T::RecvBuffer,
209 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
210}
211
212pub struct ClientDispatcher<T: Transport> {
221 inner: Arc<ClientInner<T>>,
222 exclusive: T::Exclusive,
223 is_terminated: bool,
224}
225
226impl<T: Transport> Drop for ClientDispatcher<T> {
227 fn drop(&mut self) {
228 if !self.is_terminated {
229 unsafe {
231 self.terminate(ProtocolError::Stopped);
232 }
233 }
234 }
235}
236
237impl<T: Transport> ClientDispatcher<T> {
238 pub fn new(transport: T) -> Self {
240 let (shared, exclusive) = transport.split();
241 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
242 }
243
244 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
248 unsafe {
250 self.inner.connection.terminate(error);
251 }
252 self.inner.responses.lock().unwrap().wake_all();
253 }
254
255 pub fn client(&self) -> Client<T> {
259 Client { inner: self.inner.clone() }
260 }
261
262 pub async fn run<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
264 where
265 H: ClientHandler<T>,
266 {
267 let error = loop {
273 let result = unsafe { self.run_one(&mut handler).await };
275 if let Err(error) = result {
276 break error;
277 }
278 };
279
280 unsafe {
282 self.terminate(error.clone());
283 }
284 self.is_terminated = true;
285
286 match error {
287 ProtocolError::Stopped => Ok(handler),
290
291 _ => Err(error),
293 }
294 }
295
296 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
300 where
301 H: ClientHandler<T>,
302 {
303 let mut buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
305
306 let (txid, ordinal) =
307 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
308
309 if ordinal == ORDINAL_EPITAPH {
310 let epitaph =
311 decode_epitaph::<T>(&mut buffer).map_err(ProtocolError::InvalidEpitaphBody)?;
312 return Err(ProtocolError::PeerClosedWithEpitaph(epitaph));
313 } else if txid == 0 {
314 handler.on_event(ordinal, buffer).await?;
315 } else {
316 let mut responses = self.inner.responses.lock().unwrap();
317 let locker = responses
318 .get(txid - 1)
319 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
320
321 match locker.write(ordinal, buffer) {
322 Ok(false) => (),
324 Ok(true) => responses.free(txid - 1),
326 Err(LockerError::NotWriteable) => {
327 return Err(ProtocolError::UnrequestedResponse { txid });
328 }
329 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
330 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
331 }
332 }
333 }
334
335 Ok(())
336 }
337
338 pub async fn run_client(self) -> Result<(), ProtocolError<T::Error>> {
340 self.run(IgnoreEvents).await.map(|_| ())
341 }
342}
343
344pub struct IgnoreEvents;
346
347impl<T: Transport> ClientHandler<T> for IgnoreEvents {
348 async fn on_event(&mut self, _: u64, _: T::RecvBuffer) -> Result<(), ProtocolError<T::Error>> {
349 Ok(())
350 }
351}