1use futures::task::AtomicWaker;
6use std::num::NonZero;
7use std::ptr::NonNull;
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::Poll;
11
12use fidl_next::Chunk;
13use zx::Status;
14
15use fdf_channel::arena::{Arena, ArenaBox};
16use fdf_channel::channel::Channel;
17use fdf_channel::futures::ReadMessageState;
18use fdf_channel::message::Message;
19use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
20use fdf_core::handle::{MixedHandle, MixedHandleType};
21
22pub struct DriverChannel<D = CurrentDispatcher> {
25 dispatcher: D,
26 channel: Channel<[Chunk]>,
27}
28
29impl<D> DriverChannel<D> {
30 pub fn new_with_dispatcher(dispatcher: D, channel: Channel<[Chunk]>) -> Self {
33 Self { dispatcher, channel }
34 }
35}
36
37impl DriverChannel<CurrentDispatcher> {
38 pub fn new(channel: Channel<[Chunk]>) -> Self {
41 Self::new_with_dispatcher(CurrentDispatcher, channel)
42 }
43}
44
45pub struct SendBuffer {
47 handles: Vec<Option<MixedHandle>>,
48 data: Vec<Chunk>,
49}
50
51impl SendBuffer {
52 fn new() -> Self {
53 Self { handles: Vec::new(), data: Vec::new() }
54 }
55}
56
57impl fidl_next::Encoder for SendBuffer {
58 #[inline]
59 fn bytes_written(&self) -> usize {
60 fidl_next::Encoder::bytes_written(&self.data)
61 }
62
63 #[inline]
64 fn write(&mut self, bytes: &[u8]) {
65 fidl_next::Encoder::write(&mut self.data, bytes)
66 }
67
68 #[inline]
69 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
70 fidl_next::Encoder::rewrite(&mut self.data, pos, bytes)
71 }
72
73 fn write_zeroes(&mut self, len: usize) {
74 fidl_next::Encoder::write_zeroes(&mut self.data, len);
75 }
76}
77
78impl fidl_next::encoder::InternalHandleEncoder for SendBuffer {
79 #[inline]
80 fn __internal_handle_count(&self) -> usize {
81 self.handles.len()
82 }
83}
84
85impl fidl_next::fuchsia::HandleEncoder for SendBuffer {
86 fn push_handle(&mut self, handle: zx::Handle) -> Result<(), fidl_next::EncodeError> {
87 if let Some(handle) = MixedHandle::from_zircon_handle(handle) {
88 if handle.is_driver() {
89 return Err(fidl_next::EncodeError::ExpectedZirconHandle);
90 }
91 self.handles.push(Some(handle));
92 } else {
93 self.handles.push(None);
94 }
95 Ok(())
96 }
97
98 fn push_raw_driver_handle(&mut self, handle: u32) -> Result<(), fidl_next::EncodeError> {
99 if let Some(handle) = NonZero::new(handle) {
100 let handle = unsafe { MixedHandle::from_raw(handle) };
103 if !handle.is_driver() {
104 return Err(fidl_next::EncodeError::ExpectedDriverHandle);
105 }
106 self.handles.push(Some(handle));
107 } else {
108 self.handles.push(None);
109 }
110 Ok(())
111 }
112
113 fn handles_pushed(&self) -> usize {
114 self.handles.len()
115 }
116}
117
118pub struct RecvBuffer {
119 buffer: Message<[Chunk]>,
120 data_offset: usize,
121 handle_offset: usize,
122}
123
124impl RecvBuffer {
125 fn next_handle(&self) -> Result<&MixedHandle, fidl_next::DecodeError> {
126 let Some(handles) = self.buffer.handles() else {
127 return Err(fidl_next::DecodeError::InsufficientHandles);
128 };
129 if handles.len() < self.handle_offset + 1 {
130 return Err(fidl_next::DecodeError::InsufficientHandles);
131 }
132 handles[self.handle_offset].as_ref().ok_or(fidl_next::DecodeError::RequiredHandleAbsent)
133 }
134}
135
136unsafe impl fidl_next::Decoder for RecvBuffer {
141 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, fidl_next::DecodeError> {
144 let Some(data) = self.buffer.data_mut() else {
145 return Err(fidl_next::DecodeError::InsufficientData);
146 };
147 if data.len() < self.data_offset + count {
148 return Err(fidl_next::DecodeError::InsufficientData);
149 }
150 let pos = self.data_offset;
151 self.data_offset += count;
152 Ok(unsafe { NonNull::new_unchecked((&mut data[pos..(pos + count)]).as_mut_ptr()) })
153 }
154
155 fn commit(&mut self) {
156 if let Some(handles) = self.buffer.handles_mut() {
157 for i in 0..self.handle_offset {
158 core::mem::forget(handles[i].take());
159 }
160 }
161 }
162
163 fn finish(&self) -> Result<(), fidl_next::DecodeError> {
164 let data_len = self.buffer.data().unwrap_or(&[]).len();
165 if self.data_offset != data_len {
166 return Err(fidl_next::DecodeError::ExtraBytes {
167 num_extra: data_len - self.data_offset,
168 });
169 }
170 let handle_len = self.buffer.handles().unwrap_or(&[]).len();
171 if self.handle_offset != handle_len {
172 return Err(fidl_next::DecodeError::ExtraHandles {
173 num_extra: handle_len - self.handle_offset,
174 });
175 }
176 Ok(())
177 }
178}
179
180impl fidl_next::decoder::InternalHandleDecoder for RecvBuffer {
181 fn __internal_take_handles(&mut self, count: usize) -> Result<(), fidl_next::DecodeError> {
182 let Some(handles) = self.buffer.handles_mut() else {
183 return Err(fidl_next::DecodeError::InsufficientHandles);
184 };
185 if handles.len() < self.handle_offset + count {
186 return Err(fidl_next::DecodeError::InsufficientHandles);
187 }
188 let pos = self.handle_offset;
189 self.handle_offset = pos + count;
190 Ok(())
191 }
192
193 fn __internal_handles_remaining(&self) -> usize {
194 self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
195 }
196}
197
198impl fidl_next::fuchsia::HandleDecoder for RecvBuffer {
199 fn take_raw_handle(&mut self) -> Result<zx::sys::zx_handle_t, fidl_next::DecodeError> {
200 let result = {
201 let handle = self.next_handle()?.resolve_ref();
202 let MixedHandleType::Zircon(handle) = handle else {
203 return Err(fidl_next::DecodeError::ExpectedZirconHandle);
204 };
205 handle.raw_handle()
206 };
207 let pos = self.handle_offset;
208 self.handle_offset = pos + 1;
209 Ok(result)
210 }
211
212 fn take_raw_driver_handle(&mut self) -> Result<u32, fidl_next::DecodeError> {
213 let result = {
214 let handle = self.next_handle()?.resolve_ref();
215 let MixedHandleType::Driver(handle) = handle else {
216 return Err(fidl_next::DecodeError::ExpectedDriverHandle);
217 };
218 unsafe { handle.get_raw().get() }
219 };
220 let pos = self.handle_offset;
221 self.handle_offset = pos + 1;
222 Ok(result)
223 }
224
225 fn handles_remaining(&mut self) -> usize {
226 self.buffer.handles().unwrap_or(&[]).len() - self.handle_offset
227 }
228}
229
230pub struct DriverRecvState(ReadMessageState);
232
233struct Shared<D> {
234 is_closed: AtomicBool,
235 sender_count: AtomicUsize,
236 closed_waker: AtomicWaker,
237 channel: DriverChannel<D>,
238}
239
240impl<D> Shared<D> {
241 fn new(channel: DriverChannel<D>) -> Self {
242 Self {
243 is_closed: AtomicBool::new(false),
244 sender_count: AtomicUsize::new(1),
245 closed_waker: AtomicWaker::new(),
246 channel,
247 }
248 }
249
250 fn close(&self) {
251 self.is_closed.store(true, Ordering::Relaxed);
252 self.closed_waker.wake();
253 }
254}
255pub struct DriverSender<D> {
257 shared: Arc<Shared<D>>,
258}
259
260impl<D> Drop for DriverSender<D> {
261 fn drop(&mut self) {
262 let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
263 if senders == 1 {
264 self.shared.close();
265 }
266 }
267}
268
269impl<D> Clone for DriverSender<D> {
270 fn clone(&self) -> Self {
271 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
272 Self { shared: self.shared.clone() }
273 }
274}
275
276pub struct DriverReceiver<D> {
278 shared: Arc<Shared<D>>,
279}
280
281impl<D: OnDispatcher> fidl_next::protocol::Transport for DriverChannel<D> {
282 type Error = Status;
283
284 fn split(self) -> (Self::Sender, Self::Receiver) {
285 let shared = Arc::new(Shared::new(self));
286 let sender = DriverSender { shared: shared.clone() };
287 let receiver = DriverReceiver { shared };
288 (sender, receiver)
289 }
290
291 type Sender = DriverSender<D>;
292
293 type SendBuffer = SendBuffer;
294
295 type SendFutureState = SendBuffer;
296
297 fn acquire(_sender: &Self::Sender) -> Self::SendBuffer {
298 SendBuffer::new()
299 }
300
301 fn close(sender: &Self::Sender) {
302 sender.shared.close();
303 }
304
305 type Receiver = DriverReceiver<D>;
306
307 type RecvFutureState = DriverRecvState;
308
309 type RecvBuffer = RecvBuffer;
310
311 fn begin_send(_sender: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
312 buffer
313 }
314
315 fn poll_send(
316 mut buffer: std::pin::Pin<&mut Self::SendFutureState>,
317 _cx: &mut std::task::Context<'_>,
318 sender: &Self::Sender,
319 ) -> std::task::Poll<Result<(), Self::Error>> {
320 let arena = Arena::new();
321 let message = Message::new_with(arena, |arena| {
322 let data = arena.insert_slice(&buffer.data);
323 let handles = buffer.handles.split_off(0);
324 let handles = arena.insert_from_iter(handles.into_iter());
325 (Some(data), Some(handles))
326 });
327 Poll::Ready(sender.shared.channel.channel.write(message))
328 }
329
330 fn begin_recv(receiver: &mut Self::Receiver) -> Self::RecvFutureState {
331 let state =
334 unsafe { ReadMessageState::new(receiver.shared.channel.channel.driver_handle()) };
335 DriverRecvState(state)
336 }
337
338 fn poll_recv(
339 mut future: std::pin::Pin<&mut Self::RecvFutureState>,
340 cx: &mut std::task::Context<'_>,
341 receiver: &mut Self::Receiver,
342 ) -> std::task::Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
343 use std::task::Poll::*;
344 match future.as_mut().0.poll_with_dispatcher(cx, receiver.shared.channel.dispatcher.clone())
345 {
346 Ready(Ok(Some(buffer))) => {
347 let buffer = buffer.map_data(|_, data| {
348 let bytes = data.len();
349 assert_eq!(
350 0,
351 bytes % size_of::<Chunk>(),
352 "Received driver channel buffer was not a multiple of {} bytes",
353 size_of::<Chunk>()
354 );
355 let new_box = unsafe {
359 let ptr = ArenaBox::into_ptr(data).cast();
360 ArenaBox::new(NonNull::slice_from_raw_parts(
361 ptr,
362 bytes / size_of::<Chunk>(),
363 ))
364 };
365 new_box
366 });
367
368 Ready(Ok(Some(RecvBuffer { buffer, data_offset: 0, handle_offset: 0 })))
369 }
370 Ready(Ok(None)) => Ready(Ok(None)),
371 Ready(Err(err)) => Ready(Err(err)),
372 Pending => {
373 receiver.shared.closed_waker.register(cx.waker());
374 if receiver.shared.is_closed.load(Ordering::Relaxed) {
375 return Poll::Ready(Ok(None));
376 }
377 Pending
378 }
379 }
380 }
381}
382
383#[cfg(test)]
384mod test {
385 use fidl_next::{Client, ClientEnd, Responder, Server, ServerEnd, ServerSender};
386 use fidl_next_fuchsia_examples_gizmo::device::{GetEvent, GetHardwareId};
387 use fidl_next_fuchsia_examples_gizmo::{
388 Device, DeviceClientHandler, DeviceClientSender, DeviceGetEventResponse,
389 DeviceGetHardwareIdResponse, DeviceServerHandler,
390 };
391 use fuchsia_async::OnSignals;
392 use zx::{AsHandleRef, Event, HandleBased, Signals};
393
394 use super::*;
395 use fdf_core::dispatcher::{CurrentDispatcher, OnDispatcher};
396 use fdf_env::test::spawn_in_driver;
397
398 struct DeviceServer;
399 impl DeviceServerHandler<DriverChannel> for DeviceServer {
400 fn get_hardware_id(
401 &mut self,
402 sender: &ServerSender<DriverChannel, Device>,
403 responder: Responder<GetHardwareId>,
404 ) {
405 let sender = sender.clone();
406 CurrentDispatcher
407 .spawn_task(async move {
408 responder
409 .respond(
410 &sender,
411 Result::<_, i32>::Ok(DeviceGetHardwareIdResponse { response: 4004 }),
412 )
413 .unwrap()
414 .await
415 .unwrap();
416 })
417 .unwrap();
418 }
419
420 fn get_event(
421 &mut self,
422 sender: &ServerSender<DriverChannel, Device>,
423 responder: Responder<GetEvent>,
424 ) {
425 let sender = sender.clone();
426 let event = Event::create();
427 event.signal_handle(Signals::empty(), Signals::USER_0).unwrap();
428 let response = DeviceGetEventResponse { event: event.into_handle() };
429 CurrentDispatcher
430 .spawn_task(async move {
431 responder.respond(&sender, response).unwrap().await.unwrap();
432 })
433 .unwrap();
434 }
435 }
436
437 struct DeviceClient;
438 impl DeviceClientHandler<DriverChannel> for DeviceClient {}
439
440 #[test]
441 fn driver_fidl_server() {
442 spawn_in_driver("driver fidl server", async {
443 let (server_chan, client_chan) = Channel::<[Chunk]>::create();
444 let client_end = ClientEnd::from_untyped(DriverChannel::new(client_chan));
445 let server_end: ServerEnd<_, Device> =
446 ServerEnd::from_untyped(DriverChannel::new(server_chan));
447 let mut client = Client::new(client_end);
448 let mut server = Server::new(server_end);
449 let client_sender = client.sender().clone();
450
451 CurrentDispatcher
452 .spawn_task(async move {
453 server.run(DeviceServer).await.unwrap();
454 println!("server task finished");
455 })
456 .unwrap();
457 CurrentDispatcher
458 .spawn_task(async move {
459 client.run(DeviceClient).await.unwrap();
460 println!("client task finished");
461 })
462 .unwrap();
463
464 {
465 let res = client_sender.get_hardware_id().unwrap().await.unwrap();
466 let hardware_id = res.unwrap();
467 assert_eq!(hardware_id.response, 4004);
468 }
469
470 {
471 let res = client_sender.get_event().unwrap().await.unwrap();
472 let event = Event::from_handle(res.event.take());
473
474 let mut executor = fuchsia_async::LocalExecutor::new();
476 let signalled =
477 executor.run_singlethreaded(OnSignals::new(event, Signals::USER_0)).unwrap();
478 assert_eq!(Signals::USER_0, signalled);
479 }
480 });
481 }
482}