fidl_next_protocol/endpoints/
client.rs1use core::future::Future;
8use core::mem::ManuallyDrop;
9use core::pin::Pin;
10use core::ptr;
11use core::task::{Context, Poll, ready};
12
13use fidl_constants::EPITAPH_ORDINAL;
14use fidl_next_codec::{
15 AsDecoder as _, DecoderExt as _, Encode, EncodeError, EncoderExt, Wire, wire,
16};
17use fuchsia_loom::sync::{Arc, Mutex};
18use pin_project::{pin_project, pinned_drop};
19
20use crate::endpoints::connection::{Connection, SendFutureState};
21use crate::endpoints::lockers::{LockerError, Lockers};
22use crate::wire::{Epitaph, MessageHeader};
23use crate::{
24 Flexibility, FrameworkError, Message, NonBlockingTransport, ProtocolError, SendFuture,
25 Transport,
26};
27
28const ORD_FRAMEWORK_ERR: u64 = 3;
29const ERR_UNKNOWN_METHOD: i32 = FrameworkError::UnknownMethod as i32;
30
31struct ClientInner<T: Transport> {
32 connection: Connection<T>,
33 responses: Mutex<Lockers<Message<T>>>,
34}
35
36impl<T: Transport> ClientInner<T> {
37 fn new(shared: T::Shared) -> Self {
38 Self { connection: Connection::new(shared), responses: Mutex::new(Lockers::new()) }
39 }
40}
41
42pub struct Client<T: Transport> {
44 inner: Arc<ClientInner<T>>,
45}
46
47impl<T: Transport> Drop for Client<T> {
48 fn drop(&mut self) {
49 if Arc::strong_count(&self.inner) == 2 {
50 self.close();
53 }
54 }
55}
56
57impl<T: Transport> Client<T> {
58 pub fn close(&self) {
60 self.inner.connection.stop();
61 }
62
63 pub fn send_one_way<W>(
65 &self,
66 ordinal: u64,
67 flexibility: Flexibility,
68 request: impl Encode<W, T::SendBuffer>,
69 ) -> Result<SendFuture<'_, T>, EncodeError>
70 where
71 W: Wire<Constraint = ()>,
72 {
73 Ok(SendFuture::from_raw_parts(
74 &self.inner.connection,
75 self.send_message_raw(0, ordinal, flexibility, request)?,
76 ))
77 }
78
79 pub fn send_two_way<W>(
81 &self,
82 ordinal: u64,
83 flexibility: Flexibility,
84 request: impl Encode<W, T::SendBuffer>,
85 ) -> Result<TwoWayRequestFuture<'_, T>, EncodeError>
86 where
87 W: Wire<Constraint = ()>,
88 {
89 let index = self.inner.responses.lock().unwrap().alloc(ordinal);
90
91 match self.send_message_raw(index + 1, ordinal, flexibility, request) {
93 Ok(state) => Ok(TwoWayRequestFuture { inner: &self.inner, index: Some(index), state }),
94 Err(e) => {
95 self.inner.responses.lock().unwrap().free(index);
96 Err(e)
97 }
98 }
99 }
100
101 fn send_message_raw<W>(
102 &self,
103 txid: u32,
104 ordinal: u64,
105 flexibility: Flexibility,
106 message: impl Encode<W, T::SendBuffer>,
107 ) -> Result<SendFutureState<T>, EncodeError>
108 where
109 W: Wire<Constraint = ()>,
110 {
111 self.inner.connection.send_message_raw(|buffer| {
112 buffer.encode_next(MessageHeader::new(txid, ordinal, flexibility))?;
113 buffer.encode_next(message)
114 })
115 }
116}
117
118impl<T: Transport> Clone for Client<T> {
119 fn clone(&self) -> Self {
120 Self { inner: self.inner.clone() }
121 }
122}
123
124pub struct TwoWayResponseFuture<'a, T: Transport> {
126 inner: &'a ClientInner<T>,
127 index: Option<u32>,
128}
129
130impl<T: Transport> Drop for TwoWayResponseFuture<'_, T> {
131 fn drop(&mut self) {
132 if let Some(index) = self.index {
134 let mut responses = self.inner.responses.lock().unwrap();
135 if responses.get(index).unwrap().cancel() {
136 responses.free(index);
137 }
138 }
139 }
140}
141
142fn handle_flexible_response<T: Transport>(
143 mut message: Message<T>,
144) -> Result<Message<T>, ProtocolError<T::Error>> {
145 if matches!(message.header().flexibility(), Flexibility::Flexible) {
146 let mut decoder = message.as_decoder();
147 let mut union = decoder
148 .take_slot::<wire::Union>()
149 .map_err(|_| ProtocolError::InvalidFlexibleResponse)?;
150 let ordinal = wire::Union::encoded_ordinal(union.as_mut());
151 if ordinal == ORD_FRAMEWORK_ERR {
152 wire::Union::decode_as::<_, wire::Int32>(union.as_mut(), &mut decoder, ())
153 .map_err(|_| ProtocolError::InvalidFlexibleResponse)?;
154 let union = unsafe { union.deref_unchecked() };
155 let error = unsafe { union.get().read_unchecked::<wire::Int32>() };
156 match error.0 {
157 ERR_UNKNOWN_METHOD => {
158 return Err(ProtocolError::FrameworkError(FrameworkError::UnknownMethod));
159 }
160 ordinal => return Err(ProtocolError::UnknownFrameworkError { ordinal }),
161 }
162 }
163 }
164
165 Ok(message)
166}
167
168impl<T: Transport> Future for TwoWayResponseFuture<'_, T> {
169 type Output = Result<Message<T>, ProtocolError<T::Error>>;
170
171 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172 let this = Pin::into_inner(self);
173 let Some(index) = this.index else {
174 panic!("TwoWayResponseFuture polled after returning `Poll::Ready`");
175 };
176
177 let mut responses = this.inner.responses.lock().unwrap();
178 let message = if let Some(message) = responses.get(index).unwrap().read(cx.waker()) {
179 handle_flexible_response(message)
180 } else if let Some(termination_reason) = this.inner.connection.get_termination_reason() {
181 Err(termination_reason)
182 } else {
183 return Poll::Pending;
184 };
185
186 responses.free(index);
187 this.index = None;
188
189 Poll::Ready(message)
190 }
191}
192
193#[pin_project(PinnedDrop)]
195pub struct TwoWayRequestFuture<'a, T: Transport> {
196 inner: &'a ClientInner<T>,
197 index: Option<u32>,
198 #[pin]
199 state: SendFutureState<T>,
200}
201
202#[pinned_drop]
203impl<T: Transport> PinnedDrop for TwoWayRequestFuture<'_, T> {
204 fn drop(self: Pin<&mut Self>) {
205 if let Some(index) = self.index {
206 let mut responses = self.inner.responses.lock().unwrap();
207
208 responses.free(index);
211 }
212 }
213}
214
215impl<'a, T: Transport> Future for TwoWayRequestFuture<'a, T> {
216 type Output = Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>>;
217
218 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 let this = self.project();
220
221 let Some(index) = *this.index else {
222 panic!("TwoWayRequestFuture polled after returning `Poll::Ready`");
223 };
224
225 let result = ready!(this.state.poll_send(cx, &this.inner.connection));
226 *this.index = None;
227 if let Err(error) = result {
228 this.inner.responses.lock().unwrap().free(index);
230 Poll::Ready(Err(error))
231 } else {
232 Poll::Ready(Ok(TwoWayResponseFuture { inner: this.inner, index: Some(index) }))
233 }
234 }
235}
236
237impl<'a, T: NonBlockingTransport> TwoWayRequestFuture<'a, T> {
238 pub fn send_immediately(self) -> Result<TwoWayResponseFuture<'a, T>, ProtocolError<T::Error>> {
249 let inner = self.inner;
250 let index = self.index;
251 let state = unsafe { ptr::read(&ManuallyDrop::new(self).state) };
252 if let Err(e) = state.send_immediately(&inner.connection) {
253 inner.responses.lock().unwrap().free(index.unwrap());
254 return Err(e);
255 }
256
257 Ok(TwoWayResponseFuture { inner, index })
258 }
259}
260
261pub trait LocalClientHandler<T: Transport> {
266 fn on_event(
270 &mut self,
271 message: Message<T>,
272 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>>;
273}
274
275pub trait ClientHandler<T: Transport>: Send {
277 fn on_event(
283 &mut self,
284 message: Message<T>,
285 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> + Send;
286}
287
288#[repr(transparent)]
290pub struct ClientHandlerToLocalAdapter<H>(H);
291
292impl<T, H> LocalClientHandler<T> for ClientHandlerToLocalAdapter<H>
293where
294 T: Transport,
295 H: ClientHandler<T>,
296{
297 #[inline]
298 fn on_event(
299 &mut self,
300 message: Message<T>,
301 ) -> impl Future<Output = Result<(), ProtocolError<T::Error>>> {
302 self.0.on_event(message)
303 }
304}
305
306pub struct ClientDispatcher<T: Transport> {
315 inner: Arc<ClientInner<T>>,
316 exclusive: T::Exclusive,
317 is_terminated: bool,
318}
319
320impl<T: Transport> Drop for ClientDispatcher<T> {
321 fn drop(&mut self) {
322 if !self.is_terminated {
323 unsafe {
325 self.terminate(ProtocolError::Stopped);
326 }
327 }
328 }
329}
330
331impl<T: Transport> ClientDispatcher<T> {
332 pub fn new(transport: T) -> Self {
334 let (shared, exclusive) = transport.split();
335 Self { inner: Arc::new(ClientInner::new(shared)), exclusive, is_terminated: false }
336 }
337
338 unsafe fn terminate(&mut self, error: ProtocolError<T::Error>) {
342 unsafe {
344 self.inner.connection.terminate(error);
345 }
346 self.inner.responses.lock().unwrap().wake_all();
347 }
348
349 pub fn client(&self) -> Client<T> {
353 Client { inner: self.inner.clone() }
354 }
355
356 pub async fn run<H>(self, handler: H) -> Result<H, ProtocolError<T::Error>>
358 where
359 H: ClientHandler<T>,
360 {
361 self.run_local(ClientHandlerToLocalAdapter(handler)).await.map(|adapter| adapter.0)
364 }
365
366 pub async fn run_local<H>(mut self, mut handler: H) -> Result<H, ProtocolError<T::Error>>
368 where
369 H: LocalClientHandler<T>,
370 {
371 let error = loop {
377 let result = unsafe { self.run_one(&mut handler).await };
379 if let Err(error) = result {
380 break error;
381 }
382 };
383
384 unsafe {
386 self.terminate(error.clone());
387 }
388 self.is_terminated = true;
389
390 match error {
391 ProtocolError::Stopped => Ok(handler),
394
395 _ => Err(error),
397 }
398 }
399
400 async unsafe fn run_one<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
404 where
405 H: LocalClientHandler<T>,
406 {
407 let buffer = unsafe { self.inner.connection.recv(&mut self.exclusive).await? };
409
410 let mut message =
422 Message::<T>::decode(buffer).map_err(ProtocolError::InvalidMessageHeader)?;
423 let txid = *message.header().txid;
424 let ordinal = *message.header().ordinal;
425
426 if ordinal == EPITAPH_ORDINAL {
429 let mut decoder = message.as_decoder();
430 let epitaph = decoder.decode::<Epitaph>().map_err(ProtocolError::InvalidEpitaphBody)?;
431 return Err(ProtocolError::PeerClosedWithEpitaph(*epitaph.error));
432 }
433
434 if txid == 0 {
435 handler.on_event(message).await?;
436 } else {
437 let mut responses = self.inner.responses.lock().unwrap();
438 let locker = responses
439 .get(txid - 1)
440 .ok_or_else(|| ProtocolError::UnrequestedResponse { txid })?;
441
442 match locker.write(ordinal, message) {
443 Ok(false) => (),
445 Ok(true) => responses.free(txid - 1),
447 Err(LockerError::NotWriteable) => {
448 return Err(ProtocolError::UnrequestedResponse { txid });
449 }
450 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
451 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
452 }
453 }
454 }
455
456 Ok(())
457 }
458}