usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::{Scope, Socket};
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::{Address, Header, Packet, PacketType, UsbPacketBuilder, UsbPacketFiller};
19
20/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
21pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
22impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
23
24/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
25/// transmitted for them.
26///
27/// This implementation aims to be agnostic to both the underlying transport and the buffers used
28/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
29/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
30///
31/// The client of this library will:
32/// - Use methods on this struct to initiate actions like connecting and accepting
33/// connections to the other end.
34/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
35/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
36pub struct Connection<B> {
37    control_socket_writer: Mutex<WriteHalf<Socket>>,
38    packet_filler: Arc<UsbPacketFiller<B>>,
39    connections: std::sync::Mutex<HashMap<Address, VsockConnection>>,
40    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
41    _task_scope: Scope,
42}
43
44impl<B: PacketBuffer> Connection<B> {
45    /// Creates a new connection with:
46    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
47    /// between host and device) can be read and written from.
48    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
49    /// connection requests from the other side.
50    pub fn new(
51        control_socket: Socket,
52        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
53    ) -> Self {
54        let (control_socket_reader, control_socket_writer) = control_socket.split();
55        let control_socket_writer = Mutex::new(control_socket_writer);
56        let packet_filler = Arc::new(UsbPacketFiller::default());
57        let connections = Default::default();
58        let task_scope = Scope::new_with_name("vsock_usb");
59        task_scope.spawn(Self::run_socket(
60            control_socket_reader,
61            Address::default(),
62            packet_filler.clone(),
63        ));
64        Self {
65            control_socket_writer,
66            packet_filler,
67            connections,
68            incoming_requests_tx,
69            _task_scope: task_scope,
70        }
71    }
72
73    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
74        let header = &mut Header::new(PacketType::Finish);
75        header.set_address(address);
76        usb_packet_filler
77            .write_vsock_packet(&Packet { header, payload: &[] })
78            .await
79            .expect("Finish packet should never be too big");
80    }
81
82    async fn run_socket(
83        mut reader: ReadHalf<Socket>,
84        address: Address,
85        usb_packet_filler: Arc<UsbPacketFiller<B>>,
86    ) {
87        let mut buf = [0; 4096];
88        loop {
89            log::trace!("reading from control socket");
90            let read = match reader.read(&mut buf).await {
91                Ok(0) => {
92                    if !address.is_zeros() {
93                        Self::send_close_packet(&address, &usb_packet_filler).await;
94                    }
95                    return;
96                }
97                Ok(read) => read,
98                Err(err) => {
99                    if address.is_zeros() {
100                        log::error!("Error reading usb socket: {err:?}");
101                    } else {
102                        Self::send_close_packet(&address, &usb_packet_filler).await;
103                    }
104                    return;
105                }
106            };
107            log::trace!("writing {read} bytes to vsock packet");
108            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
109            log::trace!("wrote {read} bytes to vsock packet");
110        }
111    }
112
113    fn set_connection(&self, address: Address, state: VsockConnectionState) -> Result<(), Error> {
114        let mut connections = self.connections.lock().unwrap();
115        if !connections.contains_key(&address) {
116            connections.insert(address.clone(), VsockConnection { _address: address, state });
117            Ok(())
118        } else {
119            Err(Error::other(format!("connection on address {address:?} already set")))
120        }
121    }
122
123    /// Sends an echo packet to the remote end that you don't care about the reply, so it doesn't
124    /// have a distinct target address or payload.
125    pub async fn send_empty_echo(&self) {
126        debug!("Sending empty echo packet");
127        let header = &mut Header::new(PacketType::Echo);
128        self.packet_filler
129            .write_vsock_packet(&Packet { header, payload: &[] })
130            .await
131            .expect("empty echo packet should never be too large to fit in a usb packet");
132    }
133
134    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
135    /// to read and write from. The function will complete when the other end has accepted or
136    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
137    /// for the connection to be closed.
138    pub async fn connect(&self, addr: Address, socket: Socket) -> Result<ConnectionState, Error> {
139        let (read_socket, write_socket) = socket.split();
140        let write_socket = Arc::new(Mutex::new(write_socket));
141        let (connected_tx, connected_rx) = oneshot::channel();
142
143        self.set_connection(
144            addr.clone(),
145            VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
146        )?;
147
148        let header = &mut Header::new(PacketType::Connect);
149        header.set_address(&addr);
150        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
151        connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
152    }
153
154    /// Sends a request for the other end to close the connection.
155    pub async fn close(&self, address: &Address) {
156        Self::send_close_packet(address, &self.packet_filler).await
157    }
158
159    /// Resets the named connection without going through a close request.
160    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
161        let mut notify = None;
162        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
163            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
164                notify = Some(notify_closed);
165            }
166        } else {
167            return Err(Error::other(
168                "Client asked to reset connection {address:?} that did not exist",
169            ));
170        }
171
172        if let Some(mut notify) = notify {
173            notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
174        }
175
176        let header = &mut Header::new(PacketType::Reset);
177        header.set_address(address);
178        self.packet_filler
179            .write_vsock_packet(&Packet { header, payload: &[] })
180            .await
181            .expect("Reset packet should never be too big");
182        Ok(())
183    }
184
185    /// Accepts a connection for which an outstanding connection request has been made, and
186    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
187    /// can be used to wait for the connection to be closed.
188    pub async fn accept(
189        &self,
190        request: ConnectionRequest,
191        socket: Socket,
192    ) -> Result<ConnectionState, Error> {
193        let address = request.address;
194        let notify_closed_rx;
195        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
196            let VsockConnectionState::ConnectingIncoming = &conn.state else {
197                return Err(Error::other(format!(
198                    "Attempted to accept connection that was not waiting at {address:?}"
199                )));
200            };
201
202            let (read_socket, write_socket) = socket.split();
203            let writer = Arc::new(Mutex::new(write_socket));
204            let notify_closed = mpsc::channel(2);
205            notify_closed_rx = notify_closed.1;
206            let notify_closed = notify_closed.0;
207
208            let reader_task = Scope::new_with_name("connection-reader");
209            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
210
211            conn.state = VsockConnectionState::Connected {
212                writer,
213                _reader_scope: reader_task,
214                notify_closed,
215            };
216        } else {
217            return Err(Error::other(format!(
218                "Attempting to accept connection that did not exist at {address:?}"
219            )));
220        }
221        let header = &mut Header::new(PacketType::Accept);
222        header.set_address(&address);
223        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
224        Ok(ConnectionState(notify_closed_rx))
225    }
226
227    /// Rejects a pending connection request from the other side.
228    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
229        let address = request.address;
230        match self.connections.lock().unwrap().entry(address.clone()) {
231            Entry::Occupied(entry) => {
232                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
233                    return Err(Error::other(format!(
234                        "Attempted to reject connection that was not waiting at {address:?}"
235                    )));
236                };
237                entry.remove();
238            }
239            Entry::Vacant(_) => {
240                return Err(Error::other(format!(
241                    "Attempted to reject connection that was not waiting at {address:?}"
242                )));
243            }
244        }
245
246        let header = &mut Header::new(PacketType::Reset);
247        header.set_address(&address);
248        self.packet_filler
249            .write_vsock_packet(&Packet { header, payload: &[] })
250            .await
251            .expect("accept packet should never be too large for packet buffer");
252        Ok(())
253    }
254
255    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
256        // all zero data packets go to the control channel
257        if address.is_zeros() {
258            let written = self.control_socket_writer.lock().await.write(payload).await?;
259            assert_eq!(written, payload.len());
260            Ok(())
261        } else {
262            let payload_socket;
263            if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
264                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
265                    warn!(
266                        "Received data packet for connection in unexpected state for {address:?}"
267                    );
268                    return Ok(());
269                };
270                payload_socket = writer.clone();
271            } else {
272                warn!("Received data packet for connection that didn't exist at {address:?}");
273                return Ok(());
274            }
275            if let Err(err) = payload_socket.lock().await.write_all(payload).await {
276                debug!(
277                    "Write to socket address {address:?} failed, resetting connection immediately: {err:?}"
278                );
279                self.reset(&address).await.inspect_err(|err| warn!("Attempt to reset connection to {address:?} failed after write error: {err:?}")).ok();
280            }
281            Ok(())
282        }
283    }
284
285    async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
286        debug!("received echo for {address:?} with payload {payload:?}");
287        let header = &mut Header::new(PacketType::EchoReply);
288        header.payload_len.set(payload.len() as u32);
289        header.set_address(&address);
290        self.packet_filler
291            .write_vsock_packet(&Packet { header, payload })
292            .await
293            .map_err(|_| Error::other("Echo packet was too large to be sent back"))
294    }
295
296    async fn handle_echo_reply_packet(
297        &self,
298        address: Address,
299        payload: &[u8],
300    ) -> Result<(), Error> {
301        // ignore but log replies
302        debug!("received echo reply for {address:?} with payload {payload:?}");
303        Ok(())
304    }
305
306    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
307        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
308            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
309            let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
310            else {
311                warn!("Received accept packet for connection in unexpected state for {address:?}");
312                return Ok(());
313            };
314            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
315            if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
316                warn!(
317                    "Accept packet received for {address:?} but connect caller stopped waiting for it"
318                );
319            }
320
321            let reader_task = Scope::new_with_name("connection-reader");
322            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
323            conn.state = VsockConnectionState::Connected {
324                writer,
325                _reader_scope: reader_task,
326                notify_closed,
327            };
328        } else {
329            warn!("Got accept packet for connection that was not being made at {address:?}");
330            return Ok(());
331        }
332        Ok(())
333    }
334
335    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
336        trace!("received connect packet for {address:?}");
337        match self.connections.lock().unwrap().entry(address.clone()) {
338            Entry::Vacant(entry) => {
339                debug!("valid connect request for {address:?}");
340                entry.insert(VsockConnection {
341                    _address: address,
342                    state: VsockConnectionState::ConnectingIncoming,
343                });
344            }
345            Entry::Occupied(_) => {
346                warn!(
347                    "Received connect packet for already existing connection for address {address:?}. Ignoring"
348                );
349                return Ok(());
350            }
351        }
352
353        trace!("sending incoming connection request to client for {address:?}");
354        let connection_request = ConnectionRequest { address };
355        self.incoming_requests_tx
356            .clone()
357            .send(connection_request)
358            .await
359            .inspect(|_| trace!("sent incoming request for {address:?}"))
360            .map_err(|_| Error::other("Failed to send connection request"))
361    }
362
363    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
364        trace!("received finish packet for {address:?}");
365        let mut notify;
366        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
367            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
368                warn!(
369                    "Received finish (close) packet for {address:?} which was not in a connected state. Ignoring and dropping connection state."
370                );
371                return Ok(());
372            };
373            notify = notify_closed;
374        } else {
375            warn!(
376                "Received finish (close) packet for connection that didn't exist on address {address:?}. Ignoring"
377            );
378            return Ok(());
379        }
380
381        notify.send(Ok(())).await.ok();
382
383        let header = &mut Header::new(PacketType::Reset);
384        header.set_address(&address);
385        self.packet_filler
386            .write_vsock_packet(&Packet { header, payload: &[] })
387            .await
388            .expect("accept packet should never be too large for packet buffer");
389        Ok(())
390    }
391
392    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
393        trace!("received reset packet for {address:?}");
394        let mut notify = None;
395        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
396            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
397                notify = Some(notify_closed);
398            } else {
399                debug!(
400                    "Received reset packet for connection that wasn't in a connecting or disconnected state on address {address:?}."
401                );
402            }
403        } else {
404            warn!(
405                "Received reset packet for connection that didn't exist on address {address:?}. Ignoring"
406            );
407        }
408
409        if let Some(mut notify) = notify {
410            notify.send(Ok(())).await.ok();
411        }
412        Ok(())
413    }
414
415    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
416    /// or the overall state of the connection.
417    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
418        trace!("received vsock packet {header:?}", header = packet.header);
419        let payload_len = packet.header.payload_len.get() as usize;
420        let payload = &packet.payload[..payload_len];
421        let address = Address::from(packet.header);
422        match packet.header.packet_type {
423            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
424            PacketType::Data => self.handle_data_packet(address, payload).await,
425            PacketType::Accept => self.handle_accept_packet(address).await,
426            PacketType::Connect => self.handle_connect_packet(address).await,
427            PacketType::Finish => self.handle_finish_packet(address).await,
428            PacketType::Reset => self.handle_reset_packet(address).await,
429            PacketType::Echo => self.handle_echo_packet(address, payload).await,
430            PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
431        }
432    }
433
434    /// Provides a packet builder for the state machine to write packets to. Returns a future that
435    /// will be fulfilled when there is data available to send on the packet.
436    ///
437    /// # Panics
438    ///
439    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
440    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
441        self.packet_filler.fill_usb_packet(builder).await
442    }
443}
444
445enum VsockConnectionState {
446    ConnectingOutgoing(
447        Arc<Mutex<WriteHalf<Socket>>>,
448        ReadHalf<Socket>,
449        oneshot::Sender<Result<ConnectionState, Error>>,
450    ),
451    ConnectingIncoming,
452    Connected {
453        writer: Arc<Mutex<WriteHalf<Socket>>>,
454        notify_closed: mpsc::Sender<Result<(), Error>>,
455        _reader_scope: Scope,
456    },
457    Invalid,
458}
459
460struct VsockConnection {
461    _address: Address,
462    state: VsockConnectionState,
463}
464
465/// A handle for the state of a connection established with either [`Connection::connect`] or
466/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
467/// needing to hold on to the Socket end.
468#[derive(Debug)]
469pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
470
471impl ConnectionState {
472    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
473    /// and an error if it closed because of an error.
474    pub async fn wait_for_close(mut self) -> Result<(), Error> {
475        self.0
476            .next()
477            .await
478            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
479    }
480}
481
482/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
483/// [`Connection::reject`]ed.
484#[derive(Debug)]
485pub struct ConnectionRequest {
486    address: Address,
487}
488
489impl ConnectionRequest {
490    /// Creates a new connection request for the given address.
491    pub fn new(address: Address) -> Self {
492        Self { address }
493    }
494
495    /// The address this connection request is being made for.
496    pub fn address(&self) -> &Address {
497        &self.address
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use std::sync::Arc;
504
505    use crate::VsockPacketIterator;
506
507    use super::*;
508
509    #[cfg(not(target_os = "fuchsia"))]
510    use fuchsia_async::emulated_handle::Socket as SyncSocket;
511    use fuchsia_async::Task;
512    use futures::StreamExt;
513    #[cfg(target_os = "fuchsia")]
514    use zx::Socket as SyncSocket;
515
516    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>>>) {
517        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
518        loop {
519            println!("waiting for usb packet");
520            builder = echo_connection.fill_usb_packet(builder).await;
521            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
522            println!("got usb packet, echoing it back to the other side");
523            let mut packet_count = 0;
524            for packet in packets {
525                let packet = packet.unwrap();
526                match packet.header.packet_type {
527                    PacketType::Connect => {
528                        // respond with an accept packet
529                        let mut reply_header = packet.header.clone();
530                        reply_header.packet_type = PacketType::Accept;
531                        echo_connection
532                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
533                            .await
534                            .unwrap();
535                    }
536                    PacketType::Accept => {
537                        // just ignore it
538                    }
539                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
540                }
541                packet_count += 1;
542            }
543            println!("handled {packet_count} packets");
544        }
545    }
546
547    #[fuchsia::test]
548    async fn data_over_control_socket() {
549        let (socket, other_socket) = SyncSocket::create_stream();
550        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
551        let mut socket = Socket::from_socket(socket);
552        let connection =
553            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
554
555        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
556
557        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
558            println!("round tripping packet of size {size}");
559            socket.write_all(&vec![size; size as usize]).await.unwrap();
560            let mut buf = vec![0u8; size as usize];
561            socket.read_exact(&mut buf).await.unwrap();
562            assert_eq!(buf, vec![size; size as usize]);
563        }
564        echo_task.abort().await;
565    }
566
567    #[fuchsia::test]
568    async fn data_over_normal_outgoing_socket() {
569        let (_control_socket, other_socket) = SyncSocket::create_stream();
570        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
571        let connection =
572            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
573
574        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
575
576        let (socket, other_socket) = SyncSocket::create_stream();
577        let mut socket = Socket::from_socket(socket);
578        connection
579            .connect(
580                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
581                Socket::from_socket(other_socket),
582            )
583            .await
584            .unwrap();
585
586        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
587            println!("round tripping packet of size {size}");
588            socket.write_all(&vec![size; size as usize]).await.unwrap();
589            let mut buf = vec![0u8; size as usize];
590            socket.read_exact(&mut buf).await.unwrap();
591            assert_eq!(buf, vec![size; size as usize]);
592        }
593        echo_task.abort().await;
594    }
595
596    #[fuchsia::test]
597    async fn data_over_normal_incoming_socket() {
598        let (_control_socket, other_socket) = SyncSocket::create_stream();
599        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
600        let connection =
601            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
602
603        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
604
605        let header = &mut Header::new(PacketType::Connect);
606        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
607        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
608
609        let request = incoming_requests.next().await.unwrap();
610        assert_eq!(
611            request.address,
612            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
613        );
614
615        let (socket, other_socket) = SyncSocket::create_stream();
616        let mut socket = Socket::from_socket(socket);
617        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
618
619        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
620            println!("round tripping packet of size {size}");
621            socket.write_all(&vec![size; size as usize]).await.unwrap();
622            let mut buf = vec![0u8; size as usize];
623            socket.read_exact(&mut buf).await.unwrap();
624            assert_eq!(buf, vec![size; size as usize]);
625        }
626        echo_task.abort().await;
627    }
628
629    async fn copy_connection(from: &Connection<Vec<u8>>, to: &Connection<Vec<u8>>) {
630        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
631        loop {
632            builder = from.fill_usb_packet(builder).await;
633            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
634            for packet in packets {
635                println!("forwarding vsock packet");
636                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
637            }
638        }
639    }
640
641    pub(crate) trait EndToEndTestFn<R>:
642        AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
643    {
644    }
645    impl<T, R> EndToEndTestFn<R> for T where
646        T: AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
647    {
648    }
649
650    pub(crate) async fn end_to_end_test<R1, R2>(
651        left_side: impl EndToEndTestFn<R1>,
652        right_side: impl EndToEndTestFn<R2>,
653    ) -> (R1, R2) {
654        type Connection = crate::Connection<Vec<u8>>;
655        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
656        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
657        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
658        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
659
660        let connection1 =
661            Arc::new(Connection::new(Socket::from_socket(other_socket1), incoming_requests_tx1));
662        let connection2 =
663            Arc::new(Connection::new(Socket::from_socket(other_socket2), incoming_requests_tx2));
664
665        let conn1 = connection1.clone();
666        let conn2 = connection2.clone();
667        let passthrough_task = Task::spawn(async move {
668            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
669            println!("passthrough task loop ended");
670        });
671
672        let res = futures::join!(
673            left_side(connection1, incoming_requests1),
674            right_side(connection2, incoming_requests2)
675        );
676        passthrough_task.abort().await;
677        res
678    }
679
680    #[fuchsia::test]
681    async fn data_over_end_to_end() {
682        end_to_end_test(
683            async |conn, _incoming| {
684                println!("sending request on connection 1");
685                let (socket, other_socket) = SyncSocket::create_stream();
686                let mut socket = Socket::from_socket(socket);
687                let state = conn
688                    .connect(
689                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
690                        Socket::from_socket(other_socket),
691                    )
692                    .await
693                    .unwrap();
694
695                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
696                    println!("round tripping packet of size {size}");
697                    socket.write_all(&vec![size; size as usize]).await.unwrap();
698                }
699                drop(socket);
700                state.wait_for_close().await.unwrap();
701            },
702            async |conn, mut incoming| {
703                println!("accepting request on connection 2");
704                let request = incoming.next().await.unwrap();
705                assert_eq!(
706                    request.address,
707                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
708                );
709
710                let (socket, other_socket) = SyncSocket::create_stream();
711                let mut socket = Socket::from_socket(socket);
712                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
713
714                println!("accepted request on connection 2");
715                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
716                    let mut buf = vec![0u8; size as usize];
717                    socket.read_exact(&mut buf).await.unwrap();
718                    assert_eq!(buf, vec![size; size as usize]);
719                }
720                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
721                state.wait_for_close().await.unwrap();
722            },
723        )
724        .await;
725    }
726
727    #[fuchsia::test]
728    async fn normal_close_end_to_end() {
729        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
730        end_to_end_test(
731            async |conn, _incoming| {
732                let (socket, other_socket) = SyncSocket::create_stream();
733                let mut socket = Socket::from_socket(socket);
734                let state =
735                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
736                conn.close(&addr).await;
737                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
738                state.wait_for_close().await.unwrap();
739            },
740            async |conn, mut incoming| {
741                println!("accepting request on connection 2");
742                let request = incoming.next().await.unwrap();
743                assert_eq!(request.address, addr.clone(),);
744
745                let (socket, other_socket) = SyncSocket::create_stream();
746                let mut socket = Socket::from_socket(socket);
747                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
748                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
749                state.wait_for_close().await.unwrap();
750            },
751        )
752        .await;
753    }
754
755    #[fuchsia::test]
756    async fn reset_end_to_end() {
757        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
758        end_to_end_test(
759            async |conn, _incoming| {
760                let (socket, other_socket) = SyncSocket::create_stream();
761                let mut socket = Socket::from_socket(socket);
762                let state =
763                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
764                conn.reset(&addr).await.unwrap();
765                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
766                state.wait_for_close().await.expect_err("expected reset");
767            },
768            async |conn, mut incoming| {
769                println!("accepting request on connection 2");
770                let request = incoming.next().await.unwrap();
771                assert_eq!(request.address, addr.clone(),);
772
773                let (socket, other_socket) = SyncSocket::create_stream();
774                let mut socket = Socket::from_socket(socket);
775                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
776                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
777                state.wait_for_close().await.unwrap();
778            },
779        )
780        .await;
781    }
782}