usb_vsock/
connection.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::{Scope, Socket};
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::{Address, Header, Packet, PacketType, UsbPacketBuilder, UsbPacketFiller};
19
20/// A marker trait for types that are capable of being used as buffers for a [`Connection`].
21pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
22impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
23
24/// Manages the state of a vsock-over-usb connection and the sockets over which data is being
25/// transmitted for them.
26///
27/// This implementation aims to be agnostic to both the underlying transport and the buffers used
28/// to read and write from it. The buffer type must conform to [`PacketBuffer`], which is essentially
29/// a type that holds a mutable slice of bytes and is [`Send`] and [`Unpin`]-able.
30///
31/// The client of this library will:
32/// - Use methods on this struct to initiate actions like connecting and accepting
33/// connections to the other end.
34/// - Provide buffers to be filled and sent to the other end with [`Connection::fill_usb_packet`].
35/// - Pump usb packets received into it using [`Connection::handle_vsock_packet`].
36pub struct Connection<B> {
37    control_socket_writer: Mutex<WriteHalf<Socket>>,
38    packet_filler: Arc<UsbPacketFiller<B>>,
39    connections: std::sync::Mutex<HashMap<Address, VsockConnection>>,
40    incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
41    _task_scope: Scope,
42}
43
44impl<B: PacketBuffer> Connection<B> {
45    /// Creates a new connection with:
46    /// - a `control_socket`, over which data addressed to and from cid 0, port 0 (a control channel
47    /// between host and device) can be read and written from.
48    /// - An `incoming_requests_tx` that is the sender half of a request queue for incoming
49    /// connection requests from the other side.
50    pub fn new(
51        control_socket: Socket,
52        incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
53    ) -> Self {
54        let (control_socket_reader, control_socket_writer) = control_socket.split();
55        let control_socket_writer = Mutex::new(control_socket_writer);
56        let packet_filler = Arc::new(UsbPacketFiller::default());
57        let connections = Default::default();
58        let task_scope = Scope::new_with_name("vsock_usb");
59        task_scope.spawn(Self::run_socket(
60            control_socket_reader,
61            Address::default(),
62            packet_filler.clone(),
63        ));
64        Self {
65            control_socket_writer,
66            packet_filler,
67            connections,
68            incoming_requests_tx,
69            _task_scope: task_scope,
70        }
71    }
72
73    async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
74        let header = &mut Header::new(PacketType::Finish);
75        header.set_address(address);
76        usb_packet_filler
77            .write_vsock_packet(&Packet { header, payload: &[] })
78            .await
79            .expect("Finish packet should never be too big");
80    }
81
82    async fn run_socket(
83        mut reader: ReadHalf<Socket>,
84        address: Address,
85        usb_packet_filler: Arc<UsbPacketFiller<B>>,
86    ) {
87        let mut buf = [0; 4096];
88        loop {
89            log::trace!("reading from control socket");
90            let read = match reader.read(&mut buf).await {
91                Ok(0) => {
92                    if !address.is_zeros() {
93                        Self::send_close_packet(&address, &usb_packet_filler).await;
94                    }
95                    return;
96                }
97                Ok(read) => read,
98                Err(err) => {
99                    if address.is_zeros() {
100                        log::error!("Error reading usb socket: {err:?}");
101                    } else {
102                        Self::send_close_packet(&address, &usb_packet_filler).await;
103                    }
104                    return;
105                }
106            };
107            log::trace!("writing {read} bytes to vsock packet");
108            usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
109            log::trace!("wrote {read} bytes to vsock packet");
110        }
111    }
112
113    fn set_connection(&self, address: Address, state: VsockConnectionState) -> Result<(), Error> {
114        let mut connections = self.connections.lock().unwrap();
115        if !connections.contains_key(&address) {
116            connections.insert(address.clone(), VsockConnection { _address: address, state });
117            Ok(())
118        } else {
119            Err(Error::other(format!("connection on address {address:?} already set")))
120        }
121    }
122
123    /// Starts a connection attempt to the other end of the USB connection, and provides a socket
124    /// to read and write from. The function will complete when the other end has accepted or
125    /// rejected the connection, and the returned [`ConnectionState`] handle can be used to wait
126    /// for the connection to be closed.
127    pub async fn connect(&self, addr: Address, socket: Socket) -> Result<ConnectionState, Error> {
128        let (read_socket, write_socket) = socket.split();
129        let write_socket = Arc::new(Mutex::new(write_socket));
130        let (connected_tx, connected_rx) = oneshot::channel();
131
132        self.set_connection(
133            addr.clone(),
134            VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
135        )?;
136
137        let header = &mut Header::new(PacketType::Connect);
138        header.set_address(&addr);
139        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
140        connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
141    }
142
143    /// Sends a request for the other end to close the connection.
144    pub async fn close(&self, address: &Address) {
145        Self::send_close_packet(address, &self.packet_filler).await
146    }
147
148    /// Resets the named connection without going through a close request.
149    pub async fn reset(&self, address: &Address) -> Result<(), Error> {
150        let mut notify = None;
151        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
152            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
153                notify = Some(notify_closed);
154            }
155        } else {
156            return Err(Error::other(
157                "Client asked to reset connection {address:?} that did not exist",
158            ));
159        }
160
161        if let Some(mut notify) = notify {
162            notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
163        }
164
165        let header = &mut Header::new(PacketType::Reset);
166        header.set_address(address);
167        self.packet_filler
168            .write_vsock_packet(&Packet { header, payload: &[] })
169            .await
170            .expect("Reset packet should never be too big");
171        Ok(())
172    }
173
174    /// Accepts a connection for which an outstanding connection request has been made, and
175    /// provides a socket to read and write data packets to and from. The returned [`ConnectionState`]
176    /// can be used to wait for the connection to be closed.
177    pub async fn accept(
178        &self,
179        request: ConnectionRequest,
180        socket: Socket,
181    ) -> Result<ConnectionState, Error> {
182        let address = request.address;
183        let notify_closed_rx;
184        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
185            let VsockConnectionState::ConnectingIncoming = &conn.state else {
186                return Err(Error::other(format!(
187                    "Attempted to accept connection that was not waiting at {address:?}"
188                )));
189            };
190
191            let (read_socket, write_socket) = socket.split();
192            let writer = Arc::new(Mutex::new(write_socket));
193            let notify_closed = mpsc::channel(2);
194            notify_closed_rx = notify_closed.1;
195            let notify_closed = notify_closed.0;
196
197            let reader_task = Scope::new_with_name("connection-reader");
198            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
199
200            conn.state = VsockConnectionState::Connected {
201                writer,
202                _reader_scope: reader_task,
203                notify_closed,
204            };
205        } else {
206            return Err(Error::other(format!(
207                "Attempting to accept connection that did not exist at {address:?}"
208            )));
209        }
210        let header = &mut Header::new(PacketType::Accept);
211        header.set_address(&address);
212        self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
213        Ok(ConnectionState(notify_closed_rx))
214    }
215
216    /// Rejects a pending connection request from the other side.
217    pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
218        let address = request.address;
219        match self.connections.lock().unwrap().entry(address.clone()) {
220            Entry::Occupied(entry) => {
221                let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
222                    return Err(Error::other(format!(
223                        "Attempted to reject connection that was not waiting at {address:?}"
224                    )));
225                };
226                entry.remove();
227            }
228            Entry::Vacant(_) => {
229                return Err(Error::other(format!(
230                    "Attempted to reject connection that was not waiting at {address:?}"
231                )));
232            }
233        }
234
235        let header = &mut Header::new(PacketType::Reset);
236        header.set_address(&address);
237        self.packet_filler
238            .write_vsock_packet(&Packet { header, payload: &[] })
239            .await
240            .expect("accept packet should never be too large for packet buffer");
241        Ok(())
242    }
243
244    async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
245        // all zero data packets go to the control channel
246        if address.is_zeros() {
247            let written = self.control_socket_writer.lock().await.write(payload).await?;
248            assert_eq!(written, payload.len());
249            Ok(())
250        } else {
251            let payload_socket;
252            if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
253                let VsockConnectionState::Connected { writer, .. } = &conn.state else {
254                    warn!(
255                        "Received data packet for connection in unexpected state for {address:?}"
256                    );
257                    return Ok(());
258                };
259                payload_socket = writer.clone();
260            } else {
261                warn!("Received data packet for connection that didn't exist at {address:?}");
262                return Ok(());
263            }
264            if let Err(err) = payload_socket.lock().await.write_all(payload).await {
265                debug!("Write to socket address {address:?} failed, resetting connection immediately: {err:?}");
266                self.reset(&address).await.inspect_err(|err| warn!("Attempt to reset connection to {address:?} failed after write error: {err:?}")).ok();
267            }
268            Ok(())
269        }
270    }
271
272    async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
273        if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
274            let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
275            let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
276            else {
277                warn!("Received accept packet for connection in unexpected state for {address:?}");
278                return Ok(());
279            };
280            let (notify_closed, notify_closed_rx) = mpsc::channel(2);
281            if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
282                warn!("Accept packet received for {address:?} but connect caller stopped waiting for it");
283            }
284
285            let reader_task = Scope::new_with_name("connection-reader");
286            reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
287            conn.state = VsockConnectionState::Connected {
288                writer,
289                _reader_scope: reader_task,
290                notify_closed,
291            };
292        } else {
293            warn!("Got accept packet for connection that was not being made at {address:?}");
294            return Ok(());
295        }
296        Ok(())
297    }
298
299    async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
300        trace!("received connect packet for {address:?}");
301        match self.connections.lock().unwrap().entry(address.clone()) {
302            Entry::Vacant(entry) => {
303                debug!("valid connect request for {address:?}");
304                entry.insert(VsockConnection {
305                    _address: address,
306                    state: VsockConnectionState::ConnectingIncoming,
307                });
308            }
309            Entry::Occupied(_) => {
310                warn!("Received connect packet for already existing connection for address {address:?}. Ignoring");
311                return Ok(());
312            }
313        }
314
315        trace!("sending incoming connection request to client for {address:?}");
316        let connection_request = ConnectionRequest { address };
317        self.incoming_requests_tx
318            .clone()
319            .send(connection_request)
320            .await
321            .inspect(|_| trace!("sent incoming request for {address:?}"))
322            .map_err(|_| Error::other("Failed to send connection request"))
323    }
324
325    async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
326        trace!("received finish packet for {address:?}");
327        let mut notify;
328        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
329            let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
330                warn!("Received finish (close) packet for {address:?} which was not in a connected state. Ignoring and dropping connection state.");
331                return Ok(());
332            };
333            notify = notify_closed;
334        } else {
335            warn!("Received finish (close) packet for connection that didn't exist on address {address:?}. Ignoring");
336            return Ok(());
337        }
338
339        notify.send(Ok(())).await.ok();
340
341        let header = &mut Header::new(PacketType::Reset);
342        header.set_address(&address);
343        self.packet_filler
344            .write_vsock_packet(&Packet { header, payload: &[] })
345            .await
346            .expect("accept packet should never be too large for packet buffer");
347        Ok(())
348    }
349
350    async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
351        trace!("received reset packet for {address:?}");
352        let mut notify = None;
353        if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
354            if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
355                notify = Some(notify_closed);
356            } else {
357                debug!("Received reset packet for connection that wasn't in a connecting or disconnected state on address {address:?}.");
358            }
359        } else {
360            warn!("Received reset packet for connection that didn't exist on address {address:?}. Ignoring");
361        }
362
363        if let Some(mut notify) = notify {
364            notify.send(Ok(())).await.ok();
365        }
366        Ok(())
367    }
368
369    /// Dispatches the given vsock packet type and handles its effect on any outstanding connections
370    /// or the overall state of the connection.
371    pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
372        trace!("received vsock packet {header:?}", header = packet.header);
373        let payload_len = packet.header.payload_len.get() as usize;
374        let payload = &packet.payload[..payload_len];
375        let address = Address::from(packet.header);
376        match packet.header.packet_type {
377            PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
378            PacketType::Data => self.handle_data_packet(address, payload).await,
379            PacketType::Accept => self.handle_accept_packet(address).await,
380            PacketType::Connect => self.handle_connect_packet(address).await,
381            PacketType::Finish => self.handle_finish_packet(address).await,
382            PacketType::Reset => self.handle_reset_packet(address).await,
383        }
384    }
385
386    /// Provides a packet builder for the state machine to write packets to. Returns a future that
387    /// will be fulfilled when there is data available to send on the packet.
388    ///
389    /// # Panics
390    ///
391    /// Panics if called while another [`Self::fill_usb_packet`] future is pending.
392    pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
393        self.packet_filler.fill_usb_packet(builder).await
394    }
395}
396
397enum VsockConnectionState {
398    ConnectingOutgoing(
399        Arc<Mutex<WriteHalf<Socket>>>,
400        ReadHalf<Socket>,
401        oneshot::Sender<Result<ConnectionState, Error>>,
402    ),
403    ConnectingIncoming,
404    Connected {
405        writer: Arc<Mutex<WriteHalf<Socket>>>,
406        notify_closed: mpsc::Sender<Result<(), Error>>,
407        _reader_scope: Scope,
408    },
409    Invalid,
410}
411
412struct VsockConnection {
413    _address: Address,
414    state: VsockConnectionState,
415}
416
417/// A handle for the state of a connection established with either [`Connection::connect`] or
418/// [`Connection::accept`]. Use this to get notified when the connection has been closed without
419/// needing to hold on to the Socket end.
420#[derive(Debug)]
421pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
422
423impl ConnectionState {
424    /// Wait for this connection to close. Returns Ok(()) if the connection was closed without error,
425    /// and an error if it closed because of an error.
426    pub async fn wait_for_close(mut self) -> Result<(), Error> {
427        self.0
428            .next()
429            .await
430            .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
431    }
432}
433
434/// An outstanding connection request that needs to be either [`Connection::accept`]ed or
435/// [`Connection::reject`]ed.
436#[derive(Debug)]
437pub struct ConnectionRequest {
438    address: Address,
439}
440
441impl ConnectionRequest {
442    /// Creates a new connection request for the given address.
443    pub fn new(address: Address) -> Self {
444        Self { address }
445    }
446
447    /// The address this connection request is being made for.
448    pub fn address(&self) -> &Address {
449        &self.address
450    }
451}
452
453#[cfg(test)]
454mod test {
455    use std::sync::Arc;
456
457    use crate::VsockPacketIterator;
458
459    use super::*;
460
461    #[cfg(not(target_os = "fuchsia"))]
462    use fuchsia_async::emulated_handle::Socket as SyncSocket;
463    use fuchsia_async::Task;
464    use futures::StreamExt;
465    #[cfg(target_os = "fuchsia")]
466    use zx::Socket as SyncSocket;
467
468    async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>>>) {
469        let mut builder = UsbPacketBuilder::new(vec![0; 128]);
470        loop {
471            println!("waiting for usb packet");
472            builder = echo_connection.fill_usb_packet(builder).await;
473            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
474            println!("got usb packet, echoing it back to the other side");
475            let mut packet_count = 0;
476            for packet in packets {
477                let packet = packet.unwrap();
478                match packet.header.packet_type {
479                    PacketType::Connect => {
480                        // respond with an accept packet
481                        let mut reply_header = packet.header.clone();
482                        reply_header.packet_type = PacketType::Accept;
483                        echo_connection
484                            .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
485                            .await
486                            .unwrap();
487                    }
488                    PacketType::Accept => {
489                        // just ignore it
490                    }
491                    _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
492                }
493                packet_count += 1;
494            }
495            println!("handled {packet_count} packets");
496        }
497    }
498
499    #[fuchsia::test]
500    async fn data_over_control_socket() {
501        let (socket, other_socket) = SyncSocket::create_stream();
502        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
503        let mut socket = Socket::from_socket(socket);
504        let connection =
505            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
506
507        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
508
509        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
510            println!("round tripping packet of size {size}");
511            socket.write_all(&vec![size; size as usize]).await.unwrap();
512            let mut buf = vec![0u8; size as usize];
513            socket.read_exact(&mut buf).await.unwrap();
514            assert_eq!(buf, vec![size; size as usize]);
515        }
516        echo_task.cancel().await;
517    }
518
519    #[fuchsia::test]
520    async fn data_over_normal_outgoing_socket() {
521        let (_control_socket, other_socket) = SyncSocket::create_stream();
522        let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
523        let connection =
524            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
525
526        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
527
528        let (socket, other_socket) = SyncSocket::create_stream();
529        let mut socket = Socket::from_socket(socket);
530        connection
531            .connect(
532                Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
533                Socket::from_socket(other_socket),
534            )
535            .await
536            .unwrap();
537
538        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
539            println!("round tripping packet of size {size}");
540            socket.write_all(&vec![size; size as usize]).await.unwrap();
541            let mut buf = vec![0u8; size as usize];
542            socket.read_exact(&mut buf).await.unwrap();
543            assert_eq!(buf, vec![size; size as usize]);
544        }
545        echo_task.cancel().await;
546    }
547
548    #[fuchsia::test]
549    async fn data_over_normal_incoming_socket() {
550        let (_control_socket, other_socket) = SyncSocket::create_stream();
551        let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
552        let connection =
553            Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
554
555        let echo_task = Task::spawn(usb_echo_server(connection.clone()));
556
557        let header = &mut Header::new(PacketType::Connect);
558        header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
559        connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
560
561        let request = incoming_requests.next().await.unwrap();
562        assert_eq!(
563            request.address,
564            Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
565        );
566
567        let (socket, other_socket) = SyncSocket::create_stream();
568        let mut socket = Socket::from_socket(socket);
569        connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
570
571        for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
572            println!("round tripping packet of size {size}");
573            socket.write_all(&vec![size; size as usize]).await.unwrap();
574            let mut buf = vec![0u8; size as usize];
575            socket.read_exact(&mut buf).await.unwrap();
576            assert_eq!(buf, vec![size; size as usize]);
577        }
578        echo_task.cancel().await;
579    }
580
581    async fn copy_connection(from: &Connection<Vec<u8>>, to: &Connection<Vec<u8>>) {
582        let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
583        loop {
584            builder = from.fill_usb_packet(builder).await;
585            let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
586            for packet in packets {
587                println!("forwarding vsock packet");
588                to.handle_vsock_packet(packet.unwrap()).await.unwrap();
589            }
590        }
591    }
592
593    pub(crate) trait EndToEndTestFn<R>:
594        AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
595    {
596    }
597    impl<T, R> EndToEndTestFn<R> for T where
598        T: AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
599    {
600    }
601
602    pub(crate) async fn end_to_end_test<R1, R2>(
603        left_side: impl EndToEndTestFn<R1>,
604        right_side: impl EndToEndTestFn<R2>,
605    ) -> (R1, R2) {
606        type Connection = crate::Connection<Vec<u8>>;
607        let (_control_socket1, other_socket1) = SyncSocket::create_stream();
608        let (_control_socket2, other_socket2) = SyncSocket::create_stream();
609        let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
610        let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
611
612        let connection1 =
613            Arc::new(Connection::new(Socket::from_socket(other_socket1), incoming_requests_tx1));
614        let connection2 =
615            Arc::new(Connection::new(Socket::from_socket(other_socket2), incoming_requests_tx2));
616
617        let conn1 = connection1.clone();
618        let conn2 = connection2.clone();
619        let passthrough_task = Task::spawn(async move {
620            futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
621            println!("passthrough task loop ended");
622        });
623
624        let res = futures::join!(
625            left_side(connection1, incoming_requests1),
626            right_side(connection2, incoming_requests2)
627        );
628        passthrough_task.cancel().await;
629        res
630    }
631
632    #[fuchsia::test]
633    async fn data_over_end_to_end() {
634        end_to_end_test(
635            async |conn, _incoming| {
636                println!("sending request on connection 1");
637                let (socket, other_socket) = SyncSocket::create_stream();
638                let mut socket = Socket::from_socket(socket);
639                let state = conn
640                    .connect(
641                        Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
642                        Socket::from_socket(other_socket),
643                    )
644                    .await
645                    .unwrap();
646
647                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
648                    println!("round tripping packet of size {size}");
649                    socket.write_all(&vec![size; size as usize]).await.unwrap();
650                }
651                drop(socket);
652                state.wait_for_close().await.unwrap();
653            },
654            async |conn, mut incoming| {
655                println!("accepting request on connection 2");
656                let request = incoming.next().await.unwrap();
657                assert_eq!(
658                    request.address,
659                    Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
660                );
661
662                let (socket, other_socket) = SyncSocket::create_stream();
663                let mut socket = Socket::from_socket(socket);
664                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
665
666                println!("accepted request on connection 2");
667                for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
668                    let mut buf = vec![0u8; size as usize];
669                    socket.read_exact(&mut buf).await.unwrap();
670                    assert_eq!(buf, vec![size; size as usize]);
671                }
672                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
673                state.wait_for_close().await.unwrap();
674            },
675        )
676        .await;
677    }
678
679    #[fuchsia::test]
680    async fn normal_close_end_to_end() {
681        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
682        end_to_end_test(
683            async |conn, _incoming| {
684                let (socket, other_socket) = SyncSocket::create_stream();
685                let mut socket = Socket::from_socket(socket);
686                let state =
687                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
688                conn.close(&addr).await;
689                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
690                state.wait_for_close().await.unwrap();
691            },
692            async |conn, mut incoming| {
693                println!("accepting request on connection 2");
694                let request = incoming.next().await.unwrap();
695                assert_eq!(request.address, addr.clone(),);
696
697                let (socket, other_socket) = SyncSocket::create_stream();
698                let mut socket = Socket::from_socket(socket);
699                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
700                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
701                state.wait_for_close().await.unwrap();
702            },
703        )
704        .await;
705    }
706
707    #[fuchsia::test]
708    async fn reset_end_to_end() {
709        let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
710        end_to_end_test(
711            async |conn, _incoming| {
712                let (socket, other_socket) = SyncSocket::create_stream();
713                let mut socket = Socket::from_socket(socket);
714                let state =
715                    conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
716                conn.reset(&addr).await.unwrap();
717                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
718                state.wait_for_close().await.expect_err("expected reset");
719            },
720            async |conn, mut incoming| {
721                println!("accepting request on connection 2");
722                let request = incoming.next().await.unwrap();
723                assert_eq!(request.address, addr.clone(),);
724
725                let (socket, other_socket) = SyncSocket::create_stream();
726                let mut socket = Socket::from_socket(socket);
727                let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
728                assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
729                state.wait_for_close().await.unwrap();
730            },
731        )
732        .await;
733    }
734}