1use core::mem::replace;
8use core::pin::Pin;
9use core::ptr::NonNull;
10use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
11use core::task::{Context, Poll};
12use std::sync::Arc;
13
14use fidl_next_codec::decoder::InternalHandleDecoder;
15use fidl_next_codec::encoder::InternalHandleEncoder;
16use fidl_next_codec::fuchsia::{HandleDecoder, HandleEncoder};
17use fidl_next_codec::{Chunk, DecodeError, Decoder, EncodeError, Encoder, CHUNK_SIZE};
18use fuchsia_async::{RWHandle, ReadableHandle as _};
19use futures::task::AtomicWaker;
20use zx::sys::{
21 zx_channel_read, zx_channel_write, zx_handle_t, ZX_ERR_BUFFER_TOO_SMALL, ZX_ERR_PEER_CLOSED,
22 ZX_ERR_SHOULD_WAIT, ZX_OK,
23};
24use zx::{AsHandleRef as _, Channel, Handle, HandleBased, Status};
25
26use crate::{NonBlockingTransport, Transport};
27
28struct Shared {
29 is_closed: AtomicBool,
30 sender_count: AtomicUsize,
31 closed_waker: AtomicWaker,
32 channel: RWHandle<Channel>,
33 }
35
36impl Shared {
37 fn new(channel: Channel) -> Self {
38 Self {
39 is_closed: AtomicBool::new(false),
40 sender_count: AtomicUsize::new(1),
41 closed_waker: AtomicWaker::new(),
42 channel: RWHandle::new(channel),
43 }
44 }
45
46 fn close(&self) {
47 self.is_closed.store(true, Ordering::Relaxed);
48 self.closed_waker.wake();
49 }
50}
51
52pub struct Sender {
54 shared: Arc<Shared>,
55}
56
57impl Drop for Sender {
58 fn drop(&mut self) {
59 let senders = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
60 if senders == 1 {
61 self.shared.close();
62 }
63 }
64}
65
66impl Clone for Sender {
67 fn clone(&self) -> Self {
68 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
69 Self { shared: self.shared.clone() }
70 }
71}
72
73#[derive(Default)]
75pub struct Buffer {
76 handles: Vec<Handle>,
77 chunks: Vec<Chunk>,
78}
79
80impl Buffer {
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn handles(&self) -> &[Handle] {
88 &self.handles
89 }
90
91 pub fn bytes(&self) -> Vec<u8> {
93 self.chunks.iter().flat_map(|chunk| chunk.to_le_bytes()).collect()
94 }
95}
96
97impl InternalHandleEncoder for Buffer {
98 #[inline]
99 fn __internal_handle_count(&self) -> usize {
100 self.handles.len()
101 }
102}
103
104impl Encoder for Buffer {
105 #[inline]
106 fn bytes_written(&self) -> usize {
107 Encoder::bytes_written(&self.chunks)
108 }
109
110 #[inline]
111 fn write_zeroes(&mut self, len: usize) {
112 Encoder::write_zeroes(&mut self.chunks, len)
113 }
114
115 #[inline]
116 fn write(&mut self, bytes: &[u8]) {
117 Encoder::write(&mut self.chunks, bytes)
118 }
119
120 #[inline]
121 fn rewrite(&mut self, pos: usize, bytes: &[u8]) {
122 Encoder::rewrite(&mut self.chunks, pos, bytes)
123 }
124}
125
126impl HandleEncoder for Buffer {
127 fn push_handle(&mut self, handle: Handle) -> Result<(), EncodeError> {
128 self.handles.push(handle);
129 Ok(())
130 }
131
132 fn handles_pushed(&self) -> usize {
133 self.handles.len()
134 }
135}
136
137pub struct SendFutureState {
139 buffer: Buffer,
140}
141
142pub struct Receiver {
144 shared: Arc<Shared>,
145}
146
147pub struct RecvFutureState {
149 buffer: Option<Buffer>,
150}
151
152pub struct RecvBuffer {
154 buffer: Buffer,
155 chunks_taken: usize,
156 handles_taken: usize,
157}
158
159unsafe impl Decoder for RecvBuffer {
160 fn take_chunks_raw(&mut self, count: usize) -> Result<NonNull<Chunk>, DecodeError> {
161 if count > self.buffer.chunks.len() - self.chunks_taken {
162 return Err(DecodeError::InsufficientData);
163 }
164
165 let chunks = unsafe { self.buffer.chunks.as_mut_ptr().add(self.chunks_taken) };
166 self.chunks_taken += count;
167
168 unsafe { Ok(NonNull::new_unchecked(chunks)) }
169 }
170
171 fn commit(&mut self) {
172 for handle in &mut self.buffer.handles[0..self.handles_taken] {
173 let _ = replace(handle, Handle::invalid()).into_raw();
175 }
176 }
177
178 fn finish(&self) -> Result<(), DecodeError> {
179 if self.chunks_taken != self.buffer.chunks.len() {
180 return Err(DecodeError::ExtraBytes {
181 num_extra: (self.buffer.chunks.len() - self.chunks_taken) * CHUNK_SIZE,
182 });
183 }
184
185 if self.handles_taken != self.buffer.handles.len() {
186 return Err(DecodeError::ExtraHandles {
187 num_extra: self.buffer.handles.len() - self.handles_taken,
188 });
189 }
190
191 Ok(())
192 }
193}
194
195impl InternalHandleDecoder for RecvBuffer {
196 fn __internal_take_handles(&mut self, count: usize) -> Result<(), DecodeError> {
197 if count > self.buffer.handles.len() - self.handles_taken {
198 return Err(DecodeError::InsufficientHandles);
199 }
200
201 for i in self.handles_taken..self.handles_taken + count {
202 let handle = replace(&mut self.buffer.handles[i], Handle::invalid());
203 drop(handle);
204 }
205 self.handles_taken += count;
206
207 Ok(())
208 }
209
210 fn __internal_handles_remaining(&self) -> usize {
211 self.buffer.handles.len() - self.handles_taken
212 }
213}
214
215impl HandleDecoder for RecvBuffer {
216 fn take_raw_handle(&mut self) -> Result<zx_handle_t, DecodeError> {
217 if self.handles_taken >= self.buffer.handles.len() {
218 return Err(DecodeError::InsufficientHandles);
219 }
220
221 let handle = self.buffer.handles[self.handles_taken].raw_handle();
222 self.handles_taken += 1;
223
224 Ok(handle)
225 }
226
227 fn handles_remaining(&mut self) -> usize {
228 self.buffer.handles.len() - self.handles_taken
229 }
230}
231
232impl Transport for Channel {
233 type Error = Status;
234
235 fn split(self) -> (Self::Sender, Self::Receiver) {
236 let shared = Arc::new(Shared::new(self));
237 (Sender { shared: shared.clone() }, Receiver { shared })
238 }
239
240 type Sender = Sender;
241 type SendBuffer = Buffer;
242 type SendFutureState = SendFutureState;
243
244 fn acquire(_: &Self::Sender) -> Self::SendBuffer {
245 Buffer::new()
246 }
247
248 fn begin_send(_: &Self::Sender, buffer: Self::SendBuffer) -> Self::SendFutureState {
249 SendFutureState { buffer }
250 }
251
252 fn poll_send(
253 future_state: Pin<&mut Self::SendFutureState>,
254 _: &mut Context<'_>,
255 sender: &Self::Sender,
256 ) -> Poll<Result<(), Self::Error>> {
257 Poll::Ready(Self::send_immediately(future_state.get_mut(), sender))
258 }
259
260 fn close(sender: &Self::Sender) {
261 sender.shared.close();
262 }
263
264 type Receiver = Receiver;
265 type RecvFutureState = RecvFutureState;
266 type RecvBuffer = RecvBuffer;
267
268 fn begin_recv(_: &mut Self::Receiver) -> Self::RecvFutureState {
269 RecvFutureState { buffer: Some(Buffer::new()) }
270 }
271
272 fn poll_recv(
273 mut future_state: Pin<&mut Self::RecvFutureState>,
274 cx: &mut Context<'_>,
275 receiver: &mut Self::Receiver,
276 ) -> Poll<Result<Option<Self::RecvBuffer>, Self::Error>> {
277 let buffer = future_state.buffer.as_mut().unwrap();
278
279 let mut actual_bytes = 0;
280 let mut actual_handles = 0;
281
282 loop {
283 let result = unsafe {
284 zx_channel_read(
285 receiver.shared.channel.get_ref().raw_handle(),
286 0,
287 buffer.chunks.as_mut_ptr().cast(),
288 buffer.handles.as_mut_ptr().cast(),
289 (buffer.chunks.capacity() * CHUNK_SIZE) as u32,
290 buffer.handles.capacity() as u32,
291 &mut actual_bytes,
292 &mut actual_handles,
293 )
294 };
295
296 match result {
297 ZX_OK => {
298 unsafe {
299 buffer.chunks.set_len(actual_bytes as usize / CHUNK_SIZE);
300 buffer.handles.set_len(actual_handles as usize);
301 }
302 return Poll::Ready(Ok(Some(RecvBuffer {
303 buffer: future_state.buffer.take().unwrap(),
304 chunks_taken: 0,
305 handles_taken: 0,
306 })));
307 }
308 ZX_ERR_PEER_CLOSED => return Poll::Ready(Ok(None)),
309 ZX_ERR_BUFFER_TOO_SMALL => {
310 let min_chunks = (actual_bytes as usize).div_ceil(CHUNK_SIZE);
311 buffer.chunks.reserve(min_chunks - buffer.chunks.capacity());
312 buffer.handles.reserve(actual_handles as usize - buffer.handles.capacity());
313 }
314 ZX_ERR_SHOULD_WAIT => {
315 if matches!(receiver.shared.channel.need_readable(cx)?, Poll::Pending) {
316 receiver.shared.closed_waker.register(cx.waker());
317 if receiver.shared.is_closed.load(Ordering::Relaxed) {
318 return Poll::Ready(Ok(None));
319 }
320 return Poll::Pending;
321 }
322 }
323 raw => return Poll::Ready(Err(Status::from_raw(raw))),
324 }
325 }
326 }
327}
328
329impl NonBlockingTransport for Channel {
330 fn send_immediately(
331 future_state: &mut Self::SendFutureState,
332 sender: &Self::Sender,
333 ) -> Result<(), Self::Error> {
334 let result = unsafe {
335 zx_channel_write(
336 sender.shared.channel.get_ref().raw_handle(),
337 0,
338 future_state.buffer.chunks.as_ptr().cast::<u8>(),
339 (future_state.buffer.chunks.len() * CHUNK_SIZE) as u32,
340 future_state.buffer.handles.as_ptr().cast(),
341 future_state.buffer.handles.len() as u32,
342 )
343 };
344
345 if result == ZX_OK {
346 unsafe {
348 future_state.buffer.handles.set_len(0);
349 }
350 Ok(())
351 } else {
352 Err(Status::from_raw(result))
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use core::mem::MaybeUninit;
360
361 use fidl_next_codec::fuchsia::{HandleDecoder, HandleEncoder, WireHandle};
362 use fidl_next_codec::{
363 munge, Decode, DecodeError, DecoderExt as _, Encodable, Encode, EncodeError,
364 EncoderExt as _, Slot, WireString, ZeroPadding,
365 };
366 use fuchsia_async as fasync;
367 use zx::{AsHandleRef, Channel, Handle, HandleBased as _, Instant, Signals, WaitResult};
368
369 use crate::fuchsia::channel::{Buffer, RecvBuffer};
370 use crate::testing::{
371 test_close_on_drop, test_event, test_multiple_two_way, test_one_way, test_two_way,
372 };
373 use crate::{Client, Responder, Server, ServerHandler, ServerSender, Transport};
374
375 #[fasync::run_singlethreaded(test)]
376 async fn close_on_drop() {
377 let (client_end, server_end) = Channel::create();
378 test_close_on_drop(client_end, server_end).await;
379 }
380
381 #[fasync::run_singlethreaded(test)]
382 async fn one_way() {
383 let (client_end, server_end) = Channel::create();
384 test_one_way(client_end, server_end).await;
385 }
386
387 #[fasync::run_singlethreaded(test)]
388 async fn two_way() {
389 let (client_end, server_end) = Channel::create();
390 test_two_way(client_end, server_end).await;
391 }
392
393 #[fasync::run_singlethreaded(test)]
394 async fn multiple_two_way() {
395 let (client_end, server_end) = Channel::create();
396 test_multiple_two_way(client_end, server_end).await;
397 }
398
399 #[fasync::run_singlethreaded(test)]
400 async fn event() {
401 let (client_end, server_end) = Channel::create();
402 test_event(client_end, server_end).await;
403 }
404
405 struct HandleAndBoolean {
406 handle: Handle,
407 boolean: bool,
408 }
409
410 #[derive(Debug)]
411 #[repr(C)]
412 struct WireHandleAndBoolean {
413 handle: WireHandle,
414 boolean: bool,
415 }
416
417 unsafe impl ZeroPadding for WireHandleAndBoolean {
418 fn zero_padding(out: &mut MaybeUninit<Self>) {
419 unsafe {
420 out.as_mut_ptr().write_bytes(0, 1);
421 }
422 }
423 }
424
425 impl Encodable for HandleAndBoolean {
426 type Encoded = WireHandleAndBoolean;
427 }
428
429 unsafe impl<E: HandleEncoder + ?Sized> Encode<E> for HandleAndBoolean {
430 fn encode(
431 self,
432 encoder: &mut E,
433 out: &mut MaybeUninit<Self::Encoded>,
434 ) -> Result<(), EncodeError> {
435 munge!(let Self::Encoded { handle, boolean } = out);
436 self.handle.encode(encoder, handle)?;
437 self.boolean.encode(encoder, boolean)?;
438 Ok(())
439 }
440 }
441
442 unsafe impl<D: HandleDecoder + ?Sized> Decode<D> for WireHandleAndBoolean {
443 fn decode(slot: Slot<'_, Self>, decoder: &mut D) -> Result<(), DecodeError> {
444 munge!(let Self { handle, boolean } = slot);
445 Decode::decode(handle, decoder)?;
446 Decode::decode(boolean, decoder)?;
447 Ok(())
448 }
449 }
450
451 #[test]
452 fn partial_decode_drops_handles() {
453 let (encode_end, check_end) = Channel::create();
454
455 let mut buffer = Buffer::new();
456 buffer
457 .encode_next(HandleAndBoolean { handle: encode_end.into_handle(), boolean: false })
458 .expect("encoding should succeed");
459 *buffer.chunks[0] |= 0x00000002_00000000;
461
462 let mut recv_buffer = RecvBuffer { buffer, chunks_taken: 0, handles_taken: 0 };
463 (&mut recv_buffer)
464 .decode_prefix::<WireHandleAndBoolean>()
465 .expect_err("decoding an invalid boolean should fail");
466
467 assert_eq!(
469 check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
470 WaitResult::TimedOut(Signals::CHANNEL_WRITABLE),
471 );
472
473 drop(recv_buffer);
474
475 assert_eq!(
477 check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
478 WaitResult::Ok(Signals::CHANNEL_PEER_CLOSED),
479 );
480 }
481
482 #[test]
483 fn complete_decode_moves_handles() {
484 let (encode_end, check_end) = Channel::create();
485
486 let mut buffer = Buffer::new();
487 buffer
488 .encode_next(HandleAndBoolean { handle: encode_end.into_handle(), boolean: false })
489 .expect("encoding should succeed");
490
491 let recv_buffer = RecvBuffer { buffer, chunks_taken: 0, handles_taken: 0 };
492 let decoded =
493 recv_buffer.decode::<WireHandleAndBoolean>().expect("decoding should succeed");
494
495 assert_eq!(
497 check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
498 WaitResult::TimedOut(Signals::CHANNEL_WRITABLE),
499 );
500
501 drop(decoded.handle.take());
502
503 assert_eq!(
505 check_end.wait_handle(Signals::CHANNEL_PEER_CLOSED, Instant::INFINITE_PAST),
506 WaitResult::Ok(Signals::CHANNEL_PEER_CLOSED),
507 );
508
509 drop(decoded);
510 }
511
512 #[fasync::run_singlethreaded(test)]
513 async fn one_way_nonblocking() {
514 let (client_end, server_end) = Channel::create();
515 struct TestServer;
516
517 impl<T: Transport> ServerHandler<T> for TestServer {
518 fn on_one_way(&mut self, _: &ServerSender<T>, ordinal: u64, buffer: T::RecvBuffer) {
519 assert_eq!(ordinal, 42);
520 let message = buffer.decode::<WireString>().expect("failed to decode request");
521 assert_eq!(&**message, "Hello world");
522 }
523
524 fn on_two_way(&mut self, _: &ServerSender<T>, _: u64, _: T::RecvBuffer, _: Responder) {
525 panic!("unexpected two-way message");
526 }
527 }
528
529 let mut client = Client::new(client_end);
530 let client_sender = client.sender().clone();
531 let client_task = fasync::Task::spawn(async move { client.run_sender().await });
532 let mut server = Server::new(server_end);
533 let server_task = fasync::Task::spawn(async move { server.run(TestServer).await });
534
535 client_sender
536 .send_one_way(42, "Hello world")
537 .expect("client failed to encode request")
538 .send_immediately()
539 .expect("client failed to send request");
540 client_sender.close();
541 drop(client_sender);
542
543 client_task.await.expect("client encountered an error");
544 server_task.await.expect("server encountered an error");
545 }
546}