1use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::Scope;
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::connection::overflow_writer::OverflowHandleFut;
19use crate::{
20 Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
21};
22
23mod overflow_writer;
24mod pause_state;
25
26use overflow_writer::OverflowWriter;
27use pause_state::PauseState;
28
29pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
31impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
32
33#[derive(Copy, Clone, PartialEq, Eq)]
34enum PausePacket {
35 Pause,
36 UnPause,
37}
38
39impl PausePacket {
40 fn bytes(&self) -> [u8; 1] {
41 match self {
42 PausePacket::Pause => [1],
43 PausePacket::UnPause => [0],
44 }
45 }
46}
47
48pub struct Connection<B, S> {
61 control_socket_writer: Option<Mutex<WriteHalf<S>>>,
62 packet_filler: Arc<UsbPacketFiller<B>>,
63 protocol_version: ProtocolVersion,
64 connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
65 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
66 task_scope: Scope,
67}
68
69impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
70 pub fn new(
77 protocol_version: ProtocolVersion,
78 control_socket: Option<S>,
79 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
80 ) -> Self {
81 let packet_filler = Arc::new(UsbPacketFiller::default());
82 let connections = Default::default();
83 let task_scope = Scope::new_with_name("vsock_usb");
84 let control_socket_writer = control_socket.map(|control_socket| {
85 let (control_socket_reader, control_socket_writer) = control_socket.split();
86 task_scope.spawn(Self::run_socket(
87 control_socket_reader,
88 Address::default(),
89 packet_filler.clone(),
90 PauseState::new(),
91 ));
92 Mutex::new(control_socket_writer)
93 });
94 Self {
95 control_socket_writer,
96 packet_filler,
97 connections,
98 incoming_requests_tx,
99 protocol_version,
100 task_scope,
101 }
102 }
103
104 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
105 let header = &mut Header::new(PacketType::Finish);
106 header.set_address(address);
107 usb_packet_filler
108 .write_vsock_packet(&Packet { header, payload: &[] })
109 .await
110 .expect("Finish packet should never be too big");
111 }
112
113 async fn run_socket(
114 mut reader: ReadHalf<S>,
115 address: Address,
116 usb_packet_filler: Arc<UsbPacketFiller<B>>,
117 pause_state: Arc<PauseState>,
118 ) {
119 let mut buf = [0; 4096];
120 loop {
121 log::trace!("reading from control socket");
122 let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
123 Ok(0) => {
124 if !address.is_zeros() {
125 Self::send_close_packet(&address, &usb_packet_filler).await;
126 }
127 return;
128 }
129 Ok(read) => read,
130 Err(err) => {
131 if address.is_zeros() {
132 log::error!("Error reading usb socket: {err:?}");
133 } else {
134 Self::send_close_packet(&address, &usb_packet_filler).await;
135 }
136 return;
137 }
138 };
139 log::trace!("writing {read} bytes to vsock packet");
140 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
141 log::trace!("wrote {read} bytes to vsock packet");
142 }
143 }
144
145 fn set_connection(
146 &self,
147 address: Address,
148 state: VsockConnectionState<S>,
149 ) -> Result<(), Error> {
150 let mut connections = self.connections.lock().unwrap();
151 if !connections.contains_key(&address) {
152 connections.insert(address.clone(), VsockConnection { _address: address, state });
153 Ok(())
154 } else {
155 Err(Error::other(format!("connection on address {address:?} already set")))
156 }
157 }
158
159 pub async fn send_empty_echo(&self) {
162 debug!("Sending empty echo packet");
163 let header = &mut Header::new(PacketType::Echo);
164 self.packet_filler
165 .write_vsock_packet(&Packet { header, payload: &[] })
166 .await
167 .expect("empty echo packet should never be too large to fit in a usb packet");
168 }
169
170 pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
175 let (read_socket, write_socket) = socket.split();
176 let write_socket = Arc::new(Mutex::new(OverflowWriter::new(write_socket)));
177 let (connected_tx, connected_rx) = oneshot::channel();
178
179 self.set_connection(
180 addr.clone(),
181 VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
182 )?;
183
184 let header = &mut Header::new(PacketType::Connect);
185 header.set_address(&addr);
186 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
187 connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
188 }
189
190 pub async fn close(&self, address: &Address) {
192 Self::send_close_packet(address, &self.packet_filler).await
193 }
194
195 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
197 reset(address, &self.connections, &self.packet_filler).await
198 }
199
200 pub async fn accept(
204 &self,
205 request: ConnectionRequest,
206 socket: S,
207 ) -> Result<ConnectionState, Error> {
208 let address = request.address;
209 let notify_closed_rx;
210 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
211 let VsockConnectionState::ConnectingIncoming = &conn.state else {
212 return Err(Error::other(format!(
213 "Attempted to accept connection that was not waiting at {address:?}"
214 )));
215 };
216
217 let (read_socket, write_socket) = socket.split();
218 let writer = Arc::new(Mutex::new(OverflowWriter::new(write_socket)));
219 let notify_closed = mpsc::channel(2);
220 notify_closed_rx = notify_closed.1;
221 let notify_closed = notify_closed.0;
222 let pause_state = PauseState::new();
223
224 let reader_task = Scope::new_with_name("connection-reader");
225 reader_task.spawn(Self::run_socket(
226 read_socket,
227 address,
228 self.packet_filler.clone(),
229 Arc::clone(&pause_state),
230 ));
231
232 conn.state = VsockConnectionState::Connected {
233 writer,
234 _reader_scope: reader_task,
235 notify_closed,
236 pause_state,
237 };
238 } else {
239 return Err(Error::other(format!(
240 "Attempting to accept connection that did not exist at {address:?}"
241 )));
242 }
243 let header = &mut Header::new(PacketType::Accept);
244 header.set_address(&address);
245 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
246 Ok(ConnectionState(notify_closed_rx))
247 }
248
249 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
251 let address = request.address;
252 match self.connections.lock().unwrap().entry(address.clone()) {
253 Entry::Occupied(entry) => {
254 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
255 return Err(Error::other(format!(
256 "Attempted to reject connection that was not waiting at {address:?}"
257 )));
258 };
259 entry.remove();
260 }
261 Entry::Vacant(_) => {
262 return Err(Error::other(format!(
263 "Attempted to reject connection that was not waiting at {address:?}"
264 )));
265 }
266 }
267
268 let header = &mut Header::new(PacketType::Reset);
269 header.set_address(&address);
270 self.packet_filler
271 .write_vsock_packet(&Packet { header, payload: &[] })
272 .await
273 .expect("accept packet should never be too large for packet buffer");
274 Ok(())
275 }
276
277 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
278 if address.is_zeros() {
280 if let Some(writer) = self.control_socket_writer.as_ref() {
281 writer.lock().await.write_all(payload).await?;
282 } else {
283 trace!("Discarding {} bytes of data sent to control socket", payload.len());
284 }
285 Ok(())
286 } else {
287 let payload_socket;
288 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
289 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
290 warn!(
291 "Received data packet for connection in unexpected state for {address:?}"
292 );
293 return Ok(());
294 };
295 payload_socket = writer.clone();
296 } else {
297 warn!("Received data packet for connection that didn't exist at {address:?}");
298 return Ok(());
299 }
300 let mut socket_guard = payload_socket.lock().await;
301 match socket_guard.write_all(payload) {
302 Err(err) => {
303 debug!(
304 "Write to socket address {address:?} failed, \
305 resetting connection immediately: {err:?}"
306 );
307 self.reset(&address)
308 .await
309 .inspect_err(|err| {
310 warn!(
311 "Attempt to reset connection to {address:?} \
312 failed after write error: {err:?}"
313 );
314 })
315 .ok();
316 }
317 Ok(status) => {
318 if status.overflowed() {
319 if self.protocol_version.has_pause_packets() {
320 let header = &mut Header::new(PacketType::Pause);
321 header.set_address(&address);
322 self.packet_filler
323 .write_vsock_packet(&Packet {
324 header,
325 payload: &PausePacket::Pause.bytes(),
326 })
327 .await
328 .expect(
329 "pause packet should never be too large to fit in a usb packet",
330 );
331 }
332
333 let weak_payload_socket = Arc::downgrade(&payload_socket);
334 let connections = Arc::clone(&self.connections);
335 let has_pause_packets = self.protocol_version.has_pause_packets();
336 let packet_filler = Arc::clone(&self.packet_filler);
337 self.task_scope.spawn(async move {
338 let res = OverflowHandleFut::new(weak_payload_socket).await;
339
340 if let Err(err) = res {
341 debug!(
342 "Write to socket address {address:?} failed while \
343 processing backlog, resetting connection at next poll: {err:?}"
344 );
345 if let Err(err) = reset(&address, &connections, &packet_filler).await {
346 debug!("Error sending reset frame after overflow write failed: {err:?}");
347 }
348 } else if has_pause_packets {
349 let header = &mut Header::new(PacketType::Pause);
350 header.set_address(&address);
351 packet_filler
352 .write_vsock_packet(&Packet { header, payload: &PausePacket::UnPause.bytes() })
353 .await
354 .expect("pause packet should never be too large to fit in a usb packet");
355 }
356 });
357 }
358 }
359 }
360 Ok(())
361 }
362 }
363
364 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
365 debug!("received echo for {address:?} with payload {payload:?}");
366 let header = &mut Header::new(PacketType::EchoReply);
367 header.payload_len.set(payload.len() as u32);
368 header.set_address(&address);
369 self.packet_filler
370 .write_vsock_packet(&Packet { header, payload })
371 .await
372 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
373 }
374
375 async fn handle_echo_reply_packet(
376 &self,
377 address: Address,
378 payload: &[u8],
379 ) -> Result<(), Error> {
380 debug!("received echo reply for {address:?} with payload {payload:?}");
382 Ok(())
383 }
384
385 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
386 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
387 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
388 let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
389 else {
390 warn!("Received accept packet for connection in unexpected state for {address:?}");
391 return Ok(());
392 };
393 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
394 if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
395 warn!(
396 "Accept packet received for {address:?} but connect caller stopped waiting for it"
397 );
398 }
399 let pause_state = PauseState::new();
400
401 let reader_task = Scope::new_with_name("connection-reader");
402 reader_task.spawn(Self::run_socket(
403 read_socket,
404 address,
405 self.packet_filler.clone(),
406 Arc::clone(&pause_state),
407 ));
408 conn.state = VsockConnectionState::Connected {
409 writer,
410 _reader_scope: reader_task,
411 notify_closed,
412 pause_state,
413 };
414 } else {
415 warn!("Got accept packet for connection that was not being made at {address:?}");
416 return Ok(());
417 }
418 Ok(())
419 }
420
421 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
422 trace!("received connect packet for {address:?}");
423 match self.connections.lock().unwrap().entry(address.clone()) {
424 Entry::Vacant(entry) => {
425 debug!("valid connect request for {address:?}");
426 entry.insert(VsockConnection {
427 _address: address,
428 state: VsockConnectionState::ConnectingIncoming,
429 });
430 }
431 Entry::Occupied(_) => {
432 warn!(
433 "Received connect packet for already existing \
434 connection for address {address:?}. Ignoring"
435 );
436 return Ok(());
437 }
438 }
439
440 trace!("sending incoming connection request to client for {address:?}");
441 let connection_request = ConnectionRequest { address };
442 self.incoming_requests_tx
443 .clone()
444 .send(connection_request)
445 .await
446 .inspect(|_| trace!("sent incoming request for {address:?}"))
447 .map_err(|_| Error::other("Failed to send connection request"))
448 }
449
450 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
451 trace!("received finish packet for {address:?}");
452 let mut notify;
453 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
454 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
455 warn!(
456 "Received finish (close) packet for {address:?} \
457 which was not in a connected state. Ignoring and dropping connection state."
458 );
459 return Ok(());
460 };
461 notify = notify_closed;
462 } else {
463 warn!(
464 "Received finish (close) packet for connection that didn't exist \
465 on address {address:?}. Ignoring"
466 );
467 return Ok(());
468 }
469
470 notify.send(Ok(())).await.ok();
471
472 let header = &mut Header::new(PacketType::Reset);
473 header.set_address(&address);
474 self.packet_filler
475 .write_vsock_packet(&Packet { header, payload: &[] })
476 .await
477 .expect("accept packet should never be too large for packet buffer");
478 Ok(())
479 }
480
481 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
482 trace!("received reset packet for {address:?}");
483 let mut notify = None;
484 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
485 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
486 notify = Some(notify_closed);
487 } else {
488 debug!(
489 "Received reset packet for connection that wasn't in a connecting or \
490 disconnected state on address {address:?}."
491 );
492 }
493 } else {
494 warn!(
495 "Received reset packet for connection that didn't \
496 exist on address {address:?}. Ignoring"
497 );
498 }
499
500 if let Some(mut notify) = notify {
501 notify.send(Ok(())).await.ok();
502 }
503 Ok(())
504 }
505
506 async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
507 if !self.protocol_version.has_pause_packets() {
508 warn!(
509 "Got a pause packet while using protocol \
510 version {} which does not support them. Ignoring",
511 self.protocol_version
512 );
513 return Ok(());
514 }
515
516 let pause = match payload {
517 [1] => true,
518 [0] => false,
519 other => {
520 warn!("Ignoring unexpected pause packet payload {other:?}");
521 return Ok(());
522 }
523 };
524
525 if let Some(conn) = self.connections.lock().unwrap().get(&address) {
526 if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
527 pause_state.set_paused(pause);
528 } else {
529 warn!("Received pause packet for unestablished connection. Ignoring");
530 };
531 } else {
532 warn!(
533 "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
534 );
535 }
536
537 Ok(())
538 }
539
540 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
543 trace!("received vsock packet {header:?}", header = packet.header);
544 let payload_len = packet.header.payload_len.get() as usize;
545 let payload = &packet.payload[..payload_len];
546 let address = Address::from(packet.header);
547 match packet.header.packet_type {
548 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
549 PacketType::Data => self.handle_data_packet(address, payload).await,
550 PacketType::Accept => self.handle_accept_packet(address).await,
551 PacketType::Connect => self.handle_connect_packet(address).await,
552 PacketType::Finish => self.handle_finish_packet(address).await,
553 PacketType::Reset => self.handle_reset_packet(address).await,
554 PacketType::Echo => self.handle_echo_packet(address, payload).await,
555 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
556 PacketType::Pause => self.handle_pause_packet(address, payload).await,
557 }
558 }
559
560 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
567 self.packet_filler.fill_usb_packet(builder).await
568 }
569}
570
571async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
572 address: &Address,
573 connections: &std::sync::Mutex<HashMap<Address, VsockConnection<S>>>,
574 packet_filler: &UsbPacketFiller<B>,
575) -> Result<(), Error> {
576 let mut notify = None;
577 if let Some(conn) = connections.lock().unwrap().remove(&address) {
578 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
579 notify = Some(notify_closed);
580 }
581 } else {
582 return Err(Error::other(
583 "Client asked to reset connection {address:?} that did not exist",
584 ));
585 }
586
587 if let Some(mut notify) = notify {
588 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
589 }
590
591 let header = &mut Header::new(PacketType::Reset);
592 header.set_address(address);
593 packet_filler
594 .write_vsock_packet(&Packet { header, payload: &[] })
595 .await
596 .expect("Reset packet should never be too big");
597 Ok(())
598}
599
600enum VsockConnectionState<S> {
601 ConnectingOutgoing(
602 Arc<Mutex<OverflowWriter<S>>>,
603 ReadHalf<S>,
604 oneshot::Sender<Result<ConnectionState, Error>>,
605 ),
606 ConnectingIncoming,
607 Connected {
608 writer: Arc<Mutex<OverflowWriter<S>>>,
609 notify_closed: mpsc::Sender<Result<(), Error>>,
610 pause_state: Arc<PauseState>,
611 _reader_scope: Scope,
612 },
613 Invalid,
614}
615
616struct VsockConnection<S> {
617 _address: Address,
618 state: VsockConnectionState<S>,
619}
620
621#[derive(Debug)]
625pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
626
627impl ConnectionState {
628 pub async fn wait_for_close(mut self) -> Result<(), Error> {
631 self.0
632 .next()
633 .await
634 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
635 }
636}
637
638#[derive(Debug)]
641pub struct ConnectionRequest {
642 address: Address,
643}
644
645impl ConnectionRequest {
646 pub fn new(address: Address) -> Self {
648 Self { address }
649 }
650
651 pub fn address(&self) -> &Address {
653 &self.address
654 }
655}
656
657#[cfg(test)]
658mod test {
659 use std::sync::Arc;
660
661 use crate::VsockPacketIterator;
662
663 use super::*;
664
665 #[cfg(not(target_os = "fuchsia"))]
666 use fuchsia_async::emulated_handle::Socket as SyncSocket;
667 use fuchsia_async::{Socket, Task};
668 use futures::StreamExt;
669 #[cfg(target_os = "fuchsia")]
670 use zx::Socket as SyncSocket;
671
672 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
673 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
674 loop {
675 println!("waiting for usb packet");
676 builder = echo_connection.fill_usb_packet(builder).await;
677 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
678 println!("got usb packet, echoing it back to the other side");
679 let mut packet_count = 0;
680 for packet in packets {
681 let packet = packet.unwrap();
682 match packet.header.packet_type {
683 PacketType::Connect => {
684 let mut reply_header = packet.header.clone();
686 reply_header.packet_type = PacketType::Accept;
687 echo_connection
688 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
689 .await
690 .unwrap();
691 }
692 PacketType::Accept => {
693 }
695 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
696 }
697 packet_count += 1;
698 }
699 println!("handled {packet_count} packets");
700 }
701 }
702
703 #[fuchsia::test]
704 async fn data_over_control_socket() {
705 let (socket, other_socket) = SyncSocket::create_stream();
706 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
707 let mut socket = Socket::from_socket(socket);
708 let connection = Arc::new(Connection::new(
709 ProtocolVersion::LATEST,
710 Some(Socket::from_socket(other_socket)),
711 incoming_requests_tx,
712 ));
713
714 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
715
716 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
717 println!("round tripping packet of size {size}");
718 socket.write_all(&vec![size; size as usize]).await.unwrap();
719 let mut buf = vec![0u8; size as usize];
720 socket.read_exact(&mut buf).await.unwrap();
721 assert_eq!(buf, vec![size; size as usize]);
722 }
723 echo_task.abort().await;
724 }
725
726 #[fuchsia::test]
727 async fn data_over_normal_outgoing_socket() {
728 let (_control_socket, other_socket) = SyncSocket::create_stream();
729 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
730 let connection = Arc::new(Connection::new(
731 ProtocolVersion::LATEST,
732 Some(Socket::from_socket(other_socket)),
733 incoming_requests_tx,
734 ));
735
736 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
737
738 let (socket, other_socket) = SyncSocket::create_stream();
739 let mut socket = Socket::from_socket(socket);
740 connection
741 .connect(
742 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
743 Socket::from_socket(other_socket),
744 )
745 .await
746 .unwrap();
747
748 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
749 println!("round tripping packet of size {size}");
750 socket.write_all(&vec![size; size as usize]).await.unwrap();
751 let mut buf = vec![0u8; size as usize];
752 socket.read_exact(&mut buf).await.unwrap();
753 assert_eq!(buf, vec![size; size as usize]);
754 }
755 echo_task.abort().await;
756 }
757
758 #[fuchsia::test]
759 async fn data_over_normal_incoming_socket() {
760 let (_control_socket, other_socket) = SyncSocket::create_stream();
761 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
762 let connection = Arc::new(Connection::new(
763 ProtocolVersion::LATEST,
764 Some(Socket::from_socket(other_socket)),
765 incoming_requests_tx,
766 ));
767
768 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
769
770 let header = &mut Header::new(PacketType::Connect);
771 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
772 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
773
774 let request = incoming_requests.next().await.unwrap();
775 assert_eq!(
776 request.address,
777 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
778 );
779
780 let (socket, other_socket) = SyncSocket::create_stream();
781 let mut socket = Socket::from_socket(socket);
782 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
783
784 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
785 println!("round tripping packet of size {size}");
786 socket.write_all(&vec![size; size as usize]).await.unwrap();
787 let mut buf = vec![0u8; size as usize];
788 socket.read_exact(&mut buf).await.unwrap();
789 assert_eq!(buf, vec![size; size as usize]);
790 }
791 echo_task.abort().await;
792 }
793
794 async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
795 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
796 loop {
797 builder = from.fill_usb_packet(builder).await;
798 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
799 for packet in packets {
800 println!("forwarding vsock packet");
801 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
802 }
803 }
804 }
805
806 pub(crate) trait EndToEndTestFn<R>:
807 AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
808 {
809 }
810 impl<T, R> EndToEndTestFn<R> for T where
811 T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
812 {
813 }
814
815 pub(crate) async fn end_to_end_test<R1, R2>(
816 left_side: impl EndToEndTestFn<R1>,
817 right_side: impl EndToEndTestFn<R2>,
818 ) -> (R1, R2) {
819 type Connection = crate::Connection<Vec<u8>, Socket>;
820 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
821 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
822 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
823 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
824
825 let connection1 = Arc::new(Connection::new(
826 ProtocolVersion::LATEST,
827 Some(Socket::from_socket(other_socket1)),
828 incoming_requests_tx1,
829 ));
830 let connection2 = Arc::new(Connection::new(
831 ProtocolVersion::LATEST,
832 Some(Socket::from_socket(other_socket2)),
833 incoming_requests_tx2,
834 ));
835
836 let conn1 = connection1.clone();
837 let conn2 = connection2.clone();
838 let passthrough_task = Task::spawn(async move {
839 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
840 println!("passthrough task loop ended");
841 });
842
843 let res = futures::join!(
844 left_side(connection1, incoming_requests1),
845 right_side(connection2, incoming_requests2)
846 );
847 passthrough_task.abort().await;
848 res
849 }
850
851 #[fuchsia::test]
852 async fn data_over_end_to_end() {
853 end_to_end_test(
854 async |conn, _incoming| {
855 println!("sending request on connection 1");
856 let (socket, other_socket) = SyncSocket::create_stream();
857 let mut socket = Socket::from_socket(socket);
858 let state = conn
859 .connect(
860 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
861 Socket::from_socket(other_socket),
862 )
863 .await
864 .unwrap();
865
866 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
867 println!("round tripping packet of size {size}");
868 socket.write_all(&vec![size; size as usize]).await.unwrap();
869 }
870 drop(socket);
871 state.wait_for_close().await.unwrap();
872 },
873 async |conn, mut incoming| {
874 println!("accepting request on connection 2");
875 let request = incoming.next().await.unwrap();
876 assert_eq!(
877 request.address,
878 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
879 );
880
881 let (socket, other_socket) = SyncSocket::create_stream();
882 let mut socket = Socket::from_socket(socket);
883 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
884
885 println!("accepted request on connection 2");
886 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
887 let mut buf = vec![0u8; size as usize];
888 socket.read_exact(&mut buf).await.unwrap();
889 assert_eq!(buf, vec![size; size as usize]);
890 }
891 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
892 state.wait_for_close().await.unwrap();
893 },
894 )
895 .await;
896 }
897
898 #[fuchsia::test]
899 async fn normal_close_end_to_end() {
900 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
901 end_to_end_test(
902 async |conn, _incoming| {
903 let (socket, other_socket) = SyncSocket::create_stream();
904 let mut socket = Socket::from_socket(socket);
905 let state =
906 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
907 conn.close(&addr).await;
908 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
909 state.wait_for_close().await.unwrap();
910 },
911 async |conn, mut incoming| {
912 println!("accepting request on connection 2");
913 let request = incoming.next().await.unwrap();
914 assert_eq!(request.address, addr.clone(),);
915
916 let (socket, other_socket) = SyncSocket::create_stream();
917 let mut socket = Socket::from_socket(socket);
918 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
919 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
920 state.wait_for_close().await.unwrap();
921 },
922 )
923 .await;
924 }
925
926 #[fuchsia::test]
927 async fn reset_end_to_end() {
928 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
929 end_to_end_test(
930 async |conn, _incoming| {
931 let (socket, other_socket) = SyncSocket::create_stream();
932 let mut socket = Socket::from_socket(socket);
933 let state =
934 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
935 conn.reset(&addr).await.unwrap();
936 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
937 state.wait_for_close().await.expect_err("expected reset");
938 },
939 async |conn, mut incoming| {
940 println!("accepting request on connection 2");
941 let request = incoming.next().await.unwrap();
942 assert_eq!(request.address, addr.clone(),);
943
944 let (socket, other_socket) = SyncSocket::create_stream();
945 let mut socket = Socket::from_socket(socket);
946 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
947 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
948 state.wait_for_close().await.unwrap();
949 },
950 )
951 .await;
952 }
953}