usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23    Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
24};
25
26mod overflow_writer;
27mod pause_state;
28
29use overflow_writer::OverflowWriter;
30use pause_state::PauseState;
31
32/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
33pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
34impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35
36#[derive(Copy, Clone, PartialEq, Eq)]
37enum PausePacket {
38    Pause,
39    UnPause,
40}
41
42impl PausePacket {
43    fn bytes(&self) -> [u8; 1] {
44        match self {
45            PausePacket::Pause => [1],
46            PausePacket::UnPause => [0],
47        }
48    }
49}
50
51/// A connection that has been established with the other end and now just needs
52/// a socket to start transmitting.
53pub struct ReadyConnect<B, S> {
54    connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
55    packet_filler: Arc<UsbPacketFiller<B>>,
56    address: Address,
57    connection_state: ConnectionState,
58}
59
60impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
61    /// Finish establishing the connection by providing a socket for data transfer.
62    pub async fn finish_connect(self, socket: S) -> ConnectionState {
63        let (read_socket, write_socket) = socket.split();
64        let writer = {
65            let conns = self.connections.lock().unwrap();
66            let Some(conn) = conns.get(&self.address) else {
67                warn!("Connection state was missing after connection success!");
68                return self.connection_state;
69            };
70            let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
71                &conn.state
72            else {
73                warn!("Connection state was invalid after connection success!");
74                return self.connection_state;
75            };
76            reader_scope.spawn(Connection::<B, S>::run_socket(
77                read_socket,
78                self.address,
79                self.packet_filler,
80                Arc::clone(pause_state),
81            ));
82            Arc::clone(writer)
83        };
84        let mut writer = writer.lock().await;
85        let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
86            &mut *writer,
87            ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
88        ) else {
89            unreachable!("Connection completed multiple times!")
90        };
91
92        wakers.into_iter().for_each(Waker::wake);
93        self.connection_state
94    }
95}
96
97/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
98/// transmitted for them.
99///
100/// This implementation aims to be agnostic to both the underlying transport and the buffers used
101/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
102/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
103///
104/// The client of this library will:
105/// - Use methods on this struct to initiate actions like connecting and accepting
106/// connections to the other end.
107/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
108/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
109pub struct Connection<B, S> {
110    control_socket_writer: Option<Mutex<WriteHalf<S>>>,
111    packet_filler: Arc<UsbPacketFiller<B>>,
112    protocol_version: ProtocolVersion,
113    connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
114    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
115    task_scope: Scope,
116}
117
118impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
119    /// Creates a new connection with:
120    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
121    /// between host and device) can be read and written from. If this is `None`
122    /// we will discard control data.
123    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
124    /// connection requests from the other side.
125    pub fn new(
126        protocol_version: ProtocolVersion,
127        control_socket: Option<S>,
128        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
129    ) -> Self {
130        let packet_filler = Arc::new(UsbPacketFiller::default());
131        let connections = Default::default();
132        let task_scope = Scope::new_with_name("vsock_usb");
133        let control_socket_writer = control_socket.map(|control_socket| {
134            let (control_socket_reader, control_socket_writer) = control_socket.split();
135            task_scope.spawn(Self::run_socket(
136                control_socket_reader,
137                Address::default(),
138                packet_filler.clone(),
139                PauseState::new(),
140            ));
141            Mutex::new(control_socket_writer)
142        });
143        Self {
144            control_socket_writer,
145            packet_filler,
146            connections,
147            incoming_requests_tx,
148            protocol_version,
149            task_scope,
150        }
151    }
152
153    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
154        let header = &mut Header::new(PacketType::Finish);
155        header.set_address(address);
156        usb_packet_filler
157            .write_vsock_packet(&Packet { header, payload: &[] })
158            .await
159            .expect("Finish packet should never be too big");
160    }
161
162    async fn run_socket(
163        mut reader: ReadHalf<S>,
164        address: Address,
165        usb_packet_filler: Arc<UsbPacketFiller<B>>,
166        pause_state: Arc<PauseState>,
167    ) {
168        let mut buf = [0; 4096];
169        loop {
170            log::trace!("reading from control socket");
171            let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
172                Ok(0) => {
173                    if !address.is_zeros() {
174                        Self::send_close_packet(&address, &usb_packet_filler).await;
175                    }
176                    return;
177                }
178                Ok(read) => read,
179                Err(err) => {
180                    if address.is_zeros() {
181                        log::error!("Error reading usb socket: {err:?}");
182                    } else {
183                        Self::send_close_packet(&address, &usb_packet_filler).await;
184                    }
185                    return;
186                }
187            };
188            log::trace!("writing {read} bytes to vsock packet");
189            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
190            log::trace!("wrote {read} bytes to vsock packet");
191        }
192    }
193
194    fn set_connection(
195        &self,
196        address: Address,
197        state: VsockConnectionState<S>,
198    ) -> Result<(), Error> {
199        let mut connections = self.connections.lock().unwrap();
200        if !connections.contains_key(&address) {
201            connections.insert(address.clone(), VsockConnection { _address: address, state });
202            Ok(())
203        } else {
204            Err(Error::other(format!("connection on address {address:?} already set")))
205        }
206    }
207
208    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
209    /// have a distinct target address or payload.
210    pub async fn send_empty_echo(&self) {
211        debug!("Sending empty echo packet");
212        let header = &mut Header::new(PacketType::Echo);
213        self.packet_filler
214            .write_vsock_packet(&Packet { header, payload: &[] })
215            .await
216            .expect("empty echo packet should never be too large to fit in a usb packet");
217    }
218
219    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
220    /// to read and write from. The function will complete when the other end has accepted or
221    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
222    /// for the connection to be closed.
223    pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
224        let ready = self.connect_late(addr).await?;
225        Ok(ready.finish_connect(socket).await)
226    }
227
228    /// Same as [`connect`] but doesn't require the socket to be passed. Instead
229    /// we return a [`ReadyConnect`] which can be given the socket later. This
230    /// shouldn't be deferred very long but it is useful if the socket is
231    /// starting out speaking a different protocol and needs to execute a
232    /// protocol switch, but needs to know the connection status before doing
233    /// that switch.
234    pub async fn connect_late(&self, addr: Address) -> Result<ReadyConnect<B, S>, Error> {
235        let (connected_tx, connected_rx) = oneshot::channel();
236
237        self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
238
239        let header = &mut Header::new(PacketType::Connect);
240        header.set_address(&addr);
241        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
242        let Ok(conn_state) = connected_rx.await else {
243            return Err(Error::other("Accept was never received for {addr:?}"));
244        };
245
246        Ok(ReadyConnect {
247            connections: Arc::clone(&self.connections),
248            packet_filler: Arc::clone(&self.packet_filler),
249            address: addr,
250            connection_state: conn_state,
251        })
252    }
253
254    /// Sends a request for the other end to close the connection.
255    pub async fn close(&self, address: &Address) {
256        Self::send_close_packet(address, &self.packet_filler).await
257    }
258
259    /// Resets the named connection without going through a close request.
260    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
261        reset(address, &self.connections, &self.packet_filler).await
262    }
263
264    /// Accepts a connection for which an outstanding connection request has been made, and
265    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
266    /// can be used to wait for the connection to be closed.
267    pub async fn accept(
268        &self,
269        request: ConnectionRequest,
270        socket: S,
271    ) -> Result<ConnectionState, Error> {
272        let ready = self.accept_late(request).await?;
273        Ok(ready.finish_connect(socket).await)
274    }
275
276    /// Accepts a connection for which an outstanding connection request has been made, and
277    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
278    /// can be used to wait for the connection to be closed.
279    pub async fn accept_late(
280        &self,
281        request: ConnectionRequest,
282    ) -> Result<ReadyConnect<B, S>, Error> {
283        let address = request.address;
284        let notify_closed_rx;
285        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
286            let VsockConnectionState::ConnectingIncoming = &conn.state else {
287                return Err(Error::other(format!(
288                    "Attempted to accept connection that was not waiting at {address:?}"
289                )));
290            };
291
292            let notify_closed = mpsc::channel(2);
293            notify_closed_rx = notify_closed.1;
294            let notify_closed = notify_closed.0;
295            let pause_state = PauseState::new();
296
297            let reader_scope = Scope::new_with_name("connection-reader");
298
299            conn.state = VsockConnectionState::Connected {
300                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
301                reader_scope,
302                notify_closed,
303                pause_state,
304            };
305        } else {
306            return Err(Error::other(format!(
307                "Attempting to accept connection that did not exist at {address:?}"
308            )));
309        }
310        let header = &mut Header::new(PacketType::Accept);
311        header.set_address(&address);
312        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
313        Ok(ReadyConnect {
314            connections: Arc::clone(&self.connections),
315            packet_filler: Arc::clone(&self.packet_filler),
316            address,
317            connection_state: ConnectionState(notify_closed_rx),
318        })
319    }
320
321    /// Rejects a pending connection request from the other side.
322    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
323        let address = request.address;
324        match self.connections.lock().unwrap().entry(address.clone()) {
325            Entry::Occupied(entry) => {
326                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
327                    return Err(Error::other(format!(
328                        "Attempted to reject connection that was not waiting at {address:?}"
329                    )));
330                };
331                entry.remove();
332            }
333            Entry::Vacant(_) => {
334                return Err(Error::other(format!(
335                    "Attempted to reject connection that was not waiting at {address:?}"
336                )));
337            }
338        }
339
340        let header = &mut Header::new(PacketType::Reset);
341        header.set_address(&address);
342        self.packet_filler
343            .write_vsock_packet(&Packet { header, payload: &[] })
344            .await
345            .expect("accept packet should never be too large for packet buffer");
346        Ok(())
347    }
348
349    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
350        // all zero data packets go to the control channel
351        if address.is_zeros() {
352            if let Some(writer) = self.control_socket_writer.as_ref() {
353                writer.lock().await.write_all(payload).await?;
354            } else {
355                trace!("Discarding {} bytes of data sent to control socket", payload.len());
356            }
357            Ok(())
358        } else {
359            let payload_socket;
360            if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
361                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
362                    warn!(
363                        "Received data packet for connection in unexpected state for {address:?}"
364                    );
365                    return Ok(());
366                };
367                payload_socket = writer.clone();
368            } else {
369                warn!("Received data packet for connection that didn't exist at {address:?}");
370                return Ok(());
371            }
372            let mut socket_guard =
373                ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
374            let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
375                unreachable!("wait_available didn't wait until socket was available!");
376            };
377            match socket.write_all(payload) {
378                Err(err) => {
379                    debug!(
380                        "Write to socket address {address:?} failed, \
381                         resetting connection immediately: {err:?}"
382                    );
383                    self.reset(&address)
384                        .await
385                        .inspect_err(|err| {
386                            warn!(
387                                "Attempt to reset connection to {address:?} \
388                                   failed after write error: {err:?}"
389                            );
390                        })
391                        .ok();
392                }
393                Ok(status) => {
394                    if status.overflowed() {
395                        if self.protocol_version.has_pause_packets() {
396                            let header = &mut Header::new(PacketType::Pause);
397                            let payload = &PausePacket::Pause.bytes();
398                            header.set_address(&address);
399                            header.payload_len.set(payload.len() as u32);
400                            self.packet_filler
401                                .write_vsock_packet(&Packet { header, payload })
402                                .await
403                                .expect(
404                                    "pause packet should never be too large to fit in a usb packet",
405                                );
406                        }
407
408                        let weak_payload_socket = Arc::downgrade(&payload_socket);
409                        let connections = Arc::clone(&self.connections);
410                        let has_pause_packets = self.protocol_version.has_pause_packets();
411                        let packet_filler = Arc::clone(&self.packet_filler);
412                        self.task_scope.spawn(async move {
413                            let res = OverflowHandleFut::new(weak_payload_socket).await;
414
415                            if let Err(err) = res {
416                                debug!(
417                                    "Write to socket address {address:?} failed while \
418                                     processing backlog, resetting connection at next poll: {err:?}"
419                                );
420                                if let Err(err) = reset(&address, &connections, &packet_filler).await {
421                                    debug!("Error sending reset frame after overflow write failed: {err:?}");
422                                }
423                            } else if has_pause_packets {
424                                let header = &mut Header::new(PacketType::Pause);
425                                let payload = &PausePacket::UnPause.bytes();
426                                header.set_address(&address);
427                                header.payload_len.set(payload.len() as u32);
428                                packet_filler
429                                    .write_vsock_packet(&Packet { header, payload })
430                                    .await
431                                    .expect("pause packet should never be too large to fit in a usb packet");
432                            }
433                        });
434                    }
435                }
436            }
437            Ok(())
438        }
439    }
440
441    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
442        debug!("received echo for {address:?} with payload {payload:?}");
443        let header = &mut Header::new(PacketType::EchoReply);
444        header.payload_len.set(payload.len() as u32);
445        header.set_address(&address);
446        self.packet_filler
447            .write_vsock_packet(&Packet { header, payload })
448            .await
449            .map_err(|_| Error::other("Echo packet was too large to be sent back"))
450    }
451
452    async fn handle_echo_reply_packet(
453        &self,
454        address: Address,
455        payload: &[u8],
456    ) -> Result<(), Error> {
457        // ignore but log replies
458        debug!("received echo reply for {address:?} with payload {payload:?}");
459        Ok(())
460    }
461
462    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
463        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
464            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
465            let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
466                warn!("Received accept packet for connection in unexpected state for {address:?}");
467                return Ok(());
468            };
469            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
470            if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
471                warn!(
472                    "Accept packet received for {address:?} but connect caller stopped waiting for it"
473                );
474            }
475            let pause_state = PauseState::new();
476
477            let reader_scope = Scope::new_with_name("connection-reader");
478            conn.state = VsockConnectionState::Connected {
479                writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
480                reader_scope,
481                notify_closed,
482                pause_state,
483            };
484        } else {
485            warn!("Got accept packet for connection that was not being made at {address:?}");
486            return Ok(());
487        }
488        Ok(())
489    }
490
491    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
492        trace!("received connect packet for {address:?}");
493        match self.connections.lock().unwrap().entry(address.clone()) {
494            Entry::Vacant(entry) => {
495                debug!("valid connect request for {address:?}");
496                entry.insert(VsockConnection {
497                    _address: address,
498                    state: VsockConnectionState::ConnectingIncoming,
499                });
500            }
501            Entry::Occupied(_) => {
502                warn!(
503                    "Received connect packet for already existing \
504                     connection for address {address:?}. Ignoring"
505                );
506                return Ok(());
507            }
508        }
509
510        trace!("sending incoming connection request to client for {address:?}");
511        let connection_request = ConnectionRequest { address };
512        self.incoming_requests_tx
513            .clone()
514            .send(connection_request)
515            .await
516            .inspect(|_| trace!("sent incoming request for {address:?}"))
517            .map_err(|_| Error::other("Failed to send connection request"))
518    }
519
520    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
521        trace!("received finish packet for {address:?}");
522        let mut notify;
523        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
524            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
525                warn!(
526                    "Received finish (close) packet for {address:?} \
527                     which was not in a connected state. Ignoring and dropping connection state."
528                );
529                return Ok(());
530            };
531            notify = notify_closed;
532        } else {
533            warn!(
534                "Received finish (close) packet for connection that didn't exist \
535                 on address {address:?}. Ignoring"
536            );
537            return Ok(());
538        }
539
540        notify.send(Ok(())).await.ok();
541
542        let header = &mut Header::new(PacketType::Reset);
543        header.set_address(&address);
544        self.packet_filler
545            .write_vsock_packet(&Packet { header, payload: &[] })
546            .await
547            .expect("accept packet should never be too large for packet buffer");
548        Ok(())
549    }
550
551    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
552        trace!("received reset packet for {address:?}");
553        let mut notify = None;
554        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
555            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
556                notify = Some(notify_closed);
557            } else {
558                debug!(
559                    "Received reset packet for connection that wasn't in a connecting or \
560                    disconnected state on address {address:?}."
561                );
562            }
563        } else {
564            warn!(
565                "Received reset packet for connection that didn't \
566                exist on address {address:?}. Ignoring"
567            );
568        }
569
570        if let Some(mut notify) = notify {
571            notify.send(Ok(())).await.ok();
572        }
573        Ok(())
574    }
575
576    async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
577        if !self.protocol_version.has_pause_packets() {
578            warn!(
579                "Got a pause packet while using protocol \
580                 version {} which does not support them. Ignoring",
581                self.protocol_version
582            );
583            return Ok(());
584        }
585
586        let pause = match payload {
587            [1] => true,
588            [0] => false,
589            other => {
590                warn!("Ignoring unexpected pause packet payload {other:?}");
591                return Ok(());
592            }
593        };
594
595        if let Some(conn) = self.connections.lock().unwrap().get(&address) {
596            if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
597                pause_state.set_paused(pause);
598            } else {
599                warn!("Received pause packet for unestablished connection. Ignoring");
600            };
601        } else {
602            warn!(
603                "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
604            );
605        }
606
607        Ok(())
608    }
609
610    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
611    /// or the overall state of the connection.
612    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
613        trace!("received vsock packet {header:?}", header = packet.header);
614        let payload_len = packet.header.payload_len.get() as usize;
615        let payload = &packet.payload[..payload_len];
616        let address = Address::from(packet.header);
617        match packet.header.packet_type {
618            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
619            PacketType::Data => self.handle_data_packet(address, payload).await,
620            PacketType::Accept => self.handle_accept_packet(address).await,
621            PacketType::Connect => self.handle_connect_packet(address).await,
622            PacketType::Finish => self.handle_finish_packet(address).await,
623            PacketType::Reset => self.handle_reset_packet(address).await,
624            PacketType::Echo => self.handle_echo_packet(address, payload).await,
625            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
626            PacketType::Pause => self.handle_pause_packet(address, payload).await,
627        }
628    }
629
630    /// Provides a packet builder for the state machine to write packets to. Returns a future that
631    /// will be fulfilled when there is data available to send on the packet.
632    ///
633    /// # Panics
634    ///
635    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
636    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
637        self.packet_filler.fill_usb_packet(builder).await
638    }
639}
640
641async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
642    address: &Address,
643    connections: &std::sync::Mutex<HashMap<Address, VsockConnection<S>>>,
644    packet_filler: &UsbPacketFiller<B>,
645) -> Result<(), Error> {
646    let mut notify = None;
647    if let Some(conn) = connections.lock().unwrap().remove(&address) {
648        if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
649            notify = Some(notify_closed);
650        }
651    } else {
652        return Err(Error::other(
653            "Client asked to reset connection {address:?} that did not exist",
654        ));
655    }
656
657    if let Some(mut notify) = notify {
658        notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
659    }
660
661    let header = &mut Header::new(PacketType::Reset);
662    header.set_address(address);
663    packet_filler
664        .write_vsock_packet(&Packet { header, payload: &[] })
665        .await
666        .expect("Reset packet should never be too big");
667    Ok(())
668}
669
670/// A writer inside of a [`ConnectionState`]. This is essentially an
671/// option-monad around an [`OverflowWriter`], but unlike
672/// [`std::option::Option`] the empty variant stores wakers that by convention
673/// will be woken when we replace it with the occupied variant.
674enum ConnectionStateWriter<S> {
675    NotYetAvailable(Vec<Waker>),
676    Available(OverflowWriter<S>),
677}
678
679impl<S> ConnectionStateWriter<S> {
680    /// Wait for the given `ConnectionStateWriter` to contain an actual writer.
681    fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
682        ConnectionStateWriterFut { writer: this, lock_fut: None }
683    }
684}
685
686/// Future returned by [`ConnectionStateWriter::wait_available`].
687struct ConnectionStateWriterFut<S> {
688    writer: Arc<Mutex<ConnectionStateWriter<S>>>,
689    lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
690}
691
692impl<S> Future for ConnectionStateWriterFut<S> {
693    type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
694
695    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
696        let writer = Arc::clone(&self.writer);
697        let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
698        let mut lock = ready!(lock_fut.poll_unpin(cx));
699        self.lock_fut = None;
700        match &mut *lock {
701            ConnectionStateWriter::Available(_) => Poll::Ready(lock),
702            ConnectionStateWriter::NotYetAvailable(queue) => {
703                queue.push(cx.waker().clone());
704                Poll::Pending
705            }
706        }
707    }
708}
709
710enum VsockConnectionState<S> {
711    ConnectingOutgoing(oneshot::Sender<ConnectionState>),
712    ConnectingIncoming,
713    Connected {
714        writer: Arc<Mutex<ConnectionStateWriter<S>>>,
715        notify_closed: mpsc::Sender<Result<(), Error>>,
716        pause_state: Arc<PauseState>,
717        reader_scope: Scope,
718    },
719    Invalid,
720}
721
722struct VsockConnection<S> {
723    _address: Address,
724    state: VsockConnectionState<S>,
725}
726
727/// A handle for the state of a connection established with either [`Connection::connect`] or
728/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
729/// needing to hold on to the Socket end.
730#[derive(Debug)]
731pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
732
733impl ConnectionState {
734    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
735    /// and an error if it closed because of an error.
736    pub async fn wait_for_close(mut self) -> Result<(), Error> {
737        self.0
738            .next()
739            .await
740            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
741    }
742}
743
744/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
745/// [`Connection::reject`]ed.
746#[derive(Debug)]
747pub struct ConnectionRequest {
748    address: Address,
749}
750
751impl ConnectionRequest {
752    /// Creates a new connection request for the given address.
753    pub fn new(address: Address) -> Self {
754        Self { address }
755    }
756
757    /// The address this connection request is being made for.
758    pub fn address(&self) -> &Address {
759        &self.address
760    }
761}
762
763#[cfg(test)]
764mod test {
765    use std::sync::Arc;
766
767    use crate::VsockPacketIterator;
768
769    use super::*;
770
771    #[cfg(not(target_os = "fuchsia"))]
772    use fuchsia_async::emulated_handle::Socket as SyncSocket;
773    use fuchsia_async::{Socket, Task};
774    use futures::StreamExt;
775    #[cfg(target_os = "fuchsia")]
776    use zx::Socket as SyncSocket;
777
778    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
779        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
780        loop {
781            println!("waiting for usb packet");
782            builder = echo_connection.fill_usb_packet(builder).await;
783            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
784            println!("got usb packet, echoing it back to the other side");
785            let mut packet_count = 0;
786            for packet in packets {
787                let packet = packet.unwrap();
788                match packet.header.packet_type {
789                    PacketType::Connect => {
790                        // respond with an accept packet
791                        let mut reply_header = packet.header.clone();
792                        reply_header.packet_type = PacketType::Accept;
793                        echo_connection
794                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
795                            .await
796                            .unwrap();
797                    }
798                    PacketType::Accept => {
799                        // just ignore it
800                    }
801                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
802                }
803                packet_count += 1;
804            }
805            println!("handled {packet_count} packets");
806        }
807    }
808
809    #[fuchsia::test]
810    async fn data_over_control_socket() {
811        let (socket, other_socket) = SyncSocket::create_stream();
812        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
813        let mut socket = Socket::from_socket(socket);
814        let connection = Arc::new(Connection::new(
815            ProtocolVersion::LATEST,
816            Some(Socket::from_socket(other_socket)),
817            incoming_requests_tx,
818        ));
819
820        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
821
822        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
823            println!("round tripping packet of size {size}");
824            socket.write_all(&vec![size; size as usize]).await.unwrap();
825            let mut buf = vec![0u8; size as usize];
826            socket.read_exact(&mut buf).await.unwrap();
827            assert_eq!(buf, vec![size; size as usize]);
828        }
829        echo_task.abort().await;
830    }
831
832    #[fuchsia::test]
833    async fn data_over_normal_outgoing_socket() {
834        let (_control_socket, other_socket) = SyncSocket::create_stream();
835        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
836        let connection = Arc::new(Connection::new(
837            ProtocolVersion::LATEST,
838            Some(Socket::from_socket(other_socket)),
839            incoming_requests_tx,
840        ));
841
842        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
843
844        let (socket, other_socket) = SyncSocket::create_stream();
845        let mut socket = Socket::from_socket(socket);
846        connection
847            .connect(
848                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
849                Socket::from_socket(other_socket),
850            )
851            .await
852            .unwrap();
853
854        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
855            println!("round tripping packet of size {size}");
856            socket.write_all(&vec![size; size as usize]).await.unwrap();
857            let mut buf = vec![0u8; size as usize];
858            socket.read_exact(&mut buf).await.unwrap();
859            assert_eq!(buf, vec![size; size as usize]);
860        }
861        echo_task.abort().await;
862    }
863
864    #[fuchsia::test]
865    async fn data_over_normal_incoming_socket() {
866        let (_control_socket, other_socket) = SyncSocket::create_stream();
867        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
868        let connection = Arc::new(Connection::new(
869            ProtocolVersion::LATEST,
870            Some(Socket::from_socket(other_socket)),
871            incoming_requests_tx,
872        ));
873
874        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
875
876        let header = &mut Header::new(PacketType::Connect);
877        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
878        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
879
880        let request = incoming_requests.next().await.unwrap();
881        assert_eq!(
882            request.address,
883            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
884        );
885
886        let (socket, other_socket) = SyncSocket::create_stream();
887        let mut socket = Socket::from_socket(socket);
888        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
889
890        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
891            println!("round tripping packet of size {size}");
892            socket.write_all(&vec![size; size as usize]).await.unwrap();
893            let mut buf = vec![0u8; size as usize];
894            socket.read_exact(&mut buf).await.unwrap();
895            assert_eq!(buf, vec![size; size as usize]);
896        }
897        echo_task.abort().await;
898    }
899
900    async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
901        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
902        loop {
903            builder = from.fill_usb_packet(builder).await;
904            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
905            for packet in packets {
906                println!("forwarding vsock packet");
907                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
908            }
909        }
910    }
911
912    pub(crate) trait EndToEndTestFn<R>:
913        AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
914    {
915    }
916    impl<T, R> EndToEndTestFn<R> for T where
917        T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
918    {
919    }
920
921    pub(crate) async fn end_to_end_test<R1, R2>(
922        left_side: impl EndToEndTestFn<R1>,
923        right_side: impl EndToEndTestFn<R2>,
924    ) -> (R1, R2) {
925        type Connection = crate::Connection<Vec<u8>, Socket>;
926        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
927        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
928        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
929        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
930
931        let connection1 = Arc::new(Connection::new(
932            ProtocolVersion::LATEST,
933            Some(Socket::from_socket(other_socket1)),
934            incoming_requests_tx1,
935        ));
936        let connection2 = Arc::new(Connection::new(
937            ProtocolVersion::LATEST,
938            Some(Socket::from_socket(other_socket2)),
939            incoming_requests_tx2,
940        ));
941
942        let conn1 = connection1.clone();
943        let conn2 = connection2.clone();
944        let passthrough_task = Task::spawn(async move {
945            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
946            println!("passthrough task loop ended");
947        });
948
949        let res = futures::join!(
950            left_side(connection1, incoming_requests1),
951            right_side(connection2, incoming_requests2)
952        );
953        passthrough_task.abort().await;
954        res
955    }
956
957    #[fuchsia::test]
958    async fn data_over_end_to_end() {
959        end_to_end_test(
960            async |conn, _incoming| {
961                println!("sending request on connection 1");
962                let (socket, other_socket) = SyncSocket::create_stream();
963                let mut socket = Socket::from_socket(socket);
964                let state = conn
965                    .connect(
966                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
967                        Socket::from_socket(other_socket),
968                    )
969                    .await
970                    .unwrap();
971
972                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
973                    println!("round tripping packet of size {size}");
974                    socket.write_all(&vec![size; size as usize]).await.unwrap();
975                }
976                drop(socket);
977                state.wait_for_close().await.unwrap();
978            },
979            async |conn, mut incoming| {
980                println!("accepting request on connection 2");
981                let request = incoming.next().await.unwrap();
982                assert_eq!(
983                    request.address,
984                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
985                );
986
987                let (socket, other_socket) = SyncSocket::create_stream();
988                let mut socket = Socket::from_socket(socket);
989                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
990
991                println!("accepted request on connection 2");
992                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
993                    let mut buf = vec![0u8; size as usize];
994                    socket.read_exact(&mut buf).await.unwrap();
995                    assert_eq!(buf, vec![size; size as usize]);
996                }
997                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
998                state.wait_for_close().await.unwrap();
999            },
1000        )
1001        .await;
1002    }
1003
1004    #[fuchsia::test]
1005    async fn normal_close_end_to_end() {
1006        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1007        end_to_end_test(
1008            async |conn, _incoming| {
1009                let (socket, other_socket) = SyncSocket::create_stream();
1010                let mut socket = Socket::from_socket(socket);
1011                let state =
1012                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1013                conn.close(&addr).await;
1014                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1015                state.wait_for_close().await.unwrap();
1016            },
1017            async |conn, mut incoming| {
1018                println!("accepting request on connection 2");
1019                let request = incoming.next().await.unwrap();
1020                assert_eq!(request.address, addr.clone(),);
1021
1022                let (socket, other_socket) = SyncSocket::create_stream();
1023                let mut socket = Socket::from_socket(socket);
1024                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1025                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1026                state.wait_for_close().await.unwrap();
1027            },
1028        )
1029        .await;
1030    }
1031
1032    #[fuchsia::test]
1033    async fn reset_end_to_end() {
1034        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1035        end_to_end_test(
1036            async |conn, _incoming| {
1037                let (socket, other_socket) = SyncSocket::create_stream();
1038                let mut socket = Socket::from_socket(socket);
1039                let state =
1040                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1041                conn.reset(&addr).await.unwrap();
1042                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1043                state.wait_for_close().await.expect_err("expected reset");
1044            },
1045            async |conn, mut incoming| {
1046                println!("accepting request on connection 2");
1047                let request = incoming.next().await.unwrap();
1048                assert_eq!(request.address, addr.clone(),);
1049
1050                let (socket, other_socket) = SyncSocket::create_stream();
1051                let mut socket = Socket::from_socket(socket);
1052                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1053                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1054                state.wait_for_close().await.unwrap();
1055            },
1056        )
1057        .await;
1058    }
1059}