1use core::future::Future;
8use core::pin::Pin;
9use core::sync::atomic::{AtomicBool, Ordering};
10use core::task::{Context, Poll};
11use std::sync::{Arc, Mutex};
12
13use fidl_next_codec::{Encode, EncodeError, EncoderExt};
14
15use crate::lockers::Lockers;
16use crate::{decode_header, encode_header, ProtocolError, SendFuture, Transport, TransportExt};
17
18use super::lockers::LockerError;
19
20struct Shared<T: Transport> {
21 is_closed: AtomicBool,
22 responses: Mutex<Lockers<T::RecvBuffer>>,
23}
24
25impl<T: Transport> Shared<T> {
26 fn new() -> Self {
27 Self { is_closed: AtomicBool::new(false), responses: Mutex::new(Lockers::new()) }
28 }
29}
30
31pub struct ClientSender<T: Transport> {
33 shared: Arc<Shared<T>>,
34 sender: T::Sender,
35}
36
37impl<T: Transport> ClientSender<T> {
38 pub fn close(&self) {
40 T::close(&self.sender);
41 }
42
43 pub fn send_one_way<M>(
45 &self,
46 ordinal: u64,
47 request: M,
48 ) -> Result<SendFuture<'_, T>, EncodeError>
49 where
50 M: Encode<T::SendBuffer>,
51 {
52 self.send_message(0, ordinal, request)
53 }
54
55 pub fn send_two_way<M>(
57 &self,
58 ordinal: u64,
59 request: M,
60 ) -> Result<ResponseFuture<'_, T>, EncodeError>
61 where
62 M: Encode<T::SendBuffer>,
63 {
64 let index = self.shared.responses.lock().unwrap().alloc(ordinal);
65
66 match self.send_message(index + 1, ordinal, request) {
68 Ok(future) => Ok(ResponseFuture {
69 shared: &self.shared,
70 index,
71 state: ResponseFutureState::Sending(future),
72 }),
73 Err(e) => {
74 self.shared.responses.lock().unwrap().free(index);
75 Err(e)
76 }
77 }
78 }
79
80 fn send_message<M>(
81 &self,
82 txid: u32,
83 ordinal: u64,
84 message: M,
85 ) -> Result<SendFuture<'_, T>, EncodeError>
86 where
87 M: Encode<T::SendBuffer>,
88 {
89 let mut buffer = T::acquire(&self.sender);
90 encode_header::<T>(&mut buffer, txid, ordinal)?;
91 buffer.encode_next(message)?;
92 Ok(T::send(&self.sender, buffer))
93 }
94}
95
96impl<T: Transport> Clone for ClientSender<T> {
97 fn clone(&self) -> Self {
98 Self { shared: self.shared.clone(), sender: self.sender.clone() }
99 }
100}
101
102enum ResponseFutureState<'a, T: Transport> {
103 Sending(SendFuture<'a, T>),
104 Receiving,
105 Completed,
108}
109
110pub struct ResponseFuture<'a, T: Transport> {
112 shared: &'a Shared<T>,
113 index: u32,
114 state: ResponseFutureState<'a, T>,
115}
116
117impl<T: Transport> Drop for ResponseFuture<'_, T> {
118 fn drop(&mut self) {
119 let mut responses = self.shared.responses.lock().unwrap();
120 match self.state {
121 ResponseFutureState::Sending(_) => responses.free(self.index),
124 ResponseFutureState::Receiving => {
125 if responses.get(self.index).unwrap().cancel() {
126 responses.free(self.index);
127 }
128 }
129 ResponseFutureState::Completed => (),
131 }
132 }
133}
134
135impl<T: Transport> ResponseFuture<'_, T> {
136 fn poll_receiving(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Future>::Output> {
137 if self.shared.is_closed.load(Ordering::Relaxed) {
138 self.state = ResponseFutureState::Completed;
139 return Poll::Ready(Err(None));
140 }
141
142 let mut responses = self.shared.responses.lock().unwrap();
143 if let Some(ready) = responses.get(self.index).unwrap().read(cx.waker()) {
144 responses.free(self.index);
145 self.state = ResponseFutureState::Completed;
146 Poll::Ready(Ok(ready))
147 } else {
148 Poll::Pending
149 }
150 }
151}
152
153impl<T: Transport> Future for ResponseFuture<'_, T> {
154 type Output = Result<T::RecvBuffer, Option<T::Error>>;
155
156 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
157 let this = unsafe { Pin::into_inner_unchecked(self) };
159
160 match &mut this.state {
161 ResponseFutureState::Sending(future) => {
162 let pinned = unsafe { Pin::new_unchecked(future) };
164 match pinned.poll(cx) {
165 Poll::Pending => Poll::Pending,
167 Poll::Ready(Ok(())) => {
168 this.state = ResponseFutureState::Receiving;
171 this.poll_receiving(cx)
172 }
173 Poll::Ready(Err(e)) => {
174 this.shared.responses.lock().unwrap().free(this.index);
178 this.state = ResponseFutureState::Completed;
179 Poll::Ready(Err(Some(e)))
180 }
181 }
182 }
183 ResponseFutureState::Receiving => this.poll_receiving(cx),
184 ResponseFutureState::Completed => unreachable!(),
187 }
188 }
189}
190
191pub trait ClientHandler<T: Transport> {
193 fn on_event(&mut self, sender: &ClientSender<T>, ordinal: u64, buffer: T::RecvBuffer);
199}
200
201pub struct Client<T: Transport> {
205 sender: ClientSender<T>,
206 receiver: T::Receiver,
207}
208
209impl<T: Transport> Client<T> {
210 pub fn new(transport: T) -> Self {
212 let (sender, receiver) = transport.split();
213 let shared = Arc::new(Shared::new());
214 Self { sender: ClientSender { shared, sender }, receiver }
215 }
216
217 pub fn sender(&self) -> &ClientSender<T> {
219 &self.sender
220 }
221
222 pub async fn run<H>(&mut self, mut handler: H) -> Result<(), ProtocolError<T::Error>>
224 where
225 H: ClientHandler<T>,
226 {
227 let result = self.run_to_completion(&mut handler).await;
228 self.sender.shared.is_closed.store(true, Ordering::Relaxed);
229 self.sender.shared.responses.lock().unwrap().wake_all();
230
231 result
232 }
233
234 pub async fn run_sender(&mut self) -> Result<(), ProtocolError<T::Error>> {
236 self.run(IgnoreEvents).await
237 }
238
239 async fn run_to_completion<H>(&mut self, handler: &mut H) -> Result<(), ProtocolError<T::Error>>
240 where
241 H: ClientHandler<T>,
242 {
243 while let Some(mut buffer) =
244 T::recv(&mut self.receiver).await.map_err(ProtocolError::TransportError)?
245 {
246 let (txid, ordinal) =
247 decode_header::<T>(&mut buffer).map_err(ProtocolError::InvalidMessageHeader)?;
248 if txid == 0 {
249 handler.on_event(&self.sender, ordinal, buffer);
250 } else {
251 let mut responses = self.sender.shared.responses.lock().unwrap();
252 let locker = responses
253 .get(txid - 1)
254 .ok_or_else(|| ProtocolError::UnrequestedResponse(txid))?;
255
256 match locker.write(ordinal, buffer) {
257 Ok(false) => (),
259 Ok(true) => responses.free(txid - 1),
261 Err(LockerError::NotWriteable) => {
262 return Err(ProtocolError::UnrequestedResponse(txid));
263 }
264 Err(LockerError::MismatchedOrdinal { expected, actual }) => {
265 return Err(ProtocolError::InvalidResponseOrdinal { expected, actual });
266 }
267 }
268 }
269 }
270
271 self.sender.close();
272
273 Ok(())
274 }
275}
276
277pub struct IgnoreEvents;
279
280impl<T: Transport> ClientHandler<T> for IgnoreEvents {
281 fn on_event(&mut self, _: &ClientSender<T>, _: u64, _: T::RecvBuffer) {}
282}