fidl_next_protocol/endpoints/connection.rs
1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::future::Future;
6use core::mem::{ManuallyDrop, MaybeUninit, replace, take};
7use core::pin::Pin;
8use core::task::{Context, Poll, Waker};
9use fidl_constants::EPITAPH_ORDINAL;
10
11use fidl_next_codec::{EncodeError, EncoderExt as _};
12use pin_project::pin_project;
13
14use crate::concurrency::cell::UnsafeCell;
15use crate::concurrency::future::AtomicWaker;
16use crate::concurrency::hint::unreachable_unchecked;
17use crate::concurrency::sync::Mutex;
18use crate::concurrency::sync::atomic::{AtomicUsize, Ordering};
19use crate::wire::{Epitaph, MessageHeader};
20use crate::{Flexibility, NonBlockingTransport, ProtocolError, Transport};
21
22// Indicates that the connection has been requested to stop. Connections are
23// always stopped as they are terminated.
24const STOPPING_BIT: usize = 1 << 0;
25// Indicates that the connection has been provided a termination reason.
26const TERMINATED_BIT: usize = 1 << 1;
27const BITS_COUNT: usize = 2;
28
29// Each refcount represents a thread which is attempting to access the shared
30// part of the transport.
31const REFCOUNT: usize = 1 << BITS_COUNT;
32
33#[derive(Clone, Copy)]
34struct State(usize);
35
36impl State {
37 fn is_stopping(self) -> bool {
38 self.0 & STOPPING_BIT != 0
39 }
40
41 fn is_terminated(self) -> bool {
42 self.0 & TERMINATED_BIT != 0
43 }
44
45 fn refcount(self) -> usize {
46 self.0 >> BITS_COUNT
47 }
48}
49
50/// A wrapper around a transport which connectivity semantics.
51///
52/// The [`Transport`] trait only provides the bare minimum API surface required
53/// to send and receive data. On top of that, FIDL requires that clients and
54/// servers respect additional messaging semantics. Those semantics are provided
55/// by [`Connection`]:
56///
57/// - `Transport`s are difficult to close because they may be accessed from
58/// several threads simultaneously. `Connection`s provide a mechanism for
59/// gracefully closing transports by causing all sends to pend until the
60/// connection is terminated, and all receives to fail instead of pend.
61/// - FIDL connections may send and receive an epitaph as the final message
62/// before the underlying transport is closed. This epitaph should be provided
63/// to all sends when they fail, which requires additional coordination.
64pub struct Connection<T: Transport> {
65 // The lowest `BITS_COUNT` of this field contain flags indicating the
66 // current state of the transport. The remainder of the upper bits contain
67 // the number of threads attempting to access the `shared` field.
68 state: AtomicUsize,
69 // A thread will drop `shared` if:
70 //
71 // - the connection is dropped before being terminated, or
72 // - it set `TERMINATED_BIT` while the refcount was 0, or
73 // - it decremented the refcount to 0 while `TERMINATED_BIT` was set.
74 //
75 // These cases are handled by `drop`, `terminate`, and `with_shared`
76 // respectively.
77 shared: UnsafeCell<ManuallyDrop<T::Shared>>,
78 stop_waker: AtomicWaker,
79 // TODO: switch this to intrusive linked list in send futures
80 termination_wakers: Mutex<Vec<Waker>>,
81 // Initialized if `TERMINATED_BIT` is set.
82 termination_reason: UnsafeCell<MaybeUninit<ProtocolError<T::Error>>>,
83}
84
85unsafe impl<T: Transport> Send for Connection<T> {}
86unsafe impl<T: Transport> Sync for Connection<T> {}
87
88impl<T: Transport> Drop for Connection<T> {
89 fn drop(&mut self) {
90 self.state.with_mut(|state| {
91 let state = State(*state);
92
93 if !state.is_terminated() {
94 self.shared.with_mut(|shared| {
95 // SAFETY: The connection was not terminated before being
96 // dropped, so `shared` has not yet been dropped.
97 unsafe {
98 ManuallyDrop::drop(&mut *shared);
99 }
100 });
101 } else {
102 self.termination_reason.with_mut(|termination_reason| {
103 // SAFETY: The connection was terminated before being
104 // dropped, so `termination_reason` is initialized.
105 unsafe {
106 MaybeUninit::assume_init_drop(&mut *termination_reason);
107 }
108 });
109 }
110 });
111 }
112}
113
114impl<T: Transport> Connection<T> {
115 /// Creates a new connection from the shared part of a transport.
116 pub fn new(shared: T::Shared) -> Self {
117 Self {
118 state: AtomicUsize::new(0),
119 shared: UnsafeCell::new(ManuallyDrop::new(shared)),
120 stop_waker: AtomicWaker::new(),
121 termination_wakers: Mutex::new(Vec::new()),
122 termination_reason: UnsafeCell::new(MaybeUninit::uninit()),
123 }
124 }
125
126 /// # Safety
127 ///
128 /// This thread must have loaded `state` with at least `Ordering::Acquire`
129 /// and observed that `TERMINATED_BIT` was set.
130 unsafe fn get_termination_reason_unchecked(&self) -> ProtocolError<T::Error> {
131 self.termination_reason.with(|termination_reason| {
132 // SAFETY: The caller guaranteed that `state` was loaded with at
133 // least `Ordering::Acquire` ordering and observed that
134 // `TERMINATED_BIT` was set.
135 unsafe { MaybeUninit::assume_init_ref(&*termination_reason).clone() }
136 })
137 }
138
139 /// Returns the termination reason for the connection, if any.
140 pub fn get_termination_reason(&self) -> Option<ProtocolError<T::Error>> {
141 if State(self.state.load(Ordering::Acquire)).is_terminated() {
142 // SAFETY: We loaded the state with `Ordering::Acquire` and observed
143 // that `TERMINATED_BIT` was set.
144 unsafe { Some(self.get_termination_reason_unchecked()) }
145 } else {
146 None
147 }
148 }
149
150 /// # Safety
151 ///
152 /// `shared` must not have been dropped. See the documentation on `shared`
153 /// for acceptable criteria.
154 unsafe fn get_shared_unchecked(&self) -> &T::Shared {
155 self.shared.with(|shared| {
156 // SAFETY: The caller guaranteed that `shared` has not been dropped.
157 unsafe { &*shared }
158 })
159 }
160
161 fn with_shared<U>(
162 &self,
163 success: impl FnOnce(&T::Shared) -> U,
164 failure: impl FnOnce(Option<ProtocolError<T::Error>>) -> U,
165 ) -> U {
166 let pre_increment = State(self.state.fetch_add(REFCOUNT, Ordering::Acquire));
167
168 // After the refcount drops to zero (and `shared` is dropped), threads
169 // may still increment and decrement the refcount to attempt to read it.
170 // To avoid dropping `shared` more than once, we prevent the refcount
171 // from being decremented to 0 more than once after `TERMINATED_BIT` is
172 // set.
173 //
174 // We do this by having each thread check whether its increment changed
175 // the refcount from 0 to 1 while `TERMINATED_BIT` was set. If it did,
176 // the thread will not decrement that refcount, leaving it "dangling"
177 // instead. This ensures that the refcount never falls below 1 again.
178 if pre_increment.is_terminated() && pre_increment.refcount() == 0 {
179 // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
180 // that `TERMINATED_BIT` was set.
181 let termination_reason = unsafe { self.get_termination_reason_unchecked() };
182 return failure(Some(termination_reason));
183 }
184
185 let mut success_result = None;
186 if !pre_increment.is_stopping() {
187 // SAFETY: Termination always sets `STOPPING_BIT`. We incremented
188 // the refcount while `STOPPING_BIT` was not set, so `shared` won't
189 // be dropped until we decrement our refcount.
190 let shared = unsafe { self.get_shared_unchecked() };
191 success_result = Some(success(shared));
192 }
193
194 let pre_decrement = State(self.state.fetch_sub(REFCOUNT, Ordering::AcqRel));
195
196 if !pre_decrement.is_stopping() {
197 success_result.unwrap()
198 } else if !pre_decrement.is_terminated() {
199 failure(None)
200 } else {
201 // The connection is terminated. If we decremented the refcount to
202 // 0, then we need to drop `shared`.
203 if pre_decrement.refcount() == 1 {
204 self.shared.with_mut(|shared| {
205 // SAFETY: We decremented the refcount to 0 while
206 // `TERMINATED_BIT` was set.
207 unsafe {
208 ManuallyDrop::drop(&mut *shared);
209 }
210 });
211 }
212
213 // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
214 // that `TERMINATED_BIT` was set.
215 let termination_reason = unsafe { self.get_termination_reason_unchecked() };
216 failure(Some(termination_reason))
217 }
218 }
219
220 /// Sends a message to the underlying transport.
221 ///
222 /// Returns a `SendFutureState` which can be polled to completion.
223 pub fn send_message_raw(
224 &self,
225 f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
226 ) -> Result<SendFutureState<T>, EncodeError> {
227 self.with_shared(
228 |shared| {
229 let mut buffer = T::acquire(shared);
230 f(&mut buffer)?;
231 Ok(SendFutureState::Running { future_state: T::begin_send(shared, buffer) })
232 },
233 |error| {
234 Ok(error
235 // Some(Error) => Terminated
236 .map(|error| SendFutureState::Terminated { error })
237 // None => Stopping
238 .unwrap_or(SendFutureState::Stopping))
239 },
240 )
241 }
242
243 /// Sends an epitaph to the underlying transport.
244 ///
245 /// This send ignores the current state of the connection, and does not
246 /// report back any errors encountered while sending.
247 ///
248 /// # Safety
249 ///
250 /// The connection must not be terminated, and the returned future must be
251 /// completed or canceled before the connection is terminated.
252 pub unsafe fn send_epitaph(&self, error: i32) -> SendEpitaphFuture<'_, T> {
253 // SAFETY: The caller has guaranteed that the connection is not
254 // terminated, and will not be terminated until the returned future is
255 // completed or canceled. As long as the connection is not terminated,
256 // `shared` will not be dropped.
257 let shared = unsafe { self.get_shared_unchecked() };
258
259 let mut buffer = T::acquire(shared);
260 buffer.encode_next(MessageHeader::new(0, EPITAPH_ORDINAL, Flexibility::Strict)).unwrap();
261 buffer.encode_next(Epitaph::new(error)).unwrap();
262 let future_state = T::begin_send(shared, buffer);
263
264 SendEpitaphFuture { shared, future_state }
265 }
266
267 /// Returns a new [`RecvFuture`] which receives the next message.
268 ///
269 /// # Safety
270 ///
271 /// The connection must not be terminated, and the returned future must be
272 /// completed or canceled before the connection is terminated.
273 pub unsafe fn recv<'a>(&'a self, exclusive: &'a mut T::Exclusive) -> RecvFuture<'a, T> {
274 // SAFETY: The caller has guaranteed that the connection is not
275 // terminated, and will not be terminated until the returned future is
276 // completed or canceled. As long as the connection is not terminated,
277 // `shared` will not be dropped.
278 let shared = unsafe { self.get_shared_unchecked() };
279 let future_state = T::begin_recv(shared, exclusive);
280
281 RecvFuture { connection: self, exclusive, future_state }
282 }
283
284 /// Stops the connection to wait for termination.
285 ///
286 /// This modifies the behavior of this connection's futures:
287 ///
288 /// - Polled [`SendFutureState`]s will return `Poll::Pending` without
289 /// calling [`poll_send`].
290 /// - Polled [`RecvFuture`]s will call [`poll_recv`], but will return
291 /// `Poll::Ready` with an error when they would normally return
292 /// `Poll::Pending`.
293 ///
294 /// [`poll_send`]: Transport::poll_send
295 /// [`poll_recv`]: Transport::poll_recv
296 pub fn stop(&self) {
297 let prev_state = State(self.state.fetch_or(STOPPING_BIT, Ordering::Relaxed));
298 if !prev_state.is_stopping() {
299 self.stop_waker.wake();
300 }
301 }
302
303 /// Terminates the connection.
304 ///
305 /// This causes this connection's futures to return `Poll::Ready` with an
306 /// error of the given termination reason.
307 ///
308 /// # Safety
309 ///
310 /// `terminate` may only be called once per connection.
311 pub unsafe fn terminate(&self, reason: ProtocolError<T::Error>) {
312 self.termination_reason.with_mut(|termination_reason| {
313 // SAFETY: The caller guaranteed that this is the only time
314 // `terminate` is called on this connection.
315 unsafe {
316 termination_reason.write(MaybeUninit::new(reason));
317 }
318 });
319 let pre_terminate =
320 State(self.state.fetch_or(STOPPING_BIT | TERMINATED_BIT, Ordering::AcqRel));
321
322 // If we set `TERMINATED_BIT` and the refcount was 0, then we need to
323 // drop `shared`.
324 if !pre_terminate.is_terminated() && pre_terminate.refcount() == 0 {
325 self.shared.with_mut(|shared| {
326 // SAFETY: We set `TERMINATED_BIT` while the refcount was 0.
327 unsafe {
328 ManuallyDrop::drop(&mut *shared);
329 }
330 });
331 }
332
333 // Wake all of the futures waiting for a termination reason
334 let wakers = take(&mut *self.termination_wakers.lock().unwrap());
335 for waker in wakers {
336 waker.wake();
337 }
338 }
339}
340
341pub type SendFutureOutput<T> = Result<(), ProtocolError<<T as Transport>::Error>>;
342
343#[pin_project(project = SendFutureStateProj, project_replace = SendFutureStateProjOwn)]
344pub enum SendFutureState<T: Transport> {
345 Running {
346 #[pin]
347 future_state: T::SendFutureState,
348 },
349 Stopping,
350 Terminated {
351 error: ProtocolError<T::Error>,
352 },
353 Waiting {
354 waker_index: usize,
355 },
356 Finished,
357}
358
359impl<T: Transport> SendFutureState<T> {
360 fn register_termination_waker(
361 mut self: Pin<&mut Self>,
362 cx: &mut Context<'_>,
363 connection: &Connection<T>,
364 waker_index: Option<usize>,
365 ) -> Poll<SendFutureOutput<T>> {
366 let mut wakers = connection.termination_wakers.lock().unwrap();
367
368 // Re-check the state now that we're holding the lock again. This
369 // prevents us from adding wakers after termination (which would "leak"
370 // them).
371 if let Some(termination_reason) = connection.get_termination_reason() {
372 Poll::Ready(Err(termination_reason))
373 } else {
374 let waker = cx.waker().clone();
375 if let Some(waker_index) = waker_index {
376 // Overwrite an existing waker
377 let old_waker = replace(&mut wakers[waker_index], waker);
378
379 // Drop the old waker outside of the mutex lock
380 drop(wakers);
381 drop(old_waker);
382 } else {
383 // Insert a new waker
384 let waker_index = wakers.len();
385 wakers.push(waker);
386
387 // Update the state outside of the mutex lock. If we were
388 // running then a `T::SendFutureState` may be dropped.
389 drop(wakers);
390 self.set(SendFutureState::Waiting { waker_index });
391 }
392 Poll::Pending
393 }
394 }
395
396 pub fn poll_send(
397 mut self: Pin<&mut Self>,
398 cx: &mut Context<'_>,
399 connection: &Connection<T>,
400 ) -> Poll<SendFutureOutput<T>> {
401 match self.as_mut().project() {
402 SendFutureStateProj::Running { future_state } => {
403 let result = connection.with_shared(
404 |shared| {
405 T::poll_send(future_state, cx, shared)
406 // `Err(Some(error))` =>
407 // `Err(Some(TransportError(error)))`
408 .map_err(|error| error.map(ProtocolError::TransportError))
409 },
410 |error| Poll::Ready(Err(error)),
411 );
412
413 let result = match result {
414 Poll::Pending => Poll::Pending,
415 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
416 Poll::Ready(Err(None)) => {
417 self.as_mut().register_termination_waker(cx, connection, None)
418 }
419 Poll::Ready(Err(Some(error))) => Poll::Ready(Err(error)),
420 };
421
422 if result.is_ready() {
423 self.set(Self::Finished);
424 }
425
426 result
427 }
428 SendFutureStateProj::Stopping => self.register_termination_waker(cx, connection, None),
429 SendFutureStateProj::Terminated { .. } => {
430 let state = self.project_replace(Self::Finished);
431 let SendFutureStateProjOwn::Terminated { error } = state else {
432 // SAFETY: We just checked that our state is Terminated.
433 unsafe { unreachable_unchecked() }
434 };
435 Poll::Ready(Err(error))
436 }
437 SendFutureStateProj::Waiting { waker_index } => {
438 let waker_index = *waker_index;
439 self.register_termination_waker(cx, connection, Some(waker_index))
440 }
441 SendFutureStateProj::Finished => {
442 panic!("SendFuture polled after returning `Poll::Ready`")
443 }
444 }
445 }
446
447 pub fn send_immediately(self, connection: &Connection<T>) -> SendFutureOutput<T>
448 where
449 T: NonBlockingTransport,
450 {
451 match self {
452 SendFutureState::Running { mut future_state } => {
453 connection.with_shared(
454 |shared| {
455 // Connection is running, try to send immediately.
456 T::send_immediately(&mut future_state, shared).map_err(|e| {
457 // Immediate send failed:
458 // - `None` => `PeerClosed`
459 // - `Some(T::Error)` => `TransportError(T::Error)`
460 e.map_or(ProtocolError::PeerClosed, ProtocolError::TransportError)
461 })
462 },
463 // Getting shared failed, but we may have a termination
464 // reason. If we don't have one, return `Stopped`.
465 |error| Err(error.unwrap_or(ProtocolError::Stopped)),
466 )
467 }
468 SendFutureState::Stopping | SendFutureState::Waiting { waker_index: _ } => {
469 // Try to get the termination reason. If we don't have one yet,
470 // return `Stopped`.
471 Err(connection.get_termination_reason().unwrap_or(ProtocolError::Stopped))
472 }
473 SendFutureState::Terminated { error } => Err(error),
474 SendFutureState::Finished => panic!("SendFuture polled after returning `Poll::Ready`"),
475 }
476 }
477}
478
479/// A future which sends an encoded message to a connection.
480#[must_use = "futures do nothing unless polled"]
481#[pin_project]
482pub struct SendFuture<'a, T: Transport> {
483 connection: &'a Connection<T>,
484 #[pin]
485 state: SendFutureState<T>,
486}
487
488impl<'a, T: Transport> SendFuture<'a, T> {
489 /// Creates a `SendFuture` from its raw connection reference and state.
490 pub fn from_raw_parts(connection: &'a Connection<T>, state: SendFutureState<T>) -> Self {
491 Self { connection, state }
492 }
493}
494
495impl<T: NonBlockingTransport> SendFuture<'_, T> {
496 /// Completes the send operation synchronously and without blocking.
497 ///
498 /// Using this method prevents transports from applying backpressure. Prefer
499 /// awaiting when possible to allow for backpressure.
500 ///
501 /// Because failed sends return immediately, `send_immediately` may observe
502 /// transport closure prematurely. This can manifest as this method
503 /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
504 /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
505 /// correctness.
506 pub fn send_immediately(self) -> SendFutureOutput<T> {
507 self.state.send_immediately(self.connection)
508 }
509}
510
511impl<T: Transport> Future for SendFuture<'_, T> {
512 type Output = SendFutureOutput<T>;
513
514 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
515 let this = self.project();
516 this.state.poll_send(cx, this.connection)
517 }
518}
519
520#[pin_project]
521pub struct SendEpitaphFuture<'a, T: Transport> {
522 shared: &'a T::Shared,
523 #[pin]
524 future_state: T::SendFutureState,
525}
526
527impl<T: Transport> Future for SendEpitaphFuture<'_, T> {
528 type Output = Result<(), Option<T::Error>>;
529
530 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
531 let this = self.project();
532 T::poll_send(this.future_state, cx, this.shared)
533 }
534}
535
536/// A future which receives an encoded message over the transport.
537#[must_use = "futures do nothing unless polled"]
538#[pin_project]
539pub struct RecvFuture<'a, T: Transport> {
540 connection: &'a Connection<T>,
541 exclusive: &'a mut T::Exclusive,
542 #[pin]
543 future_state: T::RecvFutureState,
544}
545
546impl<T: Transport> Future for RecvFuture<'_, T> {
547 type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
548
549 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
550 let this = self.project();
551
552 // SAFETY: This future is created by `Connection::recv`. The connection
553 // will not be terminated until this is completed or canceled, and so
554 // `shared` will not be dropped.
555 let shared = unsafe { this.connection.get_shared_unchecked() };
556
557 let termination_reason = match T::poll_recv(this.future_state, cx, shared, this.exclusive) {
558 Poll::Pending => {
559 // Receive didn't complete, register waker before
560 // re-checking state.
561 this.connection.stop_waker.register_by_ref(cx.waker());
562 let state = State(this.connection.state.load(Ordering::Relaxed));
563 if state.is_stopping() {
564 // The connection is stopping. Return an error that the
565 // connection has been stopped.
566 ProtocolError::Stopped
567 } else {
568 // Still running, we'll get polled again later.
569 return Poll::Pending;
570 }
571 }
572
573 // Receive succeeded.
574 Poll::Ready(Ok(buffer)) => return Poll::Ready(Ok(buffer)),
575
576 // Normal failure: return peer closed error.
577 Poll::Ready(Err(None)) => ProtocolError::PeerClosed,
578
579 // Abnormal failure: return transport error.
580 Poll::Ready(Err(Some(error))) => ProtocolError::TransportError(error),
581 };
582
583 Poll::Ready(Err(termination_reason))
584 }
585}