usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::Scope;
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::connection::overflow_writer::OverflowHandleFut;
19use crate::{
20    Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
21};
22
23mod overflow_writer;
24mod pause_state;
25
26use overflow_writer::OverflowWriter;
27use pause_state::PauseState;
28
29/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
30pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
31impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
32
33#[derive(Copy, Clone, PartialEq, Eq)]
34enum PausePacket {
35    Pause,
36    UnPause,
37}
38
39impl PausePacket {
40    fn bytes(&self) -> [u8; 1] {
41        match self {
42            PausePacket::Pause => [1],
43            PausePacket::UnPause => [0],
44        }
45    }
46}
47
48/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
49/// transmitted for them.
50///
51/// This implementation aims to be agnostic to both the underlying transport and the buffers used
52/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
53/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
54///
55/// The client of this library will:
56/// - Use methods on this struct to initiate actions like connecting and accepting
57/// connections to the other end.
58/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
59/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
60pub struct Connection<B, S> {
61    control_socket_writer: Option<Mutex<WriteHalf<S>>>,
62    packet_filler: Arc<UsbPacketFiller<B>>,
63    protocol_version: ProtocolVersion,
64    connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
65    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
66    task_scope: Scope,
67}
68
69impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
70    /// Creates a new connection with:
71    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
72    /// between host and device) can be read and written from. If this is `None`
73    /// we will discard control data.
74    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
75    /// connection requests from the other side.
76    pub fn new(
77        protocol_version: ProtocolVersion,
78        control_socket: Option<S>,
79        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
80    ) -> Self {
81        let packet_filler = Arc::new(UsbPacketFiller::default());
82        let connections = Default::default();
83        let task_scope = Scope::new_with_name("vsock_usb");
84        let control_socket_writer = control_socket.map(|control_socket| {
85            let (control_socket_reader, control_socket_writer) = control_socket.split();
86            task_scope.spawn(Self::run_socket(
87                control_socket_reader,
88                Address::default(),
89                packet_filler.clone(),
90                PauseState::new(),
91            ));
92            Mutex::new(control_socket_writer)
93        });
94        Self {
95            control_socket_writer,
96            packet_filler,
97            connections,
98            incoming_requests_tx,
99            protocol_version,
100            task_scope,
101        }
102    }
103
104    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
105        let header = &mut Header::new(PacketType::Finish);
106        header.set_address(address);
107        usb_packet_filler
108            .write_vsock_packet(&Packet { header, payload: &[] })
109            .await
110            .expect("Finish packet should never be too big");
111    }
112
113    async fn run_socket(
114        mut reader: ReadHalf<S>,
115        address: Address,
116        usb_packet_filler: Arc<UsbPacketFiller<B>>,
117        pause_state: Arc<PauseState>,
118    ) {
119        let mut buf = [0; 4096];
120        loop {
121            log::trace!("reading from control socket");
122            let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
123                Ok(0) => {
124                    if !address.is_zeros() {
125                        Self::send_close_packet(&address, &usb_packet_filler).await;
126                    }
127                    return;
128                }
129                Ok(read) => read,
130                Err(err) => {
131                    if address.is_zeros() {
132                        log::error!("Error reading usb socket: {err:?}");
133                    } else {
134                        Self::send_close_packet(&address, &usb_packet_filler).await;
135                    }
136                    return;
137                }
138            };
139            log::trace!("writing {read} bytes to vsock packet");
140            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
141            log::trace!("wrote {read} bytes to vsock packet");
142        }
143    }
144
145    fn set_connection(
146        &self,
147        address: Address,
148        state: VsockConnectionState<S>,
149    ) -> Result<(), Error> {
150        let mut connections = self.connections.lock().unwrap();
151        if !connections.contains_key(&address) {
152            connections.insert(address.clone(), VsockConnection { _address: address, state });
153            Ok(())
154        } else {
155            Err(Error::other(format!("connection on address {address:?} already set")))
156        }
157    }
158
159    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
160    /// have a distinct target address or payload.
161    pub async fn send_empty_echo(&self) {
162        debug!("Sending empty echo packet");
163        let header = &mut Header::new(PacketType::Echo);
164        self.packet_filler
165            .write_vsock_packet(&Packet { header, payload: &[] })
166            .await
167            .expect("empty echo packet should never be too large to fit in a usb packet");
168    }
169
170    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
171    /// to read and write from. The function will complete when the other end has accepted or
172    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
173    /// for the connection to be closed.
174    pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
175        let (read_socket, write_socket) = socket.split();
176        let write_socket = Arc::new(Mutex::new(OverflowWriter::new(write_socket)));
177        let (connected_tx, connected_rx) = oneshot::channel();
178
179        self.set_connection(
180            addr.clone(),
181            VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
182        )?;
183
184        let header = &mut Header::new(PacketType::Connect);
185        header.set_address(&addr);
186        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
187        connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
188    }
189
190    /// Sends a request for the other end to close the connection.
191    pub async fn close(&self, address: &Address) {
192        Self::send_close_packet(address, &self.packet_filler).await
193    }
194
195    /// Resets the named connection without going through a close request.
196    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
197        reset(address, &self.connections, &self.packet_filler).await
198    }
199
200    /// Accepts a connection for which an outstanding connection request has been made, and
201    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
202    /// can be used to wait for the connection to be closed.
203    pub async fn accept(
204        &self,
205        request: ConnectionRequest,
206        socket: S,
207    ) -> Result<ConnectionState, Error> {
208        let address = request.address;
209        let notify_closed_rx;
210        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
211            let VsockConnectionState::ConnectingIncoming = &conn.state else {
212                return Err(Error::other(format!(
213                    "Attempted to accept connection that was not waiting at {address:?}"
214                )));
215            };
216
217            let (read_socket, write_socket) = socket.split();
218            let writer = Arc::new(Mutex::new(OverflowWriter::new(write_socket)));
219            let notify_closed = mpsc::channel(2);
220            notify_closed_rx = notify_closed.1;
221            let notify_closed = notify_closed.0;
222            let pause_state = PauseState::new();
223
224            let reader_task = Scope::new_with_name("connection-reader");
225            reader_task.spawn(Self::run_socket(
226                read_socket,
227                address,
228                self.packet_filler.clone(),
229                Arc::clone(&pause_state),
230            ));
231
232            conn.state = VsockConnectionState::Connected {
233                writer,
234                _reader_scope: reader_task,
235                notify_closed,
236                pause_state,
237            };
238        } else {
239            return Err(Error::other(format!(
240                "Attempting to accept connection that did not exist at {address:?}"
241            )));
242        }
243        let header = &mut Header::new(PacketType::Accept);
244        header.set_address(&address);
245        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
246        Ok(ConnectionState(notify_closed_rx))
247    }
248
249    /// Rejects a pending connection request from the other side.
250    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
251        let address = request.address;
252        match self.connections.lock().unwrap().entry(address.clone()) {
253            Entry::Occupied(entry) => {
254                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
255                    return Err(Error::other(format!(
256                        "Attempted to reject connection that was not waiting at {address:?}"
257                    )));
258                };
259                entry.remove();
260            }
261            Entry::Vacant(_) => {
262                return Err(Error::other(format!(
263                    "Attempted to reject connection that was not waiting at {address:?}"
264                )));
265            }
266        }
267
268        let header = &mut Header::new(PacketType::Reset);
269        header.set_address(&address);
270        self.packet_filler
271            .write_vsock_packet(&Packet { header, payload: &[] })
272            .await
273            .expect("accept packet should never be too large for packet buffer");
274        Ok(())
275    }
276
277    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
278        // all zero data packets go to the control channel
279        if address.is_zeros() {
280            if let Some(writer) = self.control_socket_writer.as_ref() {
281                writer.lock().await.write_all(payload).await?;
282            } else {
283                trace!("Discarding {} bytes of data sent to control socket", payload.len());
284            }
285            Ok(())
286        } else {
287            let payload_socket;
288            if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
289                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
290                    warn!(
291                        "Received data packet for connection in unexpected state for {address:?}"
292                    );
293                    return Ok(());
294                };
295                payload_socket = writer.clone();
296            } else {
297                warn!("Received data packet for connection that didn't exist at {address:?}");
298                return Ok(());
299            }
300            let mut socket_guard = payload_socket.lock().await;
301            match socket_guard.write_all(payload) {
302                Err(err) => {
303                    debug!(
304                        "Write to socket address {address:?} failed, \
305                         resetting connection immediately: {err:?}"
306                    );
307                    self.reset(&address)
308                        .await
309                        .inspect_err(|err| {
310                            warn!(
311                                "Attempt to reset connection to {address:?} \
312                                   failed after write error: {err:?}"
313                            );
314                        })
315                        .ok();
316                }
317                Ok(status) => {
318                    if status.overflowed() {
319                        if self.protocol_version.has_pause_packets() {
320                            let header = &mut Header::new(PacketType::Pause);
321                            header.set_address(&address);
322                            self.packet_filler
323                                .write_vsock_packet(&Packet {
324                                    header,
325                                    payload: &PausePacket::Pause.bytes(),
326                                })
327                                .await
328                                .expect(
329                                    "pause packet should never be too large to fit in a usb packet",
330                                );
331                        }
332
333                        let weak_payload_socket = Arc::downgrade(&payload_socket);
334                        let connections = Arc::clone(&self.connections);
335                        let has_pause_packets = self.protocol_version.has_pause_packets();
336                        let packet_filler = Arc::clone(&self.packet_filler);
337                        self.task_scope.spawn(async move {
338                            let res = OverflowHandleFut::new(weak_payload_socket).await;
339
340                            if let Err(err) = res {
341                                debug!(
342                                    "Write to socket address {address:?} failed while \
343                                     processing backlog, resetting connection at next poll: {err:?}"
344                                );
345                                if let Err(err) = reset(&address, &connections, &packet_filler).await {
346                                    debug!("Error sending reset frame after overflow write failed: {err:?}");
347                                }
348                            } else if has_pause_packets {
349                                let header = &mut Header::new(PacketType::Pause);
350                                header.set_address(&address);
351                                packet_filler
352                                    .write_vsock_packet(&Packet { header, payload: &PausePacket::UnPause.bytes() })
353                                    .await
354                                    .expect("pause packet should never be too large to fit in a usb packet");
355                            }
356                        });
357                    }
358                }
359            }
360            Ok(())
361        }
362    }
363
364    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
365        debug!("received echo for {address:?} with payload {payload:?}");
366        let header = &mut Header::new(PacketType::EchoReply);
367        header.payload_len.set(payload.len() as u32);
368        header.set_address(&address);
369        self.packet_filler
370            .write_vsock_packet(&Packet { header, payload })
371            .await
372            .map_err(|_| Error::other("Echo packet was too large to be sent back"))
373    }
374
375    async fn handle_echo_reply_packet(
376        &self,
377        address: Address,
378        payload: &[u8],
379    ) -> Result<(), Error> {
380        // ignore but log replies
381        debug!("received echo reply for {address:?} with payload {payload:?}");
382        Ok(())
383    }
384
385    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
386        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
387            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
388            let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
389            else {
390                warn!("Received accept packet for connection in unexpected state for {address:?}");
391                return Ok(());
392            };
393            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
394            if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
395                warn!(
396                    "Accept packet received for {address:?} but connect caller stopped waiting for it"
397                );
398            }
399            let pause_state = PauseState::new();
400
401            let reader_task = Scope::new_with_name("connection-reader");
402            reader_task.spawn(Self::run_socket(
403                read_socket,
404                address,
405                self.packet_filler.clone(),
406                Arc::clone(&pause_state),
407            ));
408            conn.state = VsockConnectionState::Connected {
409                writer,
410                _reader_scope: reader_task,
411                notify_closed,
412                pause_state,
413            };
414        } else {
415            warn!("Got accept packet for connection that was not being made at {address:?}");
416            return Ok(());
417        }
418        Ok(())
419    }
420
421    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
422        trace!("received connect packet for {address:?}");
423        match self.connections.lock().unwrap().entry(address.clone()) {
424            Entry::Vacant(entry) => {
425                debug!("valid connect request for {address:?}");
426                entry.insert(VsockConnection {
427                    _address: address,
428                    state: VsockConnectionState::ConnectingIncoming,
429                });
430            }
431            Entry::Occupied(_) => {
432                warn!(
433                    "Received connect packet for already existing \
434                     connection for address {address:?}. Ignoring"
435                );
436                return Ok(());
437            }
438        }
439
440        trace!("sending incoming connection request to client for {address:?}");
441        let connection_request = ConnectionRequest { address };
442        self.incoming_requests_tx
443            .clone()
444            .send(connection_request)
445            .await
446            .inspect(|_| trace!("sent incoming request for {address:?}"))
447            .map_err(|_| Error::other("Failed to send connection request"))
448    }
449
450    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
451        trace!("received finish packet for {address:?}");
452        let mut notify;
453        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
454            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
455                warn!(
456                    "Received finish (close) packet for {address:?} \
457                     which was not in a connected state. Ignoring and dropping connection state."
458                );
459                return Ok(());
460            };
461            notify = notify_closed;
462        } else {
463            warn!(
464                "Received finish (close) packet for connection that didn't exist \
465                 on address {address:?}. Ignoring"
466            );
467            return Ok(());
468        }
469
470        notify.send(Ok(())).await.ok();
471
472        let header = &mut Header::new(PacketType::Reset);
473        header.set_address(&address);
474        self.packet_filler
475            .write_vsock_packet(&Packet { header, payload: &[] })
476            .await
477            .expect("accept packet should never be too large for packet buffer");
478        Ok(())
479    }
480
481    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
482        trace!("received reset packet for {address:?}");
483        let mut notify = None;
484        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
485            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
486                notify = Some(notify_closed);
487            } else {
488                debug!(
489                    "Received reset packet for connection that wasn't in a connecting or \
490                     disconnected state on address {address:?}."
491                );
492            }
493        } else {
494            warn!(
495                "Received reset packet for connection that didn't \
496                 exist on address {address:?}. Ignoring"
497            );
498        }
499
500        if let Some(mut notify) = notify {
501            notify.send(Ok(())).await.ok();
502        }
503        Ok(())
504    }
505
506    async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
507        if !self.protocol_version.has_pause_packets() {
508            warn!(
509                "Got a pause packet while using protocol \
510                 version {} which does not support them. Ignoring",
511                self.protocol_version
512            );
513            return Ok(());
514        }
515
516        let pause = match payload {
517            [1] => true,
518            [0] => false,
519            other => {
520                warn!("Ignoring unexpected pause packet payload {other:?}");
521                return Ok(());
522            }
523        };
524
525        if let Some(conn) = self.connections.lock().unwrap().get(&address) {
526            if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
527                pause_state.set_paused(pause);
528            } else {
529                warn!("Received pause packet for unestablished connection. Ignoring");
530            };
531        } else {
532            warn!(
533                "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
534            );
535        }
536
537        Ok(())
538    }
539
540    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
541    /// or the overall state of the connection.
542    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
543        trace!("received vsock packet {header:?}", header = packet.header);
544        let payload_len = packet.header.payload_len.get() as usize;
545        let payload = &packet.payload[..payload_len];
546        let address = Address::from(packet.header);
547        match packet.header.packet_type {
548            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
549            PacketType::Data => self.handle_data_packet(address, payload).await,
550            PacketType::Accept => self.handle_accept_packet(address).await,
551            PacketType::Connect => self.handle_connect_packet(address).await,
552            PacketType::Finish => self.handle_finish_packet(address).await,
553            PacketType::Reset => self.handle_reset_packet(address).await,
554            PacketType::Echo => self.handle_echo_packet(address, payload).await,
555            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
556            PacketType::Pause => self.handle_pause_packet(address, payload).await,
557        }
558    }
559
560    /// Provides a packet builder for the state machine to write packets to. Returns a future that
561    /// will be fulfilled when there is data available to send on the packet.
562    ///
563    /// # Panics
564    ///
565    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
566    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
567        self.packet_filler.fill_usb_packet(builder).await
568    }
569}
570
571async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
572    address: &Address,
573    connections: &std::sync::Mutex<HashMap<Address, VsockConnection<S>>>,
574    packet_filler: &UsbPacketFiller<B>,
575) -> Result<(), Error> {
576    let mut notify = None;
577    if let Some(conn) = connections.lock().unwrap().remove(&address) {
578        if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
579            notify = Some(notify_closed);
580        }
581    } else {
582        return Err(Error::other(
583            "Client asked to reset connection {address:?} that did not exist",
584        ));
585    }
586
587    if let Some(mut notify) = notify {
588        notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
589    }
590
591    let header = &mut Header::new(PacketType::Reset);
592    header.set_address(address);
593    packet_filler
594        .write_vsock_packet(&Packet { header, payload: &[] })
595        .await
596        .expect("Reset packet should never be too big");
597    Ok(())
598}
599
600enum VsockConnectionState<S> {
601    ConnectingOutgoing(
602        Arc<Mutex<OverflowWriter<S>>>,
603        ReadHalf<S>,
604        oneshot::Sender<Result<ConnectionState, Error>>,
605    ),
606    ConnectingIncoming,
607    Connected {
608        writer: Arc<Mutex<OverflowWriter<S>>>,
609        notify_closed: mpsc::Sender<Result<(), Error>>,
610        pause_state: Arc<PauseState>,
611        _reader_scope: Scope,
612    },
613    Invalid,
614}
615
616struct VsockConnection<S> {
617    _address: Address,
618    state: VsockConnectionState<S>,
619}
620
621/// A handle for the state of a connection established with either [`Connection::connect`] or
622/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
623/// needing to hold on to the Socket end.
624#[derive(Debug)]
625pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
626
627impl ConnectionState {
628    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
629    /// and an error if it closed because of an error.
630    pub async fn wait_for_close(mut self) -> Result<(), Error> {
631        self.0
632            .next()
633            .await
634            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
635    }
636}
637
638/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
639/// [`Connection::reject`]ed.
640#[derive(Debug)]
641pub struct ConnectionRequest {
642    address: Address,
643}
644
645impl ConnectionRequest {
646    /// Creates a new connection request for the given address.
647    pub fn new(address: Address) -> Self {
648        Self { address }
649    }
650
651    /// The address this connection request is being made for.
652    pub fn address(&self) -> &Address {
653        &self.address
654    }
655}
656
657#[cfg(test)]
658mod test {
659    use std::sync::Arc;
660
661    use crate::VsockPacketIterator;
662
663    use super::*;
664
665    #[cfg(not(target_os = "fuchsia"))]
666    use fuchsia_async::emulated_handle::Socket as SyncSocket;
667    use fuchsia_async::{Socket, Task};
668    use futures::StreamExt;
669    #[cfg(target_os = "fuchsia")]
670    use zx::Socket as SyncSocket;
671
672    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
673        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
674        loop {
675            println!("waiting for usb packet");
676            builder = echo_connection.fill_usb_packet(builder).await;
677            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
678            println!("got usb packet, echoing it back to the other side");
679            let mut packet_count = 0;
680            for packet in packets {
681                let packet = packet.unwrap();
682                match packet.header.packet_type {
683                    PacketType::Connect => {
684                        // respond with an accept packet
685                        let mut reply_header = packet.header.clone();
686                        reply_header.packet_type = PacketType::Accept;
687                        echo_connection
688                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
689                            .await
690                            .unwrap();
691                    }
692                    PacketType::Accept => {
693                        // just ignore it
694                    }
695                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
696                }
697                packet_count += 1;
698            }
699            println!("handled {packet_count} packets");
700        }
701    }
702
703    #[fuchsia::test]
704    async fn data_over_control_socket() {
705        let (socket, other_socket) = SyncSocket::create_stream();
706        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
707        let mut socket = Socket::from_socket(socket);
708        let connection = Arc::new(Connection::new(
709            ProtocolVersion::LATEST,
710            Some(Socket::from_socket(other_socket)),
711            incoming_requests_tx,
712        ));
713
714        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
715
716        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
717            println!("round tripping packet of size {size}");
718            socket.write_all(&vec![size; size as usize]).await.unwrap();
719            let mut buf = vec![0u8; size as usize];
720            socket.read_exact(&mut buf).await.unwrap();
721            assert_eq!(buf, vec![size; size as usize]);
722        }
723        echo_task.abort().await;
724    }
725
726    #[fuchsia::test]
727    async fn data_over_normal_outgoing_socket() {
728        let (_control_socket, other_socket) = SyncSocket::create_stream();
729        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
730        let connection = Arc::new(Connection::new(
731            ProtocolVersion::LATEST,
732            Some(Socket::from_socket(other_socket)),
733            incoming_requests_tx,
734        ));
735
736        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
737
738        let (socket, other_socket) = SyncSocket::create_stream();
739        let mut socket = Socket::from_socket(socket);
740        connection
741            .connect(
742                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
743                Socket::from_socket(other_socket),
744            )
745            .await
746            .unwrap();
747
748        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
749            println!("round tripping packet of size {size}");
750            socket.write_all(&vec![size; size as usize]).await.unwrap();
751            let mut buf = vec![0u8; size as usize];
752            socket.read_exact(&mut buf).await.unwrap();
753            assert_eq!(buf, vec![size; size as usize]);
754        }
755        echo_task.abort().await;
756    }
757
758    #[fuchsia::test]
759    async fn data_over_normal_incoming_socket() {
760        let (_control_socket, other_socket) = SyncSocket::create_stream();
761        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
762        let connection = Arc::new(Connection::new(
763            ProtocolVersion::LATEST,
764            Some(Socket::from_socket(other_socket)),
765            incoming_requests_tx,
766        ));
767
768        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
769
770        let header = &mut Header::new(PacketType::Connect);
771        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
772        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
773
774        let request = incoming_requests.next().await.unwrap();
775        assert_eq!(
776            request.address,
777            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
778        );
779
780        let (socket, other_socket) = SyncSocket::create_stream();
781        let mut socket = Socket::from_socket(socket);
782        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
783
784        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
785            println!("round tripping packet of size {size}");
786            socket.write_all(&vec![size; size as usize]).await.unwrap();
787            let mut buf = vec![0u8; size as usize];
788            socket.read_exact(&mut buf).await.unwrap();
789            assert_eq!(buf, vec![size; size as usize]);
790        }
791        echo_task.abort().await;
792    }
793
794    async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
795        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
796        loop {
797            builder = from.fill_usb_packet(builder).await;
798            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
799            for packet in packets {
800                println!("forwarding vsock packet");
801                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
802            }
803        }
804    }
805
806    pub(crate) trait EndToEndTestFn<R>:
807        AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
808    {
809    }
810    impl<T, R> EndToEndTestFn<R> for T where
811        T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
812    {
813    }
814
815    pub(crate) async fn end_to_end_test<R1, R2>(
816        left_side: impl EndToEndTestFn<R1>,
817        right_side: impl EndToEndTestFn<R2>,
818    ) -> (R1, R2) {
819        type Connection = crate::Connection<Vec<u8>, Socket>;
820        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
821        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
822        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
823        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
824
825        let connection1 = Arc::new(Connection::new(
826            ProtocolVersion::LATEST,
827            Some(Socket::from_socket(other_socket1)),
828            incoming_requests_tx1,
829        ));
830        let connection2 = Arc::new(Connection::new(
831            ProtocolVersion::LATEST,
832            Some(Socket::from_socket(other_socket2)),
833            incoming_requests_tx2,
834        ));
835
836        let conn1 = connection1.clone();
837        let conn2 = connection2.clone();
838        let passthrough_task = Task::spawn(async move {
839            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
840            println!("passthrough task loop ended");
841        });
842
843        let res = futures::join!(
844            left_side(connection1, incoming_requests1),
845            right_side(connection2, incoming_requests2)
846        );
847        passthrough_task.abort().await;
848        res
849    }
850
851    #[fuchsia::test]
852    async fn data_over_end_to_end() {
853        end_to_end_test(
854            async |conn, _incoming| {
855                println!("sending request on connection 1");
856                let (socket, other_socket) = SyncSocket::create_stream();
857                let mut socket = Socket::from_socket(socket);
858                let state = conn
859                    .connect(
860                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
861                        Socket::from_socket(other_socket),
862                    )
863                    .await
864                    .unwrap();
865
866                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
867                    println!("round tripping packet of size {size}");
868                    socket.write_all(&vec![size; size as usize]).await.unwrap();
869                }
870                drop(socket);
871                state.wait_for_close().await.unwrap();
872            },
873            async |conn, mut incoming| {
874                println!("accepting request on connection 2");
875                let request = incoming.next().await.unwrap();
876                assert_eq!(
877                    request.address,
878                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
879                );
880
881                let (socket, other_socket) = SyncSocket::create_stream();
882                let mut socket = Socket::from_socket(socket);
883                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
884
885                println!("accepted request on connection 2");
886                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
887                    let mut buf = vec![0u8; size as usize];
888                    socket.read_exact(&mut buf).await.unwrap();
889                    assert_eq!(buf, vec![size; size as usize]);
890                }
891                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
892                state.wait_for_close().await.unwrap();
893            },
894        )
895        .await;
896    }
897
898    #[fuchsia::test]
899    async fn normal_close_end_to_end() {
900        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
901        end_to_end_test(
902            async |conn, _incoming| {
903                let (socket, other_socket) = SyncSocket::create_stream();
904                let mut socket = Socket::from_socket(socket);
905                let state =
906                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
907                conn.close(&addr).await;
908                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
909                state.wait_for_close().await.unwrap();
910            },
911            async |conn, mut incoming| {
912                println!("accepting request on connection 2");
913                let request = incoming.next().await.unwrap();
914                assert_eq!(request.address, addr.clone(),);
915
916                let (socket, other_socket) = SyncSocket::create_stream();
917                let mut socket = Socket::from_socket(socket);
918                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
919                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
920                state.wait_for_close().await.unwrap();
921            },
922        )
923        .await;
924    }
925
926    #[fuchsia::test]
927    async fn reset_end_to_end() {
928        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
929        end_to_end_test(
930            async |conn, _incoming| {
931                let (socket, other_socket) = SyncSocket::create_stream();
932                let mut socket = Socket::from_socket(socket);
933                let state =
934                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
935                conn.reset(&addr).await.unwrap();
936                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
937                state.wait_for_close().await.expect_err("expected reset");
938            },
939            async |conn, mut incoming| {
940                println!("accepting request on connection 2");
941                let request = incoming.next().await.unwrap();
942                assert_eq!(request.address, addr.clone(),);
943
944                let (socket, other_socket) = SyncSocket::create_stream();
945                let mut socket = Socket::from_socket(socket);
946                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
947                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
948                state.wait_for_close().await.unwrap();
949            },
950        )
951        .await;
952    }
953}