fidl_next_protocol/endpoints/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::future::Future;
6use core::mem::{ManuallyDrop, MaybeUninit, replace, take};
7use core::pin::Pin;
8use core::task::{Context, Poll, Waker};
9
10use fidl_next_codec::EncodeError;
11use pin_project::pin_project;
12
13use crate::concurrency::cell::UnsafeCell;
14use crate::concurrency::future::AtomicWaker;
15use crate::concurrency::hint::unreachable_unchecked;
16use crate::concurrency::sync::Mutex;
17use crate::concurrency::sync::atomic::{AtomicUsize, Ordering};
18use crate::{
19    Flexibility, NonBlockingTransport, ProtocolError, Transport, encode_epitaph, encode_header,
20};
21
22pub const ORDINAL_EPITAPH: u64 = 0xffff_ffff_ffff_ffff;
23
24// Indicates that the connection has been requested to stop. Connections are
25// always stopped as they are terminated.
26const STOPPING_BIT: usize = 1 << 0;
27// Indicates that the connection has been provided a termination reason.
28const TERMINATED_BIT: usize = 1 << 1;
29const BITS_COUNT: usize = 2;
30
31// Each refcount represents a thread which is attempting to access the shared
32// part of the transport.
33const REFCOUNT: usize = 1 << BITS_COUNT;
34
35#[derive(Clone, Copy)]
36struct State(usize);
37
38impl State {
39    fn is_stopping(self) -> bool {
40        self.0 & STOPPING_BIT != 0
41    }
42
43    fn is_terminated(self) -> bool {
44        self.0 & TERMINATED_BIT != 0
45    }
46
47    fn refcount(self) -> usize {
48        self.0 >> BITS_COUNT
49    }
50}
51
52/// A wrapper around a transport which connectivity semantics.
53///
54/// The [`Transport`] trait only provides the bare minimum API surface required
55/// to send and receive data. On top of that, FIDL requires that clients and
56/// servers respect additional messaging semantics. Those semantics are provided
57/// by [`Connection`]:
58///
59/// - `Transport`s are difficult to close because they may be accessed from
60///   several threads simultaneously. `Connection`s provide a mechanism for
61///   gracefully closing transports by causing all sends to pend until the
62///   connection is terminated, and all receives to fail instead of pend.
63/// - FIDL connections may send and receive an epitaph as the final message
64///   before the underlying transport is closed. This epitaph should be provided
65///   to all sends when they fail, which requires additional coordination.
66pub struct Connection<T: Transport> {
67    // The lowest `BITS_COUNT` of this field contain flags indicating the
68    // current state of the transport. The remainder of the upper bits contain
69    // the number of threads attempting to access the `shared` field.
70    state: AtomicUsize,
71    // A thread will drop `shared` if:
72    //
73    // - the connection is dropped before being terminated, or
74    // - it set `TERMINATED_BIT` while the refcount was 0, or
75    // - it decremented the refcount to 0 while `TERMINATED_BIT` was set.
76    //
77    // These cases are handled by `drop`, `terminate`, and `with_shared`
78    // respectively.
79    shared: UnsafeCell<ManuallyDrop<T::Shared>>,
80    stop_waker: AtomicWaker,
81    // TODO: switch this to intrusive linked list in send futures
82    termination_wakers: Mutex<Vec<Waker>>,
83    // Initialized if `TERMINATED_BIT` is set.
84    termination_reason: UnsafeCell<MaybeUninit<ProtocolError<T::Error>>>,
85}
86
87unsafe impl<T: Transport> Send for Connection<T> {}
88unsafe impl<T: Transport> Sync for Connection<T> {}
89
90impl<T: Transport> Drop for Connection<T> {
91    fn drop(&mut self) {
92        self.state.with_mut(|state| {
93            let state = State(*state);
94
95            if !state.is_terminated() {
96                self.shared.with_mut(|shared| {
97                    // SAFETY: The connection was not terminated before being
98                    // dropped, so `shared` has not yet been dropped.
99                    unsafe {
100                        ManuallyDrop::drop(&mut *shared);
101                    }
102                });
103            } else {
104                self.termination_reason.with_mut(|termination_reason| {
105                    // SAFETY: The connection was terminated before being
106                    // dropped, so `termination_reason` is initialized.
107                    unsafe {
108                        MaybeUninit::assume_init_drop(&mut *termination_reason);
109                    }
110                });
111            }
112        });
113    }
114}
115
116impl<T: Transport> Connection<T> {
117    /// Creates a new connection from the shared part of a transport.
118    pub fn new(shared: T::Shared) -> Self {
119        Self {
120            state: AtomicUsize::new(0),
121            shared: UnsafeCell::new(ManuallyDrop::new(shared)),
122            stop_waker: AtomicWaker::new(),
123            termination_wakers: Mutex::new(Vec::new()),
124            termination_reason: UnsafeCell::new(MaybeUninit::uninit()),
125        }
126    }
127
128    /// # Safety
129    ///
130    /// This thread must have loaded `state` with at least `Ordering::Acquire`
131    /// and observed that `TERMINATED_BIT` was set.
132    unsafe fn get_termination_reason_unchecked(&self) -> ProtocolError<T::Error> {
133        self.termination_reason.with(|termination_reason| {
134            // SAFETY: The caller guaranteed that `state` was loaded with at
135            // least `Ordering::Acquire` ordering and observed that
136            // `TERMINATED_BIT` was set.
137            unsafe { MaybeUninit::assume_init_ref(&*termination_reason).clone() }
138        })
139    }
140
141    /// Returns the termination reason for the connection, if any.
142    pub fn get_termination_reason(&self) -> Option<ProtocolError<T::Error>> {
143        if State(self.state.load(Ordering::Acquire)).is_terminated() {
144            // SAFETY: We loaded the state with `Ordering::Acquire` and observed
145            // that `TERMINATED_BIT` was set.
146            unsafe { Some(self.get_termination_reason_unchecked()) }
147        } else {
148            None
149        }
150    }
151
152    /// # Safety
153    ///
154    /// `shared` must not have been dropped. See the documentation on `shared`
155    /// for acceptable criteria.
156    unsafe fn get_shared_unchecked(&self) -> &T::Shared {
157        self.shared.with(|shared| {
158            // SAFETY: The caller guaranteed that `shared` has not been dropped.
159            unsafe { &*shared }
160        })
161    }
162
163    fn with_shared<U>(
164        &self,
165        success: impl FnOnce(&T::Shared) -> U,
166        failure: impl FnOnce(Option<ProtocolError<T::Error>>) -> U,
167    ) -> U {
168        let pre_increment = State(self.state.fetch_add(REFCOUNT, Ordering::Acquire));
169
170        // After the refcount drops to zero (and `shared` is dropped), threads
171        // may still increment and decrement the refcount to attempt to read it.
172        // To avoid dropping `shared` more than once, we prevent the refcount
173        // from being decremented to 0 more than once after `TERMINATED_BIT` is
174        // set.
175        //
176        // We do this by having each thread check whether its increment changed
177        // the refcount from 0 to 1 while `TERMINATED_BIT` was set. If it did,
178        // the thread will not decrement that refcount, leaving it "dangling"
179        // instead. This ensures that the refcount never falls below 1 again.
180        if pre_increment.is_terminated() && pre_increment.refcount() == 0 {
181            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
182            // that `TERMINATED_BIT` was set.
183            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
184            return failure(Some(termination_reason));
185        }
186
187        let mut success_result = None;
188        if !pre_increment.is_stopping() {
189            // SAFETY: Termination always sets `STOPPING_BIT`. We incremented
190            // the refcount while `STOPPING_BIT` was not set, so `shared` won't
191            // be dropped until we decrement our refcount.
192            let shared = unsafe { self.get_shared_unchecked() };
193            success_result = Some(success(shared));
194        }
195
196        let pre_decrement = State(self.state.fetch_sub(REFCOUNT, Ordering::AcqRel));
197
198        if !pre_decrement.is_stopping() {
199            success_result.unwrap()
200        } else if !pre_decrement.is_terminated() {
201            failure(None)
202        } else {
203            // The connection is terminated. If we decremented the refcount to
204            // 0, then we need to drop `shared`.
205            if pre_decrement.refcount() == 1 {
206                self.shared.with_mut(|shared| {
207                    // SAFETY: We decremented the refcount to 0 while
208                    // `TERMINATED_BIT` was set.
209                    unsafe {
210                        ManuallyDrop::drop(&mut *shared);
211                    }
212                });
213            }
214
215            // SAFETY: We loaded `state` with `Ordering::Acquire` and observed
216            // that `TERMINATED_BIT` was set.
217            let termination_reason = unsafe { self.get_termination_reason_unchecked() };
218            failure(Some(termination_reason))
219        }
220    }
221
222    /// Sends a message to the underlying transport.
223    ///
224    /// Returns a `SendFutureState` which can be polled to completion.
225    pub fn send_message_raw(
226        &self,
227        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
228    ) -> Result<SendFutureState<T>, EncodeError> {
229        self.with_shared(
230            |shared| {
231                let mut buffer = T::acquire(shared);
232                f(&mut buffer)?;
233                Ok(SendFutureState::Running { future_state: T::begin_send(shared, buffer) })
234            },
235            |error| {
236                Ok(error
237                    // Some(Error) => Terminated
238                    .map(|error| SendFutureState::Terminated { error })
239                    // None => Stopping
240                    .unwrap_or(SendFutureState::Stopping))
241            },
242        )
243    }
244
245    /// Sends a message to the underlying transport.
246    pub fn send_message(
247        &self,
248        f: impl FnOnce(&mut T::SendBuffer) -> Result<(), EncodeError>,
249    ) -> Result<SendFuture<'_, T>, EncodeError> {
250        Ok(SendFuture { connection: self, state: self.send_message_raw(f)? })
251    }
252
253    /// Sends an epitaph to the underlying transport.
254    ///
255    /// This send ignores the current state of the connection, and does not
256    /// report back any errors encountered while sending.
257    ///
258    /// # Safety
259    ///
260    /// The connection must not be terminated, and the returned future must be
261    /// completed or canceled before the connection is terminated.
262    pub unsafe fn send_epitaph(&self, error: i32) -> SendEpitaphFuture<'_, T> {
263        // SAFETY: The caller has guaranteed that the connection is not
264        // terminated, and will not be terminated until the returned future is
265        // completed or canceled. As long as the connection is not terminated,
266        // `shared` will not be dropped.
267        let shared = unsafe { self.get_shared_unchecked() };
268
269        let mut buffer = T::acquire(shared);
270        encode_header::<T>(&mut buffer, 0, ORDINAL_EPITAPH, Flexibility::Strict).unwrap();
271        encode_epitaph::<T>(&mut buffer, error).unwrap();
272        let future_state = T::begin_send(shared, buffer);
273
274        SendEpitaphFuture { shared, future_state }
275    }
276
277    /// Returns a new [`RecvFuture`] which receives the next message.
278    ///
279    /// # Safety
280    ///
281    /// The connection must not be terminated, and the returned future must be
282    /// completed or canceled before the connection is terminated.
283    pub unsafe fn recv<'a>(&'a self, exclusive: &'a mut T::Exclusive) -> RecvFuture<'a, T> {
284        // SAFETY: The caller has guaranteed that the connection is not
285        // terminated, and will not be terminated until the returned future is
286        // completed or canceled. As long as the connection is not terminated,
287        // `shared` will not be dropped.
288        let shared = unsafe { self.get_shared_unchecked() };
289        let future_state = T::begin_recv(shared, exclusive);
290
291        RecvFuture { connection: self, exclusive, future_state }
292    }
293
294    /// Stops the connection to wait for termination.
295    ///
296    /// This modifies the behavior of this connection's futures:
297    ///
298    /// - Polled [`SendFutureState`]s will return `Poll::Pending` without
299    ///   calling [`poll_send`].
300    /// - Polled [`RecvFuture`]s will call [`poll_recv`], but will return
301    ///   `Poll::Ready` with an error when they would normally return
302    ///   `Poll::Pending`.
303    ///
304    /// [`poll_send`]: Transport::poll_send
305    /// [`poll_recv`]: Transport::poll_recv
306    pub fn stop(&self) {
307        let prev_state = State(self.state.fetch_or(STOPPING_BIT, Ordering::Relaxed));
308        if !prev_state.is_stopping() {
309            self.stop_waker.wake();
310        }
311    }
312
313    /// Terminates the connection.
314    ///
315    /// This causes this connection's futures to return `Poll::Ready` with an
316    /// error of the given termination reason.
317    ///
318    /// # Safety
319    ///
320    /// `terminate` may only be called once per connection.
321    pub unsafe fn terminate(&self, reason: ProtocolError<T::Error>) {
322        self.termination_reason.with_mut(|termination_reason| {
323            // SAFETY: The caller guaranteed that this is the only time
324            // `terminate` is called on this connection.
325            unsafe {
326                termination_reason.write(MaybeUninit::new(reason));
327            }
328        });
329        let pre_terminate =
330            State(self.state.fetch_or(STOPPING_BIT | TERMINATED_BIT, Ordering::AcqRel));
331
332        // If we set `TERMINATED_BIT` and the refcount was 0, then we need to
333        // drop `shared`.
334        if !pre_terminate.is_terminated() && pre_terminate.refcount() == 0 {
335            self.shared.with_mut(|shared| {
336                // SAFETY: We set `TERMINATED_BIT` while the refcount was 0.
337                unsafe {
338                    ManuallyDrop::drop(&mut *shared);
339                }
340            });
341        }
342
343        // Wake all of the futures waiting for a termination reason
344        let wakers = take(&mut *self.termination_wakers.lock().unwrap());
345        for waker in wakers {
346            waker.wake();
347        }
348    }
349}
350
351pub type SendFutureOutput<T> = Result<(), ProtocolError<<T as Transport>::Error>>;
352
353#[pin_project(project = SendFutureStateProj, project_replace = SendFutureStateProjOwn)]
354pub enum SendFutureState<T: Transport> {
355    Running {
356        #[pin]
357        future_state: T::SendFutureState,
358    },
359    Stopping,
360    Terminated {
361        error: ProtocolError<T::Error>,
362    },
363    Waiting {
364        waker_index: usize,
365    },
366    Finished,
367}
368
369impl<T: Transport> SendFutureState<T> {
370    fn register_termination_waker(
371        mut self: Pin<&mut Self>,
372        cx: &mut Context<'_>,
373        connection: &Connection<T>,
374        waker_index: Option<usize>,
375    ) -> Poll<SendFutureOutput<T>> {
376        let mut wakers = connection.termination_wakers.lock().unwrap();
377
378        // Re-check the state now that we're holding the lock again. This
379        // prevents us from adding wakers after termination (which would "leak"
380        // them).
381        if let Some(termination_reason) = connection.get_termination_reason() {
382            Poll::Ready(Err(termination_reason))
383        } else {
384            let waker = cx.waker().clone();
385            if let Some(waker_index) = waker_index {
386                // Overwrite an existing waker
387                let old_waker = replace(&mut wakers[waker_index], waker);
388
389                // Drop the old waker outside of the mutex lock
390                drop(wakers);
391                drop(old_waker);
392            } else {
393                // Insert a new waker
394                let waker_index = wakers.len();
395                wakers.push(waker);
396
397                // Update the state outside of the mutex lock. If we were
398                // running then a `T::SendFutureState` may be dropped.
399                drop(wakers);
400                self.set(SendFutureState::Waiting { waker_index });
401            }
402            Poll::Pending
403        }
404    }
405
406    pub fn poll_send(
407        mut self: Pin<&mut Self>,
408        cx: &mut Context<'_>,
409        connection: &Connection<T>,
410    ) -> Poll<SendFutureOutput<T>> {
411        match self.as_mut().project() {
412            SendFutureStateProj::Running { future_state } => {
413                let result = connection.with_shared(
414                    |shared| {
415                        T::poll_send(future_state, cx, shared)
416                            // `Err(Some(error))` =>
417                            //   `Err(Some(TransportError(error)))`
418                            .map_err(|error| error.map(ProtocolError::TransportError))
419                    },
420                    |error| Poll::Ready(Err(error)),
421                );
422
423                let result = match result {
424                    Poll::Pending => Poll::Pending,
425                    Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
426                    Poll::Ready(Err(None)) => {
427                        self.as_mut().register_termination_waker(cx, connection, None)
428                    }
429                    Poll::Ready(Err(Some(error))) => Poll::Ready(Err(error)),
430                };
431
432                if result.is_ready() {
433                    self.set(Self::Finished);
434                }
435
436                result
437            }
438            SendFutureStateProj::Stopping => self.register_termination_waker(cx, connection, None),
439            SendFutureStateProj::Terminated { .. } => {
440                let state = self.project_replace(Self::Finished);
441                let SendFutureStateProjOwn::Terminated { error } = state else {
442                    // SAFETY: We just checked that our state is Terminated.
443                    unsafe { unreachable_unchecked() }
444                };
445                Poll::Ready(Err(error))
446            }
447            SendFutureStateProj::Waiting { waker_index } => {
448                let waker_index = *waker_index;
449                self.register_termination_waker(cx, connection, Some(waker_index))
450            }
451            SendFutureStateProj::Finished => {
452                panic!("SendFuture polled after returning `Poll::Ready`")
453            }
454        }
455    }
456
457    pub fn send_immediately(self, connection: &Connection<T>) -> SendFutureOutput<T>
458    where
459        T: NonBlockingTransport,
460    {
461        match self {
462            SendFutureState::Running { mut future_state } => {
463                connection.with_shared(
464                    |shared| {
465                        // Connection is running, try to send immediately.
466                        T::send_immediately(&mut future_state, shared).map_err(|e| {
467                            // Immediate send failed:
468                            // - `None` => `PeerClosed`
469                            // - `Some(T::Error)` => `TransportError(T::Error)`
470                            e.map_or(ProtocolError::PeerClosed, ProtocolError::TransportError)
471                        })
472                    },
473                    // Getting shared failed, but we may have a termination
474                    // reason. If we don't have one, return `Stopped`.
475                    |error| Err(error.unwrap_or(ProtocolError::Stopped)),
476                )
477            }
478            SendFutureState::Stopping | SendFutureState::Waiting { waker_index: _ } => {
479                // Try to get the termination reason. If we don't have one yet,
480                // return `Stopped`.
481                Err(connection.get_termination_reason().unwrap_or(ProtocolError::Stopped))
482            }
483            SendFutureState::Terminated { error } => Err(error),
484            SendFutureState::Finished => panic!("SendFuture polled after returning `Poll::Ready`"),
485        }
486    }
487}
488
489/// A future which sends an encoded message to a connection.
490#[must_use = "futures do nothing unless polled"]
491#[pin_project]
492pub struct SendFuture<'a, T: Transport> {
493    connection: &'a Connection<T>,
494    #[pin]
495    state: SendFutureState<T>,
496}
497
498impl<T: NonBlockingTransport> SendFuture<'_, T> {
499    /// Completes the send operation synchronously and without blocking.
500    ///
501    /// Using this method prevents transports from applying backpressure. Prefer
502    /// awaiting when possible to allow for backpressure.
503    ///
504    /// Because failed sends return immediately, `send_immediately` may observe
505    /// transport closure prematurely. This can manifest as this method
506    /// returning `Err(PeerClosed)` or `Err(Stopped)` when it should have
507    /// returned `Err(PeerClosedWithEpitaph)`. Prefer awaiting when possible for
508    /// correctness.
509    pub fn send_immediately(self) -> SendFutureOutput<T> {
510        self.state.send_immediately(self.connection)
511    }
512}
513
514impl<T: Transport> Future for SendFuture<'_, T> {
515    type Output = SendFutureOutput<T>;
516
517    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
518        let this = self.project();
519        this.state.poll_send(cx, this.connection)
520    }
521}
522
523#[pin_project]
524pub struct SendEpitaphFuture<'a, T: Transport> {
525    shared: &'a T::Shared,
526    #[pin]
527    future_state: T::SendFutureState,
528}
529
530impl<T: Transport> Future for SendEpitaphFuture<'_, T> {
531    type Output = Result<(), Option<T::Error>>;
532
533    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
534        let this = self.project();
535        T::poll_send(this.future_state, cx, this.shared)
536    }
537}
538
539/// A future which receives an encoded message over the transport.
540#[must_use = "futures do nothing unless polled"]
541#[pin_project]
542pub struct RecvFuture<'a, T: Transport> {
543    connection: &'a Connection<T>,
544    exclusive: &'a mut T::Exclusive,
545    #[pin]
546    future_state: T::RecvFutureState,
547}
548
549impl<T: Transport> Future for RecvFuture<'_, T> {
550    type Output = Result<T::RecvBuffer, ProtocolError<T::Error>>;
551
552    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
553        let this = self.project();
554
555        // SAFETY: This future is created by `Connection::recv`. The connection
556        // will not be terminated until this is completed or canceled, and so
557        // `shared` will not be dropped.
558        let shared = unsafe { this.connection.get_shared_unchecked() };
559
560        let termination_reason = match T::poll_recv(this.future_state, cx, shared, this.exclusive) {
561            Poll::Pending => {
562                // Receive didn't complete, register waker before
563                // re-checking state.
564                this.connection.stop_waker.register_by_ref(cx.waker());
565                let state = State(this.connection.state.load(Ordering::Relaxed));
566                if state.is_stopping() {
567                    // The connection is stopping. Return an error that the
568                    // connection has been stopped.
569                    ProtocolError::Stopped
570                } else {
571                    // Still running, we'll get polled again later.
572                    return Poll::Pending;
573                }
574            }
575
576            // Receive succeeded.
577            Poll::Ready(Ok(buffer)) => return Poll::Ready(Ok(buffer)),
578
579            // Normal failure: return peer closed error.
580            Poll::Ready(Err(None)) => ProtocolError::PeerClosed,
581
582            // Abnormal failure: return transport error.
583            Poll::Ready(Err(Some(error))) => ProtocolError::TransportError(error),
584        };
585
586        Poll::Ready(Err(termination_reason))
587    }
588}