usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23    Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
24};
25
26mod overflow_writer;
27mod pause_state;
28
29use overflow_writer::OverflowWriter;
30use pause_state::PauseState;
31
32/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
33pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
34impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35
36#[derive(Copy, Clone, PartialEq, Eq)]
37enum PausePacket {
38    Pause,
39    UnPause,
40}
41
42impl PausePacket {
43    fn bytes(&self) -> [u8; 1] {
44        match self {
45            PausePacket::Pause => [1],
46            PausePacket::UnPause => [0],
47        }
48    }
49}
50
51/// A connection that has been established with the other end and now just needs
52/// a socket to start transmitting.
53pub struct ReadyConnect<B, S> {
54    connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
55    packet_filler: Arc<UsbPacketFiller<B>>,
56    address: Address,
57}
58
59impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
60    /// Finish establishing the connection by providing a socket for data transfer.
61    pub async fn finish_connect(self, socket: S) {
62        let (read_socket, write_socket) = socket.split();
63        let writer = {
64            let conns = self.connections.lock();
65            let Some(conn) = conns.get(&self.address) else {
66                warn!("Connection state was missing after connection success!");
67                return;
68            };
69            let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
70                &conn.state
71            else {
72                warn!("Connection state was invalid after connection success!");
73                return;
74            };
75            reader_scope.spawn(Connection::<B, S>::run_socket(
76                read_socket,
77                self.address,
78                self.packet_filler,
79                Arc::clone(pause_state),
80            ));
81            Arc::clone(writer)
82        };
83        let mut writer = writer.lock().await;
84        let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
85            &mut *writer,
86            ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
87        ) else {
88            unreachable!("Connection completed multiple times!")
89        };
90
91        wakers.into_iter().for_each(Waker::wake);
92    }
93}
94
95/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
96/// transmitted for them.
97///
98/// This implementation aims to be agnostic to both the underlying transport and the buffers used
99/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
100/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
101///
102/// The client of this library will:
103/// - Use methods on this struct to initiate actions like connecting and accepting
104/// connections to the other end.
105/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
106/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
107pub struct Connection<B, S> {
108    control_socket_writer: Option<Mutex<WriteHalf<S>>>,
109    packet_filler: Arc<UsbPacketFiller<B>>,
110    protocol_version: ProtocolVersion,
111    connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
112    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
113    task_scope: Scope,
114}
115
116impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
117    /// Creates a new connection with:
118    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
119    /// between host and device) can be read and written from. If this is `None`
120    /// we will discard control data.
121    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
122    /// connection requests from the other side.
123    pub fn new(
124        protocol_version: ProtocolVersion,
125        control_socket: Option<S>,
126        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
127    ) -> Self {
128        let packet_filler = Arc::new(UsbPacketFiller::default());
129        let connections = Default::default();
130        let task_scope = Scope::new_with_name("vsock_usb");
131        let control_socket_writer = control_socket.map(|control_socket| {
132            let (control_socket_reader, control_socket_writer) = control_socket.split();
133            task_scope.spawn(Self::run_socket(
134                control_socket_reader,
135                Address::default(),
136                packet_filler.clone(),
137                PauseState::new(),
138            ));
139            Mutex::new(control_socket_writer)
140        });
141        Self {
142            control_socket_writer,
143            packet_filler,
144            connections,
145            incoming_requests_tx,
146            protocol_version,
147            task_scope,
148        }
149    }
150
151    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
152        let header = &mut Header::new(PacketType::Finish);
153        header.set_address(address);
154        usb_packet_filler
155            .write_vsock_packet(&Packet { header, payload: &[] })
156            .await
157            .expect("Finish packet should never be too big");
158    }
159
160    async fn run_socket(
161        mut reader: ReadHalf<S>,
162        address: Address,
163        usb_packet_filler: Arc<UsbPacketFiller<B>>,
164        pause_state: Arc<PauseState>,
165    ) {
166        let mut buf = [0; 4096];
167        loop {
168            log::trace!("reading from control socket");
169            let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
170                Ok(0) => {
171                    if !address.is_zeros() {
172                        Self::send_close_packet(&address, &usb_packet_filler).await;
173                    }
174                    return;
175                }
176                Ok(read) => read,
177                Err(err) => {
178                    if address.is_zeros() {
179                        log::error!("Error reading usb socket: {err:?}");
180                    } else {
181                        Self::send_close_packet(&address, &usb_packet_filler).await;
182                    }
183                    return;
184                }
185            };
186            log::trace!("writing {read} bytes to vsock packet");
187            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
188            log::trace!("wrote {read} bytes to vsock packet");
189        }
190    }
191
192    fn set_connection(
193        &self,
194        address: Address,
195        state: VsockConnectionState<S>,
196    ) -> Result<(), Error> {
197        let mut connections = self.connections.lock();
198        if !connections.contains_key(&address) {
199            connections.insert(address.clone(), VsockConnection { _address: address, state });
200            Ok(())
201        } else {
202            Err(Error::other(format!("connection on address {address:?} already set")))
203        }
204    }
205
206    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
207    /// have a distinct target address or payload.
208    pub async fn send_empty_echo(&self) {
209        debug!("Sending empty echo packet");
210        let header = &mut Header::new(PacketType::Echo);
211        self.packet_filler
212            .write_vsock_packet(&Packet { header, payload: &[] })
213            .await
214            .expect("empty echo packet should never be too large to fit in a usb packet");
215    }
216
217    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
218    /// to read and write from. The function will complete when the other end has accepted or
219    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
220    /// for the connection to be closed.
221    pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
222        let (ready, state) = self.connect_late(addr).await?;
223        ready.finish_connect(socket).await;
224        Ok(state)
225    }
226
227    /// Same as [`connect`] but doesn't require the socket to be passed. Instead
228    /// we return a [`ReadyConnect`] which can be given the socket later. This
229    /// shouldn't be deferred very long but it is useful if the socket is
230    /// starting out speaking a different protocol and needs to execute a
231    /// protocol switch, but needs to know the connection status before doing
232    /// that switch.
233    pub async fn connect_late(
234        &self,
235        addr: Address,
236    ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
237        let (connected_tx, connected_rx) = oneshot::channel();
238
239        self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
240
241        let header = &mut Header::new(PacketType::Connect);
242        header.set_address(&addr);
243        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
244        let Ok(conn_state) = connected_rx.await else {
245            return Err(Error::other("Accept was never received for {addr:?}"));
246        };
247
248        Ok((
249            ReadyConnect {
250                connections: Arc::clone(&self.connections),
251                packet_filler: Arc::clone(&self.packet_filler),
252                address: addr,
253            },
254            conn_state,
255        ))
256    }
257
258    /// Sends a request for the other end to close the connection.
259    pub async fn close(&self, address: &Address) {
260        Self::send_close_packet(address, &self.packet_filler).await
261    }
262
263    /// Resets the named connection without going through a close request.
264    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
265        reset(address, &self.connections, &self.packet_filler).await
266    }
267
268    /// Accepts a connection for which an outstanding connection request has been made, and
269    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
270    /// can be used to wait for the connection to be closed.
271    pub async fn accept(
272        &self,
273        request: ConnectionRequest,
274        socket: S,
275    ) -> Result<ConnectionState, Error> {
276        let (ready, state) = self.accept_late(request).await?;
277        ready.finish_connect(socket).await;
278        Ok(state)
279    }
280
281    /// Accepts a connection for which an outstanding connection request has been made, and
282    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
283    /// can be used to wait for the connection to be closed.
284    pub async fn accept_late(
285        &self,
286        request: ConnectionRequest,
287    ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
288        let address = request.address;
289        let notify_closed_rx;
290        if let Some(conn) = self.connections.lock().get_mut(&address) {
291            let VsockConnectionState::ConnectingIncoming = &conn.state else {
292                return Err(Error::other(format!(
293                    "Attempted to accept connection that was not waiting at {address:?}"
294                )));
295            };
296
297            let notify_closed = mpsc::channel(2);
298            notify_closed_rx = notify_closed.1;
299            let notify_closed = notify_closed.0;
300            let pause_state = PauseState::new();
301
302            let reader_scope = Scope::new_with_name("connection-reader");
303
304            conn.state = VsockConnectionState::Connected {
305                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
306                reader_scope,
307                notify_closed,
308                pause_state,
309            };
310        } else {
311            return Err(Error::other(format!(
312                "Attempting to accept connection that did not exist at {address:?}"
313            )));
314        }
315        let header = &mut Header::new(PacketType::Accept);
316        header.set_address(&address);
317        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
318        Ok((
319            ReadyConnect {
320                connections: Arc::clone(&self.connections),
321                packet_filler: Arc::clone(&self.packet_filler),
322                address,
323            },
324            ConnectionState(notify_closed_rx),
325        ))
326    }
327
328    /// Rejects a pending connection request from the other side.
329    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
330        let address = request.address;
331        match self.connections.lock().entry(address.clone()) {
332            Entry::Occupied(entry) => {
333                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
334                    return Err(Error::other(format!(
335                        "Attempted to reject connection that was not waiting at {address:?}"
336                    )));
337                };
338                entry.remove();
339            }
340            Entry::Vacant(_) => {
341                return Err(Error::other(format!(
342                    "Attempted to reject connection that was not waiting at {address:?}"
343                )));
344            }
345        }
346
347        let header = &mut Header::new(PacketType::Reset);
348        header.set_address(&address);
349        self.packet_filler
350            .write_vsock_packet(&Packet { header, payload: &[] })
351            .await
352            .expect("accept packet should never be too large for packet buffer");
353        Ok(())
354    }
355
356    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
357        // all zero data packets go to the control channel
358        if address.is_zeros() {
359            if let Some(writer) = self.control_socket_writer.as_ref() {
360                writer.lock().await.write_all(payload).await?;
361            } else {
362                trace!("Discarding {} bytes of data sent to control socket", payload.len());
363            }
364            Ok(())
365        } else {
366            let payload_socket;
367            if let Some(conn) = self.connections.lock().get_mut(&address) {
368                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
369                    warn!(
370                        "Received data packet for connection in unexpected state for {address:?}"
371                    );
372                    return Ok(());
373                };
374                payload_socket = writer.clone();
375            } else {
376                warn!("Received data packet for connection that didn't exist at {address:?}");
377                return Ok(());
378            }
379            let mut socket_guard =
380                ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
381            let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
382                unreachable!("wait_available didn't wait until socket was available!");
383            };
384            match socket.write_all(payload) {
385                Err(err) => {
386                    debug!(
387                        "Write to socket address {address:?} failed, \
388                         resetting connection immediately: {err:?}"
389                    );
390                    self.reset(&address)
391                        .await
392                        .inspect_err(|err| {
393                            warn!(
394                                "Attempt to reset connection to {address:?} \
395                                   failed after write error: {err:?}"
396                            );
397                        })
398                        .ok();
399                }
400                Ok(status) => {
401                    if status.overflowed() {
402                        if self.protocol_version.has_pause_packets() {
403                            let header = &mut Header::new(PacketType::Pause);
404                            let payload = &PausePacket::Pause.bytes();
405                            header.set_address(&address);
406                            header.payload_len.set(payload.len() as u32);
407                            self.packet_filler
408                                .write_vsock_packet(&Packet { header, payload })
409                                .await
410                                .expect(
411                                    "pause packet should never be too large to fit in a usb packet",
412                                );
413                        }
414
415                        let weak_payload_socket = Arc::downgrade(&payload_socket);
416                        let connections = Arc::clone(&self.connections);
417                        let has_pause_packets = self.protocol_version.has_pause_packets();
418                        let packet_filler = Arc::clone(&self.packet_filler);
419                        self.task_scope.spawn(async move {
420                            let res = OverflowHandleFut::new(weak_payload_socket).await;
421
422                            if let Err(err) = res {
423                                debug!(
424                                    "Write to socket address {address:?} failed while \
425                                     processing backlog, resetting connection at next poll: {err:?}"
426                                );
427                                if let Err(err) = reset(&address, &connections, &packet_filler).await {
428                                    debug!("Error sending reset frame after overflow write failed: {err:?}");
429                                }
430                            } else if has_pause_packets {
431                                let header = &mut Header::new(PacketType::Pause);
432                                let payload = &PausePacket::UnPause.bytes();
433                                header.set_address(&address);
434                                header.payload_len.set(payload.len() as u32);
435                                packet_filler
436                                    .write_vsock_packet(&Packet { header, payload })
437                                    .await
438                                    .expect("pause packet should never be too large to fit in a usb packet");
439                            }
440                        });
441                    }
442                }
443            }
444            Ok(())
445        }
446    }
447
448    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
449        debug!("received echo for {address:?} with payload {payload:?}");
450        let header = &mut Header::new(PacketType::EchoReply);
451        header.payload_len.set(payload.len() as u32);
452        header.set_address(&address);
453        self.packet_filler
454            .write_vsock_packet(&Packet { header, payload })
455            .await
456            .map_err(|_| Error::other("Echo packet was too large to be sent back"))
457    }
458
459    async fn handle_echo_reply_packet(
460        &self,
461        address: Address,
462        payload: &[u8],
463    ) -> Result<(), Error> {
464        // ignore but log replies
465        debug!("received echo reply for {address:?} with payload {payload:?}");
466        Ok(())
467    }
468
469    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
470        if let Some(conn) = self.connections.lock().get_mut(&address) {
471            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
472            let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
473                warn!("Received accept packet for connection in unexpected state for {address:?}");
474                return Ok(());
475            };
476            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
477            if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
478                warn!(
479                    "Accept packet received for {address:?} but connect caller stopped waiting for it"
480                );
481            }
482            let pause_state = PauseState::new();
483
484            let reader_scope = Scope::new_with_name("connection-reader");
485            conn.state = VsockConnectionState::Connected {
486                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
487                reader_scope,
488                notify_closed,
489                pause_state,
490            };
491        } else {
492            warn!("Got accept packet for connection that was not being made at {address:?}");
493            return Ok(());
494        }
495        Ok(())
496    }
497
498    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
499        trace!("received connect packet for {address:?}");
500        match self.connections.lock().entry(address.clone()) {
501            Entry::Vacant(entry) => {
502                debug!("valid connect request for {address:?}");
503                entry.insert(VsockConnection {
504                    _address: address,
505                    state: VsockConnectionState::ConnectingIncoming,
506                });
507            }
508            Entry::Occupied(_) => {
509                warn!(
510                    "Received connect packet for already existing \
511                     connection for address {address:?}. Ignoring"
512                );
513                return Ok(());
514            }
515        }
516
517        trace!("sending incoming connection request to client for {address:?}");
518        let connection_request = ConnectionRequest { address };
519        self.incoming_requests_tx
520            .clone()
521            .send(connection_request)
522            .await
523            .inspect(|_| trace!("sent incoming request for {address:?}"))
524            .map_err(|_| Error::other("Failed to send connection request"))
525    }
526
527    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
528        trace!("received finish packet for {address:?}");
529        let mut notify;
530        if let Some(conn) = self.connections.lock().remove(&address) {
531            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
532                warn!(
533                    "Received finish (close) packet for {address:?} \
534                     which was not in a connected state. Ignoring and dropping connection state."
535                );
536                return Ok(());
537            };
538            notify = notify_closed;
539        } else {
540            warn!(
541                "Received finish (close) packet for connection that didn't exist \
542                 on address {address:?}. Ignoring"
543            );
544            return Ok(());
545        }
546
547        notify.send(Ok(())).await.ok();
548
549        let header = &mut Header::new(PacketType::Reset);
550        header.set_address(&address);
551        self.packet_filler
552            .write_vsock_packet(&Packet { header, payload: &[] })
553            .await
554            .expect("accept packet should never be too large for packet buffer");
555        Ok(())
556    }
557
558    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
559        trace!("received reset packet for {address:?}");
560        let mut notify = None;
561        if let Some(conn) = self.connections.lock().remove(&address) {
562            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
563                notify = Some(notify_closed);
564            } else {
565                debug!(
566                    "Received reset packet for connection that wasn't in a connecting or \
567                    disconnected state on address {address:?}."
568                );
569            }
570        } else {
571            warn!(
572                "Received reset packet for connection that didn't \
573                exist on address {address:?}. Ignoring"
574            );
575        }
576
577        if let Some(mut notify) = notify {
578            notify.send(Ok(())).await.ok();
579        }
580        Ok(())
581    }
582
583    async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
584        if !self.protocol_version.has_pause_packets() {
585            warn!(
586                "Got a pause packet while using protocol \
587                 version {} which does not support them. Ignoring",
588                self.protocol_version
589            );
590            return Ok(());
591        }
592
593        let pause = match payload {
594            [1] => true,
595            [0] => false,
596            other => {
597                warn!("Ignoring unexpected pause packet payload {other:?}");
598                return Ok(());
599            }
600        };
601
602        if let Some(conn) = self.connections.lock().get(&address) {
603            if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
604                pause_state.set_paused(pause);
605            } else {
606                warn!("Received pause packet for unestablished connection. Ignoring");
607            };
608        } else {
609            warn!(
610                "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
611            );
612        }
613
614        Ok(())
615    }
616
617    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
618    /// or the overall state of the connection.
619    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
620        trace!("received vsock packet {header:?}", header = packet.header);
621        let payload_len = packet.header.payload_len.get() as usize;
622        let payload = &packet.payload[..payload_len];
623        let address = Address::from(packet.header);
624        match packet.header.packet_type {
625            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
626            PacketType::Data => self.handle_data_packet(address, payload).await,
627            PacketType::Accept => self.handle_accept_packet(address).await,
628            PacketType::Connect => self.handle_connect_packet(address).await,
629            PacketType::Finish => self.handle_finish_packet(address).await,
630            PacketType::Reset => self.handle_reset_packet(address).await,
631            PacketType::Echo => self.handle_echo_packet(address, payload).await,
632            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
633            PacketType::Pause => self.handle_pause_packet(address, payload).await,
634        }
635    }
636
637    /// Provides a packet builder for the state machine to write packets to. Returns a future that
638    /// will be fulfilled when there is data available to send on the packet.
639    ///
640    /// # Panics
641    ///
642    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
643    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
644        self.packet_filler.fill_usb_packet(builder).await
645    }
646}
647
648async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
649    address: &Address,
650    connections: &fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>,
651    packet_filler: &UsbPacketFiller<B>,
652) -> Result<(), Error> {
653    let mut notify = None;
654    if let Some(conn) = connections.lock().remove(&address) {
655        if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
656            notify = Some(notify_closed);
657        }
658    } else {
659        return Err(Error::other(
660            "Client asked to reset connection {address:?} that did not exist",
661        ));
662    }
663
664    if let Some(mut notify) = notify {
665        notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
666    }
667
668    let header = &mut Header::new(PacketType::Reset);
669    header.set_address(address);
670    packet_filler
671        .write_vsock_packet(&Packet { header, payload: &[] })
672        .await
673        .expect("Reset packet should never be too big");
674    Ok(())
675}
676
677/// A writer inside of a [`ConnectionState`]. This is essentially an
678/// option-monad around an [`OverflowWriter`], but unlike
679/// [`std::option::Option`] the empty variant stores wakers that by convention
680/// will be woken when we replace it with the occupied variant.
681enum ConnectionStateWriter<S> {
682    NotYetAvailable(Vec<Waker>),
683    Available(OverflowWriter<S>),
684}
685
686impl<S> ConnectionStateWriter<S> {
687    /// Wait for the given `ConnectionStateWriter` to contain an actual writer.
688    fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
689        ConnectionStateWriterFut { writer: this, lock_fut: None }
690    }
691}
692
693/// Future returned by [`ConnectionStateWriter::wait_available`].
694struct ConnectionStateWriterFut<S> {
695    writer: Arc<Mutex<ConnectionStateWriter<S>>>,
696    lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
697}
698
699impl<S> Future for ConnectionStateWriterFut<S> {
700    type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
701
702    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
703        let writer = Arc::clone(&self.writer);
704        let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
705        let mut lock = ready!(lock_fut.poll_unpin(cx));
706        self.lock_fut = None;
707        match &mut *lock {
708            ConnectionStateWriter::Available(_) => Poll::Ready(lock),
709            ConnectionStateWriter::NotYetAvailable(queue) => {
710                queue.push(cx.waker().clone());
711                Poll::Pending
712            }
713        }
714    }
715}
716
717enum VsockConnectionState<S> {
718    ConnectingOutgoing(oneshot::Sender<ConnectionState>),
719    ConnectingIncoming,
720    Connected {
721        writer: Arc<Mutex<ConnectionStateWriter<S>>>,
722        notify_closed: mpsc::Sender<Result<(), Error>>,
723        pause_state: Arc<PauseState>,
724        reader_scope: Scope,
725    },
726    Invalid,
727}
728
729struct VsockConnection<S> {
730    _address: Address,
731    state: VsockConnectionState<S>,
732}
733
734/// A handle for the state of a connection established with either [`Connection::connect`] or
735/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
736/// needing to hold on to the Socket end.
737#[derive(Debug)]
738pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
739
740impl ConnectionState {
741    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
742    /// and an error if it closed because of an error.
743    pub async fn wait_for_close(mut self) -> Result<(), Error> {
744        self.0
745            .next()
746            .await
747            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
748    }
749}
750
751/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
752/// [`Connection::reject`]ed.
753#[derive(Debug)]
754pub struct ConnectionRequest {
755    address: Address,
756}
757
758impl ConnectionRequest {
759    /// Creates a new connection request for the given address.
760    pub fn new(address: Address) -> Self {
761        Self { address }
762    }
763
764    /// The address this connection request is being made for.
765    pub fn address(&self) -> &Address {
766        &self.address
767    }
768}
769
770#[cfg(test)]
771mod test {
772    use std::sync::Arc;
773
774    use crate::VsockPacketIterator;
775
776    use super::*;
777
778    #[cfg(not(target_os = "fuchsia"))]
779    use fuchsia_async::emulated_handle::Socket as SyncSocket;
780    use fuchsia_async::{Socket, Task};
781    use futures::StreamExt;
782    #[cfg(target_os = "fuchsia")]
783    use zx::Socket as SyncSocket;
784
785    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
786        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
787        loop {
788            println!("waiting for usb packet");
789            builder = echo_connection.fill_usb_packet(builder).await;
790            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
791            println!("got usb packet, echoing it back to the other side");
792            let mut packet_count = 0;
793            for packet in packets {
794                let packet = packet.unwrap();
795                match packet.header.packet_type {
796                    PacketType::Connect => {
797                        // respond with an accept packet
798                        let mut reply_header = packet.header.clone();
799                        reply_header.packet_type = PacketType::Accept;
800                        echo_connection
801                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
802                            .await
803                            .unwrap();
804                    }
805                    PacketType::Accept => {
806                        // just ignore it
807                    }
808                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
809                }
810                packet_count += 1;
811            }
812            println!("handled {packet_count} packets");
813        }
814    }
815
816    #[fuchsia::test]
817    async fn data_over_control_socket() {
818        let (socket, other_socket) = SyncSocket::create_stream();
819        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
820        let mut socket = Socket::from_socket(socket);
821        let connection = Arc::new(Connection::new(
822            ProtocolVersion::LATEST,
823            Some(Socket::from_socket(other_socket)),
824            incoming_requests_tx,
825        ));
826
827        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
828
829        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
830            println!("round tripping packet of size {size}");
831            socket.write_all(&vec![size; size as usize]).await.unwrap();
832            let mut buf = vec![0u8; size as usize];
833            socket.read_exact(&mut buf).await.unwrap();
834            assert_eq!(buf, vec![size; size as usize]);
835        }
836        echo_task.abort().await;
837    }
838
839    #[fuchsia::test]
840    async fn data_over_normal_outgoing_socket() {
841        let (_control_socket, other_socket) = SyncSocket::create_stream();
842        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
843        let connection = Arc::new(Connection::new(
844            ProtocolVersion::LATEST,
845            Some(Socket::from_socket(other_socket)),
846            incoming_requests_tx,
847        ));
848
849        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
850
851        let (socket, other_socket) = SyncSocket::create_stream();
852        let mut socket = Socket::from_socket(socket);
853        connection
854            .connect(
855                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
856                Socket::from_socket(other_socket),
857            )
858            .await
859            .unwrap();
860
861        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
862            println!("round tripping packet of size {size}");
863            socket.write_all(&vec![size; size as usize]).await.unwrap();
864            let mut buf = vec![0u8; size as usize];
865            socket.read_exact(&mut buf).await.unwrap();
866            assert_eq!(buf, vec![size; size as usize]);
867        }
868        echo_task.abort().await;
869    }
870
871    #[fuchsia::test]
872    async fn data_over_normal_incoming_socket() {
873        let (_control_socket, other_socket) = SyncSocket::create_stream();
874        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
875        let connection = Arc::new(Connection::new(
876            ProtocolVersion::LATEST,
877            Some(Socket::from_socket(other_socket)),
878            incoming_requests_tx,
879        ));
880
881        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
882
883        let header = &mut Header::new(PacketType::Connect);
884        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
885        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
886
887        let request = incoming_requests.next().await.unwrap();
888        assert_eq!(
889            request.address,
890            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
891        );
892
893        let (socket, other_socket) = SyncSocket::create_stream();
894        let mut socket = Socket::from_socket(socket);
895        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
896
897        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
898            println!("round tripping packet of size {size}");
899            socket.write_all(&vec![size; size as usize]).await.unwrap();
900            let mut buf = vec![0u8; size as usize];
901            socket.read_exact(&mut buf).await.unwrap();
902            assert_eq!(buf, vec![size; size as usize]);
903        }
904        echo_task.abort().await;
905    }
906
907    async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
908        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
909        loop {
910            builder = from.fill_usb_packet(builder).await;
911            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
912            for packet in packets {
913                println!("forwarding vsock packet");
914                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
915            }
916        }
917    }
918
919    pub(crate) trait EndToEndTestFn<R>:
920        AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
921    {
922    }
923    impl<T, R> EndToEndTestFn<R> for T where
924        T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
925    {
926    }
927
928    pub(crate) async fn end_to_end_test<R1, R2>(
929        left_side: impl EndToEndTestFn<R1>,
930        right_side: impl EndToEndTestFn<R2>,
931    ) -> (R1, R2) {
932        type Connection = crate::Connection<Vec<u8>, Socket>;
933        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
934        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
935        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
936        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
937
938        let connection1 = Arc::new(Connection::new(
939            ProtocolVersion::LATEST,
940            Some(Socket::from_socket(other_socket1)),
941            incoming_requests_tx1,
942        ));
943        let connection2 = Arc::new(Connection::new(
944            ProtocolVersion::LATEST,
945            Some(Socket::from_socket(other_socket2)),
946            incoming_requests_tx2,
947        ));
948
949        let conn1 = connection1.clone();
950        let conn2 = connection2.clone();
951        let passthrough_task = Task::spawn(async move {
952            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
953            println!("passthrough task loop ended");
954        });
955
956        let res = futures::join!(
957            left_side(connection1, incoming_requests1),
958            right_side(connection2, incoming_requests2)
959        );
960        passthrough_task.abort().await;
961        res
962    }
963
964    #[fuchsia::test]
965    async fn data_over_end_to_end() {
966        end_to_end_test(
967            async |conn, _incoming| {
968                println!("sending request on connection 1");
969                let (socket, other_socket) = SyncSocket::create_stream();
970                let mut socket = Socket::from_socket(socket);
971                let state = conn
972                    .connect(
973                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
974                        Socket::from_socket(other_socket),
975                    )
976                    .await
977                    .unwrap();
978
979                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
980                    println!("round tripping packet of size {size}");
981                    socket.write_all(&vec![size; size as usize]).await.unwrap();
982                }
983                drop(socket);
984                state.wait_for_close().await.unwrap();
985            },
986            async |conn, mut incoming| {
987                println!("accepting request on connection 2");
988                let request = incoming.next().await.unwrap();
989                assert_eq!(
990                    request.address,
991                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
992                );
993
994                let (socket, other_socket) = SyncSocket::create_stream();
995                let mut socket = Socket::from_socket(socket);
996                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
997
998                println!("accepted request on connection 2");
999                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1000                    let mut buf = vec![0u8; size as usize];
1001                    socket.read_exact(&mut buf).await.unwrap();
1002                    assert_eq!(buf, vec![size; size as usize]);
1003                }
1004                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1005                state.wait_for_close().await.unwrap();
1006            },
1007        )
1008        .await;
1009    }
1010
1011    #[fuchsia::test]
1012    async fn normal_close_end_to_end() {
1013        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1014        end_to_end_test(
1015            async |conn, _incoming| {
1016                let (socket, other_socket) = SyncSocket::create_stream();
1017                let mut socket = Socket::from_socket(socket);
1018                let state =
1019                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1020                conn.close(&addr).await;
1021                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1022                state.wait_for_close().await.unwrap();
1023            },
1024            async |conn, mut incoming| {
1025                println!("accepting request on connection 2");
1026                let request = incoming.next().await.unwrap();
1027                assert_eq!(request.address, addr.clone(),);
1028
1029                let (socket, other_socket) = SyncSocket::create_stream();
1030                let mut socket = Socket::from_socket(socket);
1031                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1032                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1033                state.wait_for_close().await.unwrap();
1034            },
1035        )
1036        .await;
1037    }
1038
1039    #[fuchsia::test]
1040    async fn reset_end_to_end() {
1041        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1042        end_to_end_test(
1043            async |conn, _incoming| {
1044                let (socket, other_socket) = SyncSocket::create_stream();
1045                let mut socket = Socket::from_socket(socket);
1046                let state =
1047                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1048                conn.reset(&addr).await.unwrap();
1049                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1050                state.wait_for_close().await.expect_err("expected reset");
1051            },
1052            async |conn, mut incoming| {
1053                println!("accepting request on connection 2");
1054                let request = incoming.next().await.unwrap();
1055                assert_eq!(request.address, addr.clone(),);
1056
1057                let (socket, other_socket) = SyncSocket::create_stream();
1058                let mut socket = Socket::from_socket(socket);
1059                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1060                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1061                state.wait_for_close().await.unwrap();
1062            },
1063        )
1064        .await;
1065    }
1066}