1use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23 Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
24};
25
26mod overflow_writer;
27mod pause_state;
28
29use overflow_writer::OverflowWriter;
30use pause_state::PauseState;
31
32pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
34impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35
36#[derive(Copy, Clone, PartialEq, Eq)]
37enum PausePacket {
38 Pause,
39 UnPause,
40}
41
42impl PausePacket {
43 fn bytes(&self) -> [u8; 1] {
44 match self {
45 PausePacket::Pause => [1],
46 PausePacket::UnPause => [0],
47 }
48 }
49}
50
51pub struct ReadyConnect<B, S> {
54 connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
55 packet_filler: Arc<UsbPacketFiller<B>>,
56 address: Address,
57}
58
59impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
60 pub async fn finish_connect(self, socket: S) {
62 let (read_socket, write_socket) = socket.split();
63 let writer = {
64 let conns = self.connections.lock();
65 let Some(conn) = conns.get(&self.address) else {
66 warn!("Connection state was missing after connection success!");
67 return;
68 };
69 let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
70 &conn.state
71 else {
72 warn!("Connection state was invalid after connection success!");
73 return;
74 };
75 reader_scope.spawn(Connection::<B, S>::run_socket(
76 read_socket,
77 self.address,
78 self.packet_filler,
79 Arc::clone(pause_state),
80 ));
81 Arc::clone(writer)
82 };
83 let mut writer = writer.lock().await;
84 let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
85 &mut *writer,
86 ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
87 ) else {
88 unreachable!("Connection completed multiple times!")
89 };
90
91 wakers.into_iter().for_each(Waker::wake);
92 }
93}
94
95pub struct Connection<B, S> {
108 control_socket_writer: Option<Mutex<WriteHalf<S>>>,
109 packet_filler: Arc<UsbPacketFiller<B>>,
110 protocol_version: ProtocolVersion,
111 connections: Arc<fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
112 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
113 task_scope: Scope,
114}
115
116impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
117 pub fn new(
124 protocol_version: ProtocolVersion,
125 control_socket: Option<S>,
126 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
127 ) -> Self {
128 let packet_filler = Arc::new(UsbPacketFiller::default());
129 let connections = Default::default();
130 let task_scope = Scope::new_with_name("vsock_usb");
131 let control_socket_writer = control_socket.map(|control_socket| {
132 let (control_socket_reader, control_socket_writer) = control_socket.split();
133 task_scope.spawn(Self::run_socket(
134 control_socket_reader,
135 Address::default(),
136 packet_filler.clone(),
137 PauseState::new(),
138 ));
139 Mutex::new(control_socket_writer)
140 });
141 Self {
142 control_socket_writer,
143 packet_filler,
144 connections,
145 incoming_requests_tx,
146 protocol_version,
147 task_scope,
148 }
149 }
150
151 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
152 let header = &mut Header::new(PacketType::Finish);
153 header.set_address(address);
154 usb_packet_filler
155 .write_vsock_packet(&Packet { header, payload: &[] })
156 .await
157 .expect("Finish packet should never be too big");
158 }
159
160 async fn run_socket(
161 mut reader: ReadHalf<S>,
162 address: Address,
163 usb_packet_filler: Arc<UsbPacketFiller<B>>,
164 pause_state: Arc<PauseState>,
165 ) {
166 let mut buf = [0; 4096];
167 loop {
168 log::trace!("reading from control socket");
169 let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
170 Ok(0) => {
171 if !address.is_zeros() {
172 Self::send_close_packet(&address, &usb_packet_filler).await;
173 }
174 return;
175 }
176 Ok(read) => read,
177 Err(err) => {
178 if address.is_zeros() {
179 log::error!("Error reading usb socket: {err:?}");
180 } else {
181 Self::send_close_packet(&address, &usb_packet_filler).await;
182 }
183 return;
184 }
185 };
186 log::trace!("writing {read} bytes to vsock packet");
187 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
188 log::trace!("wrote {read} bytes to vsock packet");
189 }
190 }
191
192 fn set_connection(
193 &self,
194 address: Address,
195 state: VsockConnectionState<S>,
196 ) -> Result<(), Error> {
197 let mut connections = self.connections.lock();
198 if !connections.contains_key(&address) {
199 connections.insert(address.clone(), VsockConnection { _address: address, state });
200 Ok(())
201 } else {
202 Err(Error::other(format!("connection on address {address:?} already set")))
203 }
204 }
205
206 pub async fn send_empty_echo(&self) {
209 debug!("Sending empty echo packet");
210 let header = &mut Header::new(PacketType::Echo);
211 self.packet_filler
212 .write_vsock_packet(&Packet { header, payload: &[] })
213 .await
214 .expect("empty echo packet should never be too large to fit in a usb packet");
215 }
216
217 pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
222 let (ready, state) = self.connect_late(addr).await?;
223 ready.finish_connect(socket).await;
224 Ok(state)
225 }
226
227 pub async fn connect_late(
234 &self,
235 addr: Address,
236 ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
237 let (connected_tx, connected_rx) = oneshot::channel();
238
239 self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
240
241 let header = &mut Header::new(PacketType::Connect);
242 header.set_address(&addr);
243 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
244 let Ok(conn_state) = connected_rx.await else {
245 return Err(Error::other("Accept was never received for {addr:?}"));
246 };
247
248 Ok((
249 ReadyConnect {
250 connections: Arc::clone(&self.connections),
251 packet_filler: Arc::clone(&self.packet_filler),
252 address: addr,
253 },
254 conn_state,
255 ))
256 }
257
258 pub async fn close(&self, address: &Address) {
260 Self::send_close_packet(address, &self.packet_filler).await
261 }
262
263 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
265 reset(address, &self.connections, &self.packet_filler).await
266 }
267
268 pub async fn accept(
272 &self,
273 request: ConnectionRequest,
274 socket: S,
275 ) -> Result<ConnectionState, Error> {
276 let (ready, state) = self.accept_late(request).await?;
277 ready.finish_connect(socket).await;
278 Ok(state)
279 }
280
281 pub async fn accept_late(
285 &self,
286 request: ConnectionRequest,
287 ) -> Result<(ReadyConnect<B, S>, ConnectionState), Error> {
288 let address = request.address;
289 let notify_closed_rx;
290 if let Some(conn) = self.connections.lock().get_mut(&address) {
291 let VsockConnectionState::ConnectingIncoming = &conn.state else {
292 return Err(Error::other(format!(
293 "Attempted to accept connection that was not waiting at {address:?}"
294 )));
295 };
296
297 let notify_closed = mpsc::channel(2);
298 notify_closed_rx = notify_closed.1;
299 let notify_closed = notify_closed.0;
300 let pause_state = PauseState::new();
301
302 let reader_scope = Scope::new_with_name("connection-reader");
303
304 conn.state = VsockConnectionState::Connected {
305 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
306 reader_scope,
307 notify_closed,
308 pause_state,
309 };
310 } else {
311 return Err(Error::other(format!(
312 "Attempting to accept connection that did not exist at {address:?}"
313 )));
314 }
315 let header = &mut Header::new(PacketType::Accept);
316 header.set_address(&address);
317 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
318 Ok((
319 ReadyConnect {
320 connections: Arc::clone(&self.connections),
321 packet_filler: Arc::clone(&self.packet_filler),
322 address,
323 },
324 ConnectionState(notify_closed_rx),
325 ))
326 }
327
328 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
330 let address = request.address;
331 match self.connections.lock().entry(address.clone()) {
332 Entry::Occupied(entry) => {
333 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
334 return Err(Error::other(format!(
335 "Attempted to reject connection that was not waiting at {address:?}"
336 )));
337 };
338 entry.remove();
339 }
340 Entry::Vacant(_) => {
341 return Err(Error::other(format!(
342 "Attempted to reject connection that was not waiting at {address:?}"
343 )));
344 }
345 }
346
347 let header = &mut Header::new(PacketType::Reset);
348 header.set_address(&address);
349 self.packet_filler
350 .write_vsock_packet(&Packet { header, payload: &[] })
351 .await
352 .expect("accept packet should never be too large for packet buffer");
353 Ok(())
354 }
355
356 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
357 if address.is_zeros() {
359 if let Some(writer) = self.control_socket_writer.as_ref() {
360 writer.lock().await.write_all(payload).await?;
361 } else {
362 trace!("Discarding {} bytes of data sent to control socket", payload.len());
363 }
364 Ok(())
365 } else {
366 let payload_socket;
367 if let Some(conn) = self.connections.lock().get_mut(&address) {
368 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
369 warn!(
370 "Received data packet for connection in unexpected state for {address:?}"
371 );
372 return Ok(());
373 };
374 payload_socket = writer.clone();
375 } else {
376 warn!("Received data packet for connection that didn't exist at {address:?}");
377 return Ok(());
378 }
379 let mut socket_guard =
380 ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
381 let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
382 unreachable!("wait_available didn't wait until socket was available!");
383 };
384 match socket.write_all(payload) {
385 Err(err) => {
386 debug!(
387 "Write to socket address {address:?} failed, \
388 resetting connection immediately: {err:?}"
389 );
390 self.reset(&address)
391 .await
392 .inspect_err(|err| {
393 warn!(
394 "Attempt to reset connection to {address:?} \
395 failed after write error: {err:?}"
396 );
397 })
398 .ok();
399 }
400 Ok(status) => {
401 if status.overflowed() {
402 if self.protocol_version.has_pause_packets() {
403 let header = &mut Header::new(PacketType::Pause);
404 let payload = &PausePacket::Pause.bytes();
405 header.set_address(&address);
406 header.payload_len.set(payload.len() as u32);
407 self.packet_filler
408 .write_vsock_packet(&Packet { header, payload })
409 .await
410 .expect(
411 "pause packet should never be too large to fit in a usb packet",
412 );
413 }
414
415 let weak_payload_socket = Arc::downgrade(&payload_socket);
416 let connections = Arc::clone(&self.connections);
417 let has_pause_packets = self.protocol_version.has_pause_packets();
418 let packet_filler = Arc::clone(&self.packet_filler);
419 self.task_scope.spawn(async move {
420 let res = OverflowHandleFut::new(weak_payload_socket).await;
421
422 if let Err(err) = res {
423 debug!(
424 "Write to socket address {address:?} failed while \
425 processing backlog, resetting connection at next poll: {err:?}"
426 );
427 if let Err(err) = reset(&address, &connections, &packet_filler).await {
428 debug!("Error sending reset frame after overflow write failed: {err:?}");
429 }
430 } else if has_pause_packets {
431 let header = &mut Header::new(PacketType::Pause);
432 let payload = &PausePacket::UnPause.bytes();
433 header.set_address(&address);
434 header.payload_len.set(payload.len() as u32);
435 packet_filler
436 .write_vsock_packet(&Packet { header, payload })
437 .await
438 .expect("pause packet should never be too large to fit in a usb packet");
439 }
440 });
441 }
442 }
443 }
444 Ok(())
445 }
446 }
447
448 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
449 debug!("received echo for {address:?} with payload {payload:?}");
450 let header = &mut Header::new(PacketType::EchoReply);
451 header.payload_len.set(payload.len() as u32);
452 header.set_address(&address);
453 self.packet_filler
454 .write_vsock_packet(&Packet { header, payload })
455 .await
456 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
457 }
458
459 async fn handle_echo_reply_packet(
460 &self,
461 address: Address,
462 payload: &[u8],
463 ) -> Result<(), Error> {
464 debug!("received echo reply for {address:?} with payload {payload:?}");
466 Ok(())
467 }
468
469 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
470 if let Some(conn) = self.connections.lock().get_mut(&address) {
471 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
472 let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
473 warn!("Received accept packet for connection in unexpected state for {address:?}");
474 return Ok(());
475 };
476 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
477 if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
478 warn!(
479 "Accept packet received for {address:?} but connect caller stopped waiting for it"
480 );
481 }
482 let pause_state = PauseState::new();
483
484 let reader_scope = Scope::new_with_name("connection-reader");
485 conn.state = VsockConnectionState::Connected {
486 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
487 reader_scope,
488 notify_closed,
489 pause_state,
490 };
491 } else {
492 warn!("Got accept packet for connection that was not being made at {address:?}");
493 return Ok(());
494 }
495 Ok(())
496 }
497
498 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
499 trace!("received connect packet for {address:?}");
500 match self.connections.lock().entry(address.clone()) {
501 Entry::Vacant(entry) => {
502 debug!("valid connect request for {address:?}");
503 entry.insert(VsockConnection {
504 _address: address,
505 state: VsockConnectionState::ConnectingIncoming,
506 });
507 }
508 Entry::Occupied(_) => {
509 warn!(
510 "Received connect packet for already existing \
511 connection for address {address:?}. Ignoring"
512 );
513 return Ok(());
514 }
515 }
516
517 trace!("sending incoming connection request to client for {address:?}");
518 let connection_request = ConnectionRequest { address };
519 self.incoming_requests_tx
520 .clone()
521 .send(connection_request)
522 .await
523 .inspect(|_| trace!("sent incoming request for {address:?}"))
524 .map_err(|_| Error::other("Failed to send connection request"))
525 }
526
527 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
528 trace!("received finish packet for {address:?}");
529 let mut notify;
530 if let Some(conn) = self.connections.lock().remove(&address) {
531 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
532 warn!(
533 "Received finish (close) packet for {address:?} \
534 which was not in a connected state. Ignoring and dropping connection state."
535 );
536 return Ok(());
537 };
538 notify = notify_closed;
539 } else {
540 warn!(
541 "Received finish (close) packet for connection that didn't exist \
542 on address {address:?}. Ignoring"
543 );
544 return Ok(());
545 }
546
547 notify.send(Ok(())).await.ok();
548
549 let header = &mut Header::new(PacketType::Reset);
550 header.set_address(&address);
551 self.packet_filler
552 .write_vsock_packet(&Packet { header, payload: &[] })
553 .await
554 .expect("accept packet should never be too large for packet buffer");
555 Ok(())
556 }
557
558 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
559 trace!("received reset packet for {address:?}");
560 let mut notify = None;
561 if let Some(conn) = self.connections.lock().remove(&address) {
562 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
563 notify = Some(notify_closed);
564 } else {
565 debug!(
566 "Received reset packet for connection that wasn't in a connecting or \
567 disconnected state on address {address:?}."
568 );
569 }
570 } else {
571 warn!(
572 "Received reset packet for connection that didn't \
573 exist on address {address:?}. Ignoring"
574 );
575 }
576
577 if let Some(mut notify) = notify {
578 notify.send(Ok(())).await.ok();
579 }
580 Ok(())
581 }
582
583 async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
584 if !self.protocol_version.has_pause_packets() {
585 warn!(
586 "Got a pause packet while using protocol \
587 version {} which does not support them. Ignoring",
588 self.protocol_version
589 );
590 return Ok(());
591 }
592
593 let pause = match payload {
594 [1] => true,
595 [0] => false,
596 other => {
597 warn!("Ignoring unexpected pause packet payload {other:?}");
598 return Ok(());
599 }
600 };
601
602 if let Some(conn) = self.connections.lock().get(&address) {
603 if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
604 pause_state.set_paused(pause);
605 } else {
606 warn!("Received pause packet for unestablished connection. Ignoring");
607 };
608 } else {
609 warn!(
610 "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
611 );
612 }
613
614 Ok(())
615 }
616
617 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
620 trace!("received vsock packet {header:?}", header = packet.header);
621 let payload_len = packet.header.payload_len.get() as usize;
622 let payload = &packet.payload[..payload_len];
623 let address = Address::from(packet.header);
624 match packet.header.packet_type {
625 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
626 PacketType::Data => self.handle_data_packet(address, payload).await,
627 PacketType::Accept => self.handle_accept_packet(address).await,
628 PacketType::Connect => self.handle_connect_packet(address).await,
629 PacketType::Finish => self.handle_finish_packet(address).await,
630 PacketType::Reset => self.handle_reset_packet(address).await,
631 PacketType::Echo => self.handle_echo_packet(address, payload).await,
632 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
633 PacketType::Pause => self.handle_pause_packet(address, payload).await,
634 }
635 }
636
637 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
644 self.packet_filler.fill_usb_packet(builder).await
645 }
646}
647
648async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
649 address: &Address,
650 connections: &fuchsia_sync::Mutex<HashMap<Address, VsockConnection<S>>>,
651 packet_filler: &UsbPacketFiller<B>,
652) -> Result<(), Error> {
653 let mut notify = None;
654 if let Some(conn) = connections.lock().remove(&address) {
655 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
656 notify = Some(notify_closed);
657 }
658 } else {
659 return Err(Error::other(
660 "Client asked to reset connection {address:?} that did not exist",
661 ));
662 }
663
664 if let Some(mut notify) = notify {
665 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
666 }
667
668 let header = &mut Header::new(PacketType::Reset);
669 header.set_address(address);
670 packet_filler
671 .write_vsock_packet(&Packet { header, payload: &[] })
672 .await
673 .expect("Reset packet should never be too big");
674 Ok(())
675}
676
677enum ConnectionStateWriter<S> {
682 NotYetAvailable(Vec<Waker>),
683 Available(OverflowWriter<S>),
684}
685
686impl<S> ConnectionStateWriter<S> {
687 fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
689 ConnectionStateWriterFut { writer: this, lock_fut: None }
690 }
691}
692
693struct ConnectionStateWriterFut<S> {
695 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
696 lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
697}
698
699impl<S> Future for ConnectionStateWriterFut<S> {
700 type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
701
702 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
703 let writer = Arc::clone(&self.writer);
704 let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
705 let mut lock = ready!(lock_fut.poll_unpin(cx));
706 self.lock_fut = None;
707 match &mut *lock {
708 ConnectionStateWriter::Available(_) => Poll::Ready(lock),
709 ConnectionStateWriter::NotYetAvailable(queue) => {
710 queue.push(cx.waker().clone());
711 Poll::Pending
712 }
713 }
714 }
715}
716
717enum VsockConnectionState<S> {
718 ConnectingOutgoing(oneshot::Sender<ConnectionState>),
719 ConnectingIncoming,
720 Connected {
721 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
722 notify_closed: mpsc::Sender<Result<(), Error>>,
723 pause_state: Arc<PauseState>,
724 reader_scope: Scope,
725 },
726 Invalid,
727}
728
729struct VsockConnection<S> {
730 _address: Address,
731 state: VsockConnectionState<S>,
732}
733
734#[derive(Debug)]
738pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
739
740impl ConnectionState {
741 pub async fn wait_for_close(mut self) -> Result<(), Error> {
744 self.0
745 .next()
746 .await
747 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
748 }
749}
750
751#[derive(Debug)]
754pub struct ConnectionRequest {
755 address: Address,
756}
757
758impl ConnectionRequest {
759 pub fn new(address: Address) -> Self {
761 Self { address }
762 }
763
764 pub fn address(&self) -> &Address {
766 &self.address
767 }
768}
769
770#[cfg(test)]
771mod test {
772 use std::sync::Arc;
773
774 use crate::VsockPacketIterator;
775
776 use super::*;
777
778 #[cfg(not(target_os = "fuchsia"))]
779 use fuchsia_async::emulated_handle::Socket as SyncSocket;
780 use fuchsia_async::{Socket, Task};
781 use futures::StreamExt;
782 #[cfg(target_os = "fuchsia")]
783 use zx::Socket as SyncSocket;
784
785 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
786 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
787 loop {
788 println!("waiting for usb packet");
789 builder = echo_connection.fill_usb_packet(builder).await;
790 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
791 println!("got usb packet, echoing it back to the other side");
792 let mut packet_count = 0;
793 for packet in packets {
794 let packet = packet.unwrap();
795 match packet.header.packet_type {
796 PacketType::Connect => {
797 let mut reply_header = packet.header.clone();
799 reply_header.packet_type = PacketType::Accept;
800 echo_connection
801 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
802 .await
803 .unwrap();
804 }
805 PacketType::Accept => {
806 }
808 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
809 }
810 packet_count += 1;
811 }
812 println!("handled {packet_count} packets");
813 }
814 }
815
816 #[fuchsia::test]
817 async fn data_over_control_socket() {
818 let (socket, other_socket) = SyncSocket::create_stream();
819 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
820 let mut socket = Socket::from_socket(socket);
821 let connection = Arc::new(Connection::new(
822 ProtocolVersion::LATEST,
823 Some(Socket::from_socket(other_socket)),
824 incoming_requests_tx,
825 ));
826
827 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
828
829 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
830 println!("round tripping packet of size {size}");
831 socket.write_all(&vec![size; size as usize]).await.unwrap();
832 let mut buf = vec![0u8; size as usize];
833 socket.read_exact(&mut buf).await.unwrap();
834 assert_eq!(buf, vec![size; size as usize]);
835 }
836 echo_task.abort().await;
837 }
838
839 #[fuchsia::test]
840 async fn data_over_normal_outgoing_socket() {
841 let (_control_socket, other_socket) = SyncSocket::create_stream();
842 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
843 let connection = Arc::new(Connection::new(
844 ProtocolVersion::LATEST,
845 Some(Socket::from_socket(other_socket)),
846 incoming_requests_tx,
847 ));
848
849 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
850
851 let (socket, other_socket) = SyncSocket::create_stream();
852 let mut socket = Socket::from_socket(socket);
853 connection
854 .connect(
855 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
856 Socket::from_socket(other_socket),
857 )
858 .await
859 .unwrap();
860
861 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
862 println!("round tripping packet of size {size}");
863 socket.write_all(&vec![size; size as usize]).await.unwrap();
864 let mut buf = vec![0u8; size as usize];
865 socket.read_exact(&mut buf).await.unwrap();
866 assert_eq!(buf, vec![size; size as usize]);
867 }
868 echo_task.abort().await;
869 }
870
871 #[fuchsia::test]
872 async fn data_over_normal_incoming_socket() {
873 let (_control_socket, other_socket) = SyncSocket::create_stream();
874 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
875 let connection = Arc::new(Connection::new(
876 ProtocolVersion::LATEST,
877 Some(Socket::from_socket(other_socket)),
878 incoming_requests_tx,
879 ));
880
881 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
882
883 let header = &mut Header::new(PacketType::Connect);
884 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
885 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
886
887 let request = incoming_requests.next().await.unwrap();
888 assert_eq!(
889 request.address,
890 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
891 );
892
893 let (socket, other_socket) = SyncSocket::create_stream();
894 let mut socket = Socket::from_socket(socket);
895 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
896
897 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
898 println!("round tripping packet of size {size}");
899 socket.write_all(&vec![size; size as usize]).await.unwrap();
900 let mut buf = vec![0u8; size as usize];
901 socket.read_exact(&mut buf).await.unwrap();
902 assert_eq!(buf, vec![size; size as usize]);
903 }
904 echo_task.abort().await;
905 }
906
907 async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
908 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
909 loop {
910 builder = from.fill_usb_packet(builder).await;
911 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
912 for packet in packets {
913 println!("forwarding vsock packet");
914 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
915 }
916 }
917 }
918
919 pub(crate) trait EndToEndTestFn<R>:
920 AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
921 {
922 }
923 impl<T, R> EndToEndTestFn<R> for T where
924 T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
925 {
926 }
927
928 pub(crate) async fn end_to_end_test<R1, R2>(
929 left_side: impl EndToEndTestFn<R1>,
930 right_side: impl EndToEndTestFn<R2>,
931 ) -> (R1, R2) {
932 type Connection = crate::Connection<Vec<u8>, Socket>;
933 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
934 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
935 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
936 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
937
938 let connection1 = Arc::new(Connection::new(
939 ProtocolVersion::LATEST,
940 Some(Socket::from_socket(other_socket1)),
941 incoming_requests_tx1,
942 ));
943 let connection2 = Arc::new(Connection::new(
944 ProtocolVersion::LATEST,
945 Some(Socket::from_socket(other_socket2)),
946 incoming_requests_tx2,
947 ));
948
949 let conn1 = connection1.clone();
950 let conn2 = connection2.clone();
951 let passthrough_task = Task::spawn(async move {
952 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
953 println!("passthrough task loop ended");
954 });
955
956 let res = futures::join!(
957 left_side(connection1, incoming_requests1),
958 right_side(connection2, incoming_requests2)
959 );
960 passthrough_task.abort().await;
961 res
962 }
963
964 #[fuchsia::test]
965 async fn data_over_end_to_end() {
966 end_to_end_test(
967 async |conn, _incoming| {
968 println!("sending request on connection 1");
969 let (socket, other_socket) = SyncSocket::create_stream();
970 let mut socket = Socket::from_socket(socket);
971 let state = conn
972 .connect(
973 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
974 Socket::from_socket(other_socket),
975 )
976 .await
977 .unwrap();
978
979 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
980 println!("round tripping packet of size {size}");
981 socket.write_all(&vec![size; size as usize]).await.unwrap();
982 }
983 drop(socket);
984 state.wait_for_close().await.unwrap();
985 },
986 async |conn, mut incoming| {
987 println!("accepting request on connection 2");
988 let request = incoming.next().await.unwrap();
989 assert_eq!(
990 request.address,
991 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
992 );
993
994 let (socket, other_socket) = SyncSocket::create_stream();
995 let mut socket = Socket::from_socket(socket);
996 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
997
998 println!("accepted request on connection 2");
999 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
1000 let mut buf = vec![0u8; size as usize];
1001 socket.read_exact(&mut buf).await.unwrap();
1002 assert_eq!(buf, vec![size; size as usize]);
1003 }
1004 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1005 state.wait_for_close().await.unwrap();
1006 },
1007 )
1008 .await;
1009 }
1010
1011 #[fuchsia::test]
1012 async fn normal_close_end_to_end() {
1013 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1014 end_to_end_test(
1015 async |conn, _incoming| {
1016 let (socket, other_socket) = SyncSocket::create_stream();
1017 let mut socket = Socket::from_socket(socket);
1018 let state =
1019 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1020 conn.close(&addr).await;
1021 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1022 state.wait_for_close().await.unwrap();
1023 },
1024 async |conn, mut incoming| {
1025 println!("accepting request on connection 2");
1026 let request = incoming.next().await.unwrap();
1027 assert_eq!(request.address, addr.clone(),);
1028
1029 let (socket, other_socket) = SyncSocket::create_stream();
1030 let mut socket = Socket::from_socket(socket);
1031 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1032 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1033 state.wait_for_close().await.unwrap();
1034 },
1035 )
1036 .await;
1037 }
1038
1039 #[fuchsia::test]
1040 async fn reset_end_to_end() {
1041 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1042 end_to_end_test(
1043 async |conn, _incoming| {
1044 let (socket, other_socket) = SyncSocket::create_stream();
1045 let mut socket = Socket::from_socket(socket);
1046 let state =
1047 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1048 conn.reset(&addr).await.unwrap();
1049 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1050 state.wait_for_close().await.expect_err("expected reset");
1051 },
1052 async |conn, mut incoming| {
1053 println!("accepting request on connection 2");
1054 let request = incoming.next().await.unwrap();
1055 assert_eq!(request.address, addr.clone(),);
1056
1057 let (socket, other_socket) = SyncSocket::create_stream();
1058 let mut socket = Socket::from_socket(socket);
1059 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1060 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1061 state.wait_for_close().await.unwrap();
1062 },
1063 )
1064 .await;
1065 }
1066}