1use futures::task::AtomicWaker;
6use std::num::NonZero;
7use std::ptr::NonNull;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::Poll;
11
12use fidl_next::Chunk;
13use zx::Status;
14
15use fdf_channel::arena::{Arena, ArenaBox};
16use fdf_channel::channel::Channel;
17use fdf_channel::futures::ReadMessageState;
18use fdf_channel::message::Message;
19use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
20use fdf_core::handle::{MixedHandle, MixedHandleType};
21
22pub struct DriverChannel<D = CurrentDispatcher> {
25 dispatcher: D,
26 channel: Channel<[Chunk]>,
27}
28
29impl<D> DriverChannel<D> {
30 pub fn new_with_dispatcher(dispatcher: D, channel: Channel<[Chunk]>) -> Self {
33 Self { dispatcher, channel }
34 }
35}
36
37impl DriverChannel<CurrentDispatcher> {
38 pub fn new(channel: Channel<[Chunk]>) -> Self {
41 Self::new_with_dispatcher(CurrentDispatcher, channel)
42 }
43}
44
45pub struct SendBuffer {
47 handles: Vec<Option<MixedHandle>>,
48 data: Vec<Chunk>,
49}
50
51impl SendBuffer {
52 fn new() -> Self {
53 Self { handles: Vec::new(), data: Vec::new() }
54 }
55}
56
57impl fidl_next::Encoder for SendBuffer {
58 #[inline]
59 fn bytes_written(&self) -> usize {
60 fidl_next::Encoder::bytes_written(&self.data)
61 }
62
63 #[inline]
64 fn write(&mut self, bytes: &[u8]) {
65 fidl_next::Encoder::write(&mut self.data, bytes)
66 }
67
68 #[inline]
69 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
70 fidl_next::Encoder::rewrite(&mut self.data, pos, bytes)
71 }
72
73 fn write_zeroes(&mut self, len: usize) {
74 fidl_next::Encoder::write_zeroes(&mut self.data, len);
75 }
76}
77
78impl fidl_next::encoder::InternalHandleEncoder for SendBuffer {
79 #[inline]
80 fn __internal_handle_count(&self) -> usize {
81 self.handles.len()
82 }
83}
84
85impl fidl_next::fuchsia::HandleEncoder for SendBuffer {
86 fn push_handle(&mut self, handle: zx::Handle) -> Result<(), fidl_next::EncodeError> {
87 if let Some(handle) = MixedHandle::from_zircon_handle(handle) {
88 if handle.is_driver() {
89 return Err(fidl_next::EncodeError::ExpectedZirconHandle);
90 }
91 self.handles.push(Some(handle));
92 } else {
93 self.handles.push(None);
94 }
95 Ok(())
96 }
97
98 fn push_raw_driver_handle(&mut self, handle: u32) -> Result<(), fidl_next::EncodeError> {
99 if let Some(handle) = NonZero::new(handle) {
100 let handle = unsafe { MixedHandle::from_raw(handle) };
103 if !handle.is_driver() {
104 return Err(fidl_next::EncodeError::ExpectedDriverHandle);
105 }
106 self.handles.push(Some(handle));
107 } else {
108 self.handles.push(None);
109 }
110 Ok(())
111 }
112
113 fn handles_pushed(&self) -> usize {
114 self.handles.len()
115 }
116}
117
118pub struct RecvBuffer {
119 buffer: Message<[Chunk]>,
120 data_offset: usize,
121 handle_offset: usize,
122}
123
124impl RecvBuffer {
125 fn take_next_handle(&mut self) -> Result<MixedHandle, fidl_next::DecodeError> {
126 let Some(handles) = self.buffer.handles_mut() else {
127 return Err(fidl_next::DecodeError::InsufficientHandles);
128 };
129 if handles.len() < self.handle_offset + 1 {
130 return Err(fidl_next::DecodeError::InsufficientHandles);
131 }
132 handles[self.handle_offset].take().ok_or(fidl_next::DecodeError::RequiredHandleAbsent)
133 }
134}
135
136unsafe impl fidl_next::Decoder for RecvBuffer {
141 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, fidl_next::DecodeError> {
144 let Some(data) = self.buffer.data_mut() else {
145 return Err(fidl_next::DecodeError::InsufficientData);
146 };
147 if data.len() < self.data_offset + count {
148 return Err(fidl_next::DecodeError::InsufficientData);
149 }
150 let pos = self.data_offset;
151 self.data_offset += count;
152 Ok(unsafe { NonNull::new_unchecked((&mut data[pos..(pos + count)]).as_mut_ptr()) })
153 }
154
155 fn finish(&mut self) -> Result<(), fidl_next::DecodeError> {
156 let data_len = self.buffer.data().unwrap_or(&[]).len();
157 if self.data_offset != data_len {
158 return Err(fidl_next::DecodeError::ExtraBytes {
159 num_extra: data_len - self.data_offset,
160 });
161 }
162 let handle_len = self.buffer.handles().unwrap_or(&[]).len();
163 if self.handle_offset != handle_len {
164 return Err(fidl_next::DecodeError::ExtraHandles {
165 num_extra: handle_len - self.handle_offset,
166 });
167 }
168 Ok(())
169 }
170}
171
172impl fidl_next::decoder::InternalHandleDecoder for RecvBuffer {
173 fn __internal_take_handles(&mut self, count: usize) -> Result<(), fidl_next::DecodeError> {
174 let Some(handles) = self.buffer.handles_mut() else {
175 return Err(fidl_next::DecodeError::InsufficientHandles);
176 };
177 if handles.len() < self.handle_offset + count {
178 return Err(fidl_next::DecodeError::InsufficientHandles);
179 }
180 let pos = self.handle_offset;
181 self.handle_offset = pos + count;
182 Ok(())
183 }
184
185 fn __internal_handles_remaining(&self) -> usize {
186 self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
187 }
188}
189
190impl fidl_next::fuchsia::HandleDecoder for RecvBuffer {
191 fn take_handle(&mut self) -> Result<zx::Handle, fidl_next::DecodeError> {
192 let handle = self.take_next_handle()?.resolve();
193 let MixedHandleType::Zircon(handle) = handle else {
194 return Err(fidl_next::DecodeError::ExpectedZirconHandle);
195 };
196 let pos = self.handle_offset;
197 self.handle_offset = pos + 1;
198 Ok(handle)
199 }
200
201 fn take_raw_driver_handle(&mut self) -> Result<u32, fidl_next::DecodeError> {
202 let handle = self.take_next_handle()?.resolve();
203 let MixedHandleType::Driver(handle) = handle else {
204 return Err(fidl_next::DecodeError::ExpectedDriverHandle);
205 };
206 let pos = self.handle_offset;
207 self.handle_offset = pos + 1;
208 Ok(handle.into_raw().get())
209 }
210
211 fn handles_remaining(&mut self) -> usize {
212 self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
213 }
214}
215
216pub struct DriverRecvState(ReadMessageState);
218
219struct Shared<D> {
220 is_closed: AtomicBool,
221 sender_count: AtomicUsize,
222 closed_waker: AtomicWaker,
223 channel: DriverChannel<D>,
224}
225
226impl<D> Shared<D> {
227 fn new(channel: DriverChannel<D>) -> Self {
228 Self {
229 is_closed: AtomicBool::new(false),
230 sender_count: AtomicUsize::new(1),
231 closed_waker: AtomicWaker::new(),
232 channel,
233 }
234 }
235
236 fn close(&self) {
237 self.is_closed.store(true, Ordering::Relaxed);
238 self.closed_waker.wake();
239 }
240}
241pub struct DriverSender<D> {
243 shared: Arc<Shared<D>>,
244}
245
246impl<D> Drop for DriverSender<D> {
247 fn drop(&mut self) {
248 let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
249 if senders == 1 {
250 self.shared.close();
251 }
252 }
253}
254
255impl<D> Clone for DriverSender<D> {
256 fn clone(&self) -> Self {
257 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
258 Self { shared: self.shared.clone() }
259 }
260}
261
262pub struct DriverReceiver<D> {
264 shared: Arc<Shared<D>>,
265}
266
267impl<D: OnDispatcher> fidl_next::protocol::Transport for DriverChannel<D> {
268 type Error = Status;
269
270 fn split(self) -> (Self::Sender, Self::Receiver) {
271 let shared = Arc::new(Shared::new(self));
272 let sender = DriverSender { shared: shared.clone() };
273 let receiver = DriverReceiver { shared };
274 (sender, receiver)
275 }
276
277 type Sender = DriverSender<D>;
278
279 type SendBuffer = SendBuffer;
280
281 type SendFutureState = SendBuffer;
282
283 fn acquire(_sender: &Self::Sender) -> Self::SendBuffer {
284 SendBuffer::new()
285 }
286
287 fn close(sender: &Self::Sender) {
288 sender.shared.close();
289 }
290
291 type Receiver = DriverReceiver<D>;
292
293 type RecvFutureState = DriverRecvState;
294
295 type RecvBuffer = RecvBuffer;
296
297 fn begin_send(_sender: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
298 buffer
299 }
300
301 fn poll_send(
302 mut buffer: std::pin::Pin<&mut Self::SendFutureState>,
303 _cx: &mut std::task::Context<'_>,
304 sender: &Self::Sender,
305 ) -> std::task::Poll<Result<(), Self::Error>> {
306 let arena = Arena::new();
307 let message = Message::new_with(arena, |arena| {
308 let data = arena.insert_slice(&buffer.data);
309 let handles = buffer.handles.split_off(0);
310 let handles = arena.insert_from_iter(handles.into_iter());
311 (Some(data), Some(handles))
312 });
313 Poll::Ready(sender.shared.channel.channel.write(message))
314 }
315
316 fn begin_recv(receiver: &mut Self::Receiver) -> Self::RecvFutureState {
317 let state =
320 unsafe { ReadMessageState::new(receiver.shared.channel.channel.driver_handle()) };
321 DriverRecvState(state)
322 }
323
324 fn poll_recv(
325 mut future: std::pin::Pin<&mut Self::RecvFutureState>,
326 cx: &mut std::task::Context<'_>,
327 receiver: &mut Self::Receiver,
328 ) -> std::task::Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
329 use std::task::Poll::*;
330 match future.as_mut().0.poll_with_dispatcher(cx, receiver.shared.channel.dispatcher.clone())
331 {
332 Ready(Ok(Some(buffer))) => {
333 let buffer = buffer.map_data(|_, data| {
334 let bytes = data.len();
335 assert_eq!(
336 0,
337 bytes % size_of::<Chunk>(),
338 "Received driver channel buffer was not a multiple of {} bytes",
339 size_of::<Chunk>()
340 );
341 let new_box = unsafe {
345 let ptr = ArenaBox::into_ptr(data).cast();
346 ArenaBox::new(NonNull::slice_from_raw_parts(
347 ptr,
348 bytes / size_of::<Chunk>(),
349 ))
350 };
351 new_box
352 });
353
354 Ready(Ok(Some(RecvBuffer { buffer, data_offset: 0, handle_offset: 0 })))
355 }
356 Ready(Ok(None)) => Ready(Ok(None)),
357 Ready(Err(err)) => Ready(Err(err)),
358 Pending => {
359 receiver.shared.closed_waker.register(cx.waker());
360 if receiver.shared.is_closed.load(Ordering::Relaxed) {
361 return Poll::Ready(Ok(None));
362 }
363 Pending
364 }
365 }
366 }
367}
368
369#[cfg(test)]
370mod test {
371 use fidl_next::{Client, ClientEnd, Responder, Server, ServerEnd, ServerSender};
372 use fidl_next_fuchsia_examples_gizmo::device::{GetEvent, GetHardwareId};
373 use fidl_next_fuchsia_examples_gizmo::{
374 Device, DeviceClientHandler, DeviceClientSender, DeviceGetEventResponse,
375 DeviceGetHardwareIdResponse, DeviceServerHandler,
376 };
377 use fuchsia_async::OnSignals;
378 use zx::{AsHandleRef, Event, HandleBased, Signals};
379
380 use super::*;
381 use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
382 use fdf_env::test::spawn_in_driver;
383
384 struct DeviceServer;
385 impl DeviceServerHandler<DriverChannel> for DeviceServer {
386 fn get_hardware_id(
387 &mut self,
388 sender: &ServerSender<DriverChannel, Device>,
389 responder: Responder<GetHardwareId>,
390 ) {
391 let sender = sender.clone();
392 CurrentDispatcher
393 .spawn_task(async move {
394 responder
395 .respond(
396 &sender,
397 &mut Result::<_, i32>::Ok(DeviceGetHardwareIdResponse {
398 response: 4004,
399 }),
400 )
401 .unwrap()
402 .await
403 .unwrap();
404 })
405 .unwrap();
406 }
407
408 fn get_event(
409 &mut self,
410 sender: &ServerSender<DriverChannel, Device>,
411 responder: Responder<GetEvent>,
412 ) {
413 let sender = sender.clone();
414 let event = Event::create();
415 event.signal_handle(Signals::empty(), Signals::USER_0).unwrap();
416 let mut response = DeviceGetEventResponse { event: event.into_handle() };
417 CurrentDispatcher
418 .spawn_task(async move {
419 responder.respond(&sender, &mut response).unwrap().await.unwrap();
420 })
421 .unwrap();
422 }
423 }
424
425 struct DeviceClient;
426 impl DeviceClientHandler<DriverChannel> for DeviceClient {}
427
428 #[test]
429 fn driver_fidl_server() {
430 spawn_in_driver("driver fidl server", async {
431 let (server_chan, client_chan) = Channel::<[Chunk]>::create();
432 let client_end = ClientEnd::from_untyped(DriverChannel::new(client_chan));
433 let server_end: ServerEnd<_, Device> =
434 ServerEnd::from_untyped(DriverChannel::new(server_chan));
435 let mut client = Client::new(client_end);
436 let mut server = Server::new(server_end);
437 let client_sender = client.sender().clone();
438
439 CurrentDispatcher
440 .spawn_task(async move {
441 server.run(DeviceServer).await.unwrap();
442 println!("server task finished");
443 })
444 .unwrap();
445 CurrentDispatcher
446 .spawn_task(async move {
447 client.run(DeviceClient).await.unwrap();
448 println!("client task finished");
449 })
450 .unwrap();
451
452 {
453 let mut res = client_sender.get_hardware_id().unwrap().await.unwrap();
454 let res = res.decode().unwrap();
455 let hardware_id = res.unwrap();
456 assert_eq!(hardware_id.response, 4004);
457 }
458
459 {
460 let mut res = client_sender.get_event().unwrap().await.unwrap();
461 let res = res.decode().unwrap();
462 let event = Event::from_handle(res.event.take());
463
464 let mut executor = fuchsia_async::LocalExecutor::new();
466 let signalled =
467 executor.run_singlethreaded(OnSignals::new(event, Signals::USER_0)).unwrap();
468 assert_eq!(Signals::USER_0, signalled);
469 }
470 });
471 }
472}