1use core::future::Future;
8use core::pin::Pin;
9use core::sync::atomic::{AtomicBool, Ordering};
10use core::task::{Context, Poll};
11use std::sync::{Arc, Mutex};
12
13use fidl_next_codec::{Encode, EncodeError, EncoderExt};
14
15use crate::lockers::Lockers;
16use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
17
18use super::lockers::LockerError;
19
20struct Shared<T: Transport> {
21 is_closed: AtomicBool,
22 responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> Shared<T> {
26 fn new() -> Self {
27 Self { is_closed: AtomicBool::new(false), responses: Mutex::new(Lockers::new()) }
28 }
29}
30
31pub struct ClientSender<T: Transport> {
33 shared: Arc<Shared<T>>,
34 sender: T::Sender,
35}
36
37impl<T: Transport> ClientSender<T> {
38 pub fn close(&self) {
40 T::close(&self.sender);
41 }
42
43 pub fn send_one_way<M>(
45 &self,
46 ordinal: u64,
47 request: M,
48 ) -> Result<SendFuture<'_, T>, EncodeError>
49 where
50 M: Encode<T::SendBuffer>,
51 {
52 self.send_message(0, ordinal, request)
53 }
54
55 pub fn send_two_way<M>(
57 &self,
58 ordinal: u64,
59 request: M,
60 ) -> Result<ResponseFuture<'_, T>, EncodeError>
61 where
62 M: Encode<T::SendBuffer>,
63 {
64 let index = self.shared.responses.lock().unwrap().alloc(ordinal);
65
66 match self.send_message(index + 1, ordinal, request) {
68 Ok(future) => Ok(ResponseFuture {
69 shared: &self.shared,
70 index,
71 state: ResponseFutureState::Sending(future),
72 }),
73 Err(e) => {
74 self.shared.responses.lock().unwrap().free(index);
75 Err(e)
76 }
77 }
78 }
79
80 fn send_message<M>(
81 &self,
82 txid: u32,
83 ordinal: u64,
84 message: M,
85 ) -> Result<SendFuture<'_, T>, EncodeError>
86 where
87 M: Encode<T::SendBuffer>,
88 {
89 let mut buffer = T::acquire(&self.sender);
90 encode_header::<T>(&mut buffer, txid, ordinal)?;
91 buffer.encode_next(message)?;
92 Ok(T::send(&self.sender, buffer))
93 }
94}
95
96impl<T: Transport> Clone for ClientSender<T> {
97 fn clone(&self) -> Self {
98 Self { shared: self.shared.clone(), sender: self.sender.clone() }
99 }
100}
101
102enum ResponseFutureState<'a, T: Transport> {
103 Sending(SendFuture<'a, T>),
104 Receiving,
105 Completed,
108}
109
110pub struct ResponseFuture<'a, T: Transport> {
112 shared: &'a Shared<T>,
113 index: u32,
114 state: ResponseFutureState<'a, T>,
115}
116
117impl<T: Transport> Drop for ResponseFuture<'_, T> {
118 fn drop(&mut self) {
119 let mut responses = self.shared.responses.lock().unwrap();
120 match self.state {
121 ResponseFutureState::Sending(_) => responses.free(self.index),
124 ResponseFutureState::Receiving => {
125 if responses.get(self.index).unwrap().cancel() {
126 responses.free(self.index);
127 }
128 }
129 ResponseFutureState::Completed => (),
131 }
132 }
133}
134
135impl<T: Transport> ResponseFuture<'_, T> {
136 fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
137 if self.shared.is_closed.load(Ordering::Relaxed) {
138 self.state = ResponseFutureState::Completed;
139 return Poll::Ready(Err(None));
140 }
141
142 let mut responses = self.shared.responses.lock().unwrap();
143 if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
144 responses.free(self.index);
145 self.state = ResponseFutureState::Completed;
146 Poll::Ready(Ok(ready))
147 } else {
148 Poll::Pending
149 }
150 }
151}
152
153impl<T: Transport> Future for ResponseFuture<'_, T> {
154 type Output = Result<T::RecvBuffer, Option<T::Error>>;
155
156 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157 let this = unsafe { Pin::into_inner_unchecked(self) };
159
160 match &mut this.state {
161 ResponseFutureState::Sending(future) => {
162 let pinned = unsafe { Pin::new_unchecked(future) };
164 match pinned.poll(cx) {
165 Poll::Pending => Poll::Pending,
167 Poll::Ready(Ok(())) => {
168 this.state = ResponseFutureState::Receiving;
171 this.poll_receiving(cx)
172 }
173 Poll::Ready(Err(e)) => {
174 this.shared.responses.lock().unwrap().free(this.index);
178 this.state = ResponseFutureState::Completed;
179 Poll::Ready(Err(Some(e)))
180 }
181 }
182 }
183 ResponseFutureState::Receiving => this.poll_receiving(cx),
184 ResponseFutureState::Completed => unreachable!(),
187 }
188 }
189}
190
191pub trait ClientHandler<T: Transport> {
193 fn on_event(
198 &mut self,
199 sender: &ClientSender<T>,
200 ordinal: u64,
201 buffer: T::RecvBuffer,
202 ) -> impl Future<Output = ()> + Send;
203}
204
205pub struct Client<T: Transport> {
209 sender: ClientSender<T>,
210 receiver: T::Receiver,
211}
212
213impl<T: Transport> Client<T> {
214 pub fn new(transport: T) -> Self {
216 let (sender, receiver) = transport.split();
217 let shared = Arc::new(Shared::new());
218 Self { sender: ClientSender { shared, sender }, receiver }
219 }
220
221 pub fn sender(&self) -> &ClientSender<T> {
223 &self.sender
224 }
225
226 pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
228 where
229 H: ClientHandler<T>,
230 {
231 let result = self.run_to_completion(&mut handler).await;
232 self.sender.shared.is_closed.store(true, Ordering::Relaxed);
233 self.sender.shared.responses.lock().unwrap().wake_all();
234
235 result
236 }
237
238 pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
240 self.run(IgnoreEvents).await
241 }
242
243 async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
244 where
245 H: ClientHandler<T>,
246 {
247 while let Some(mut buffer) =
248 T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
249 {
250 let (txid, ordinal) =
251 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
252 if txid == 0 {
253 handler.on_event(&self.sender, ordinal, buffer).await;
254 } else {
255 let mut responses = self.sender.shared.responses.lock().unwrap();
256 let locker = responses
257 .get(txid - 1)
258 .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
259
260 match locker.write(ordinal, buffer) {
261 Ok(false) => (),
263 Ok(true) => responses.free(txid - 1),
265 Err(LockerError::NotWriteable) => {
266 return Err(ProtocolError::UnrequestedResponse(txid));
267 }
268 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
269 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
270 }
271 }
272 }
273 }
274
275 self.sender.close();
276
277 Ok(())
278 }
279}
280
281pub struct IgnoreEvents;
283
284impl<T: Transport> ClientHandler<T> for IgnoreEvents {
285 async fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
286}