1use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::HashMap;
9use std::collections::hash_map::Entry;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll, Waker, ready};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23 Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
24};
25
26mod overflow_writer;
27mod pause_state;
28
29use overflow_writer::OverflowWriter;
30use pause_state::PauseState;
31
32pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
34impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35
36#[derive(Copy, Clone, PartialEq, Eq)]
37enum PausePacket {
38 Pause,
39 UnPause,
40}
41
42impl PausePacket {
43 fn bytes(&self) -> [u8; 1] {
44 match self {
45 PausePacket::Pause => [1],
46 PausePacket::UnPause => [0],
47 }
48 }
49}
50
51pub struct ReadyConnect<B, S> {
54 connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
55 packet_filler: Arc<UsbPacketFiller<B>>,
56 address: Address,
57 connection_state: ConnectionState,
58}
59
60impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
61 pub async fn finish_connect(self, socket: S) -> ConnectionState {
63 let (read_socket, write_socket) = socket.split();
64 let writer = {
65 let conns = self.connections.lock().unwrap();
66 let Some(conn) = conns.get(&self.address) else {
67 warn!("Connection state was missing after connection success!");
68 return self.connection_state;
69 };
70 let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
71 &conn.state
72 else {
73 warn!("Connection state was invalid after connection success!");
74 return self.connection_state;
75 };
76 reader_scope.spawn(Connection::<B, S>::run_socket(
77 read_socket,
78 self.address,
79 self.packet_filler,
80 Arc::clone(pause_state),
81 ));
82 Arc::clone(writer)
83 };
84 let mut writer = writer.lock().await;
85 let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
86 &mut *writer,
87 ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
88 ) else {
89 unreachable!("Connection completed multiple times!")
90 };
91
92 wakers.into_iter().for_each(Waker::wake);
93 self.connection_state
94 }
95}
96
97pub struct Connection<B, S> {
110 control_socket_writer: Option<Mutex<WriteHalf<S>>>,
111 packet_filler: Arc<UsbPacketFiller<B>>,
112 protocol_version: ProtocolVersion,
113 connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
114 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
115 task_scope: Scope,
116}
117
118impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
119 pub fn new(
126 protocol_version: ProtocolVersion,
127 control_socket: Option<S>,
128 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
129 ) -> Self {
130 let packet_filler = Arc::new(UsbPacketFiller::default());
131 let connections = Default::default();
132 let task_scope = Scope::new_with_name("vsock_usb");
133 let control_socket_writer = control_socket.map(|control_socket| {
134 let (control_socket_reader, control_socket_writer) = control_socket.split();
135 task_scope.spawn(Self::run_socket(
136 control_socket_reader,
137 Address::default(),
138 packet_filler.clone(),
139 PauseState::new(),
140 ));
141 Mutex::new(control_socket_writer)
142 });
143 Self {
144 control_socket_writer,
145 packet_filler,
146 connections,
147 incoming_requests_tx,
148 protocol_version,
149 task_scope,
150 }
151 }
152
153 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
154 let header = &mut Header::new(PacketType::Finish);
155 header.set_address(address);
156 usb_packet_filler
157 .write_vsock_packet(&Packet { header, payload: &[] })
158 .await
159 .expect("Finish packet should never be too big");
160 }
161
162 async fn run_socket(
163 mut reader: ReadHalf<S>,
164 address: Address,
165 usb_packet_filler: Arc<UsbPacketFiller<B>>,
166 pause_state: Arc<PauseState>,
167 ) {
168 let mut buf = [0; 4096];
169 loop {
170 log::trace!("reading from control socket");
171 let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
172 Ok(0) => {
173 if !address.is_zeros() {
174 Self::send_close_packet(&address, &usb_packet_filler).await;
175 }
176 return;
177 }
178 Ok(read) => read,
179 Err(err) => {
180 if address.is_zeros() {
181 log::error!("Error reading usb socket: {err:?}");
182 } else {
183 Self::send_close_packet(&address, &usb_packet_filler).await;
184 }
185 return;
186 }
187 };
188 log::trace!("writing {read} bytes to vsock packet");
189 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
190 log::trace!("wrote {read} bytes to vsock packet");
191 }
192 }
193
194 fn set_connection(
195 &self,
196 address: Address,
197 state: VsockConnectionState<S>,
198 ) -> Result<(), Error> {
199 let mut connections = self.connections.lock().unwrap();
200 if !connections.contains_key(&address) {
201 connections.insert(address.clone(), VsockConnection { _address: address, state });
202 Ok(())
203 } else {
204 Err(Error::other(format!("connection on address {address:?} already set")))
205 }
206 }
207
208 pub async fn send_empty_echo(&self) {
211 debug!("Sending empty echo packet");
212 let header = &mut Header::new(PacketType::Echo);
213 self.packet_filler
214 .write_vsock_packet(&Packet { header, payload: &[] })
215 .await
216 .expect("empty echo packet should never be too large to fit in a usb packet");
217 }
218
219 pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
224 let ready = self.connect_late(addr).await?;
225 Ok(ready.finish_connect(socket).await)
226 }
227
228 pub async fn connect_late(&self, addr: Address) -> Result<ReadyConnect<B, S>, Error> {
235 let (connected_tx, connected_rx) = oneshot::channel();
236
237 self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
238
239 let header = &mut Header::new(PacketType::Connect);
240 header.set_address(&addr);
241 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
242 let Ok(conn_state) = connected_rx.await else {
243 return Err(Error::other("Accept was never received for {addr:?}"));
244 };
245
246 Ok(ReadyConnect {
247 connections: Arc::clone(&self.connections),
248 packet_filler: Arc::clone(&self.packet_filler),
249 address: addr,
250 connection_state: conn_state,
251 })
252 }
253
254 pub async fn close(&self, address: &Address) {
256 Self::send_close_packet(address, &self.packet_filler).await
257 }
258
259 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
261 reset(address, &self.connections, &self.packet_filler).await
262 }
263
264 pub async fn accept(
268 &self,
269 request: ConnectionRequest,
270 socket: S,
271 ) -> Result<ConnectionState, Error> {
272 let ready = self.accept_late(request).await?;
273 Ok(ready.finish_connect(socket).await)
274 }
275
276 pub async fn accept_late(
280 &self,
281 request: ConnectionRequest,
282 ) -> Result<ReadyConnect<B, S>, Error> {
283 let address = request.address;
284 let notify_closed_rx;
285 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
286 let VsockConnectionState::ConnectingIncoming = &conn.state else {
287 return Err(Error::other(format!(
288 "Attempted to accept connection that was not waiting at {address:?}"
289 )));
290 };
291
292 let notify_closed = mpsc::channel(2);
293 notify_closed_rx = notify_closed.1;
294 let notify_closed = notify_closed.0;
295 let pause_state = PauseState::new();
296
297 let reader_scope = Scope::new_with_name("connection-reader");
298
299 conn.state = VsockConnectionState::Connected {
300 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
301 reader_scope,
302 notify_closed,
303 pause_state,
304 };
305 } else {
306 return Err(Error::other(format!(
307 "Attempting to accept connection that did not exist at {address:?}"
308 )));
309 }
310 let header = &mut Header::new(PacketType::Accept);
311 header.set_address(&address);
312 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
313 Ok(ReadyConnect {
314 connections: Arc::clone(&self.connections),
315 packet_filler: Arc::clone(&self.packet_filler),
316 address,
317 connection_state: ConnectionState(notify_closed_rx),
318 })
319 }
320
321 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
323 let address = request.address;
324 match self.connections.lock().unwrap().entry(address.clone()) {
325 Entry::Occupied(entry) => {
326 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
327 return Err(Error::other(format!(
328 "Attempted to reject connection that was not waiting at {address:?}"
329 )));
330 };
331 entry.remove();
332 }
333 Entry::Vacant(_) => {
334 return Err(Error::other(format!(
335 "Attempted to reject connection that was not waiting at {address:?}"
336 )));
337 }
338 }
339
340 let header = &mut Header::new(PacketType::Reset);
341 header.set_address(&address);
342 self.packet_filler
343 .write_vsock_packet(&Packet { header, payload: &[] })
344 .await
345 .expect("accept packet should never be too large for packet buffer");
346 Ok(())
347 }
348
349 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
350 if address.is_zeros() {
352 if let Some(writer) = self.control_socket_writer.as_ref() {
353 writer.lock().await.write_all(payload).await?;
354 } else {
355 trace!("Discarding {} bytes of data sent to control socket", payload.len());
356 }
357 Ok(())
358 } else {
359 let payload_socket;
360 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
361 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
362 warn!(
363 "Received data packet for connection in unexpected state for {address:?}"
364 );
365 return Ok(());
366 };
367 payload_socket = writer.clone();
368 } else {
369 warn!("Received data packet for connection that didn't exist at {address:?}");
370 return Ok(());
371 }
372 let mut socket_guard =
373 ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
374 let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
375 unreachable!("wait_available didn't wait until socket was available!");
376 };
377 match socket.write_all(payload) {
378 Err(err) => {
379 debug!(
380 "Write to socket address {address:?} failed, \
381 resetting connection immediately: {err:?}"
382 );
383 self.reset(&address)
384 .await
385 .inspect_err(|err| {
386 warn!(
387 "Attempt to reset connection to {address:?} \
388 failed after write error: {err:?}"
389 );
390 })
391 .ok();
392 }
393 Ok(status) => {
394 if status.overflowed() {
395 if self.protocol_version.has_pause_packets() {
396 let header = &mut Header::new(PacketType::Pause);
397 let payload = &PausePacket::Pause.bytes();
398 header.set_address(&address);
399 header.payload_len.set(payload.len() as u32);
400 self.packet_filler
401 .write_vsock_packet(&Packet { header, payload })
402 .await
403 .expect(
404 "pause packet should never be too large to fit in a usb packet",
405 );
406 }
407
408 let weak_payload_socket = Arc::downgrade(&payload_socket);
409 let connections = Arc::clone(&self.connections);
410 let has_pause_packets = self.protocol_version.has_pause_packets();
411 let packet_filler = Arc::clone(&self.packet_filler);
412 self.task_scope.spawn(async move {
413 let res = OverflowHandleFut::new(weak_payload_socket).await;
414
415 if let Err(err) = res {
416 debug!(
417 "Write to socket address {address:?} failed while \
418 processing backlog, resetting connection at next poll: {err:?}"
419 );
420 if let Err(err) = reset(&address, &connections, &packet_filler).await {
421 debug!("Error sending reset frame after overflow write failed: {err:?}");
422 }
423 } else if has_pause_packets {
424 let header = &mut Header::new(PacketType::Pause);
425 let payload = &PausePacket::UnPause.bytes();
426 header.set_address(&address);
427 header.payload_len.set(payload.len() as u32);
428 packet_filler
429 .write_vsock_packet(&Packet { header, payload })
430 .await
431 .expect("pause packet should never be too large to fit in a usb packet");
432 }
433 });
434 }
435 }
436 }
437 Ok(())
438 }
439 }
440
441 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
442 debug!("received echo for {address:?} with payload {payload:?}");
443 let header = &mut Header::new(PacketType::EchoReply);
444 header.payload_len.set(payload.len() as u32);
445 header.set_address(&address);
446 self.packet_filler
447 .write_vsock_packet(&Packet { header, payload })
448 .await
449 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
450 }
451
452 async fn handle_echo_reply_packet(
453 &self,
454 address: Address,
455 payload: &[u8],
456 ) -> Result<(), Error> {
457 debug!("received echo reply for {address:?} with payload {payload:?}");
459 Ok(())
460 }
461
462 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
463 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
464 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
465 let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
466 warn!("Received accept packet for connection in unexpected state for {address:?}");
467 return Ok(());
468 };
469 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
470 if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
471 warn!(
472 "Accept packet received for {address:?} but connect caller stopped waiting for it"
473 );
474 }
475 let pause_state = PauseState::new();
476
477 let reader_scope = Scope::new_with_name("connection-reader");
478 conn.state = VsockConnectionState::Connected {
479 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
480 reader_scope,
481 notify_closed,
482 pause_state,
483 };
484 } else {
485 warn!("Got accept packet for connection that was not being made at {address:?}");
486 return Ok(());
487 }
488 Ok(())
489 }
490
491 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
492 trace!("received connect packet for {address:?}");
493 match self.connections.lock().unwrap().entry(address.clone()) {
494 Entry::Vacant(entry) => {
495 debug!("valid connect request for {address:?}");
496 entry.insert(VsockConnection {
497 _address: address,
498 state: VsockConnectionState::ConnectingIncoming,
499 });
500 }
501 Entry::Occupied(_) => {
502 warn!(
503 "Received connect packet for already existing \
504 connection for address {address:?}. Ignoring"
505 );
506 return Ok(());
507 }
508 }
509
510 trace!("sending incoming connection request to client for {address:?}");
511 let connection_request = ConnectionRequest { address };
512 self.incoming_requests_tx
513 .clone()
514 .send(connection_request)
515 .await
516 .inspect(|_| trace!("sent incoming request for {address:?}"))
517 .map_err(|_| Error::other("Failed to send connection request"))
518 }
519
520 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
521 trace!("received finish packet for {address:?}");
522 let mut notify;
523 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
524 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
525 warn!(
526 "Received finish (close) packet for {address:?} \
527 which was not in a connected state. Ignoring and dropping connection state."
528 );
529 return Ok(());
530 };
531 notify = notify_closed;
532 } else {
533 warn!(
534 "Received finish (close) packet for connection that didn't exist \
535 on address {address:?}. Ignoring"
536 );
537 return Ok(());
538 }
539
540 notify.send(Ok(())).await.ok();
541
542 let header = &mut Header::new(PacketType::Reset);
543 header.set_address(&address);
544 self.packet_filler
545 .write_vsock_packet(&Packet { header, payload: &[] })
546 .await
547 .expect("accept packet should never be too large for packet buffer");
548 Ok(())
549 }
550
551 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
552 trace!("received reset packet for {address:?}");
553 let mut notify = None;
554 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
555 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
556 notify = Some(notify_closed);
557 } else {
558 debug!(
559 "Received reset packet for connection that wasn't in a connecting or \
560 disconnected state on address {address:?}."
561 );
562 }
563 } else {
564 warn!(
565 "Received reset packet for connection that didn't \
566 exist on address {address:?}. Ignoring"
567 );
568 }
569
570 if let Some(mut notify) = notify {
571 notify.send(Ok(())).await.ok();
572 }
573 Ok(())
574 }
575
576 async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
577 if !self.protocol_version.has_pause_packets() {
578 warn!(
579 "Got a pause packet while using protocol \
580 version {} which does not support them. Ignoring",
581 self.protocol_version
582 );
583 return Ok(());
584 }
585
586 let pause = match payload {
587 [1] => true,
588 [0] => false,
589 other => {
590 warn!("Ignoring unexpected pause packet payload {other:?}");
591 return Ok(());
592 }
593 };
594
595 if let Some(conn) = self.connections.lock().unwrap().get(&address) {
596 if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
597 pause_state.set_paused(pause);
598 } else {
599 warn!("Received pause packet for unestablished connection. Ignoring");
600 };
601 } else {
602 warn!(
603 "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
604 );
605 }
606
607 Ok(())
608 }
609
610 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
613 trace!("received vsock packet {header:?}", header = packet.header);
614 let payload_len = packet.header.payload_len.get() as usize;
615 let payload = &packet.payload[..payload_len];
616 let address = Address::from(packet.header);
617 match packet.header.packet_type {
618 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
619 PacketType::Data => self.handle_data_packet(address, payload).await,
620 PacketType::Accept => self.handle_accept_packet(address).await,
621 PacketType::Connect => self.handle_connect_packet(address).await,
622 PacketType::Finish => self.handle_finish_packet(address).await,
623 PacketType::Reset => self.handle_reset_packet(address).await,
624 PacketType::Echo => self.handle_echo_packet(address, payload).await,
625 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
626 PacketType::Pause => self.handle_pause_packet(address, payload).await,
627 }
628 }
629
630 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
637 self.packet_filler.fill_usb_packet(builder).await
638 }
639}
640
641async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
642 address: &Address,
643 connections: &std::sync::Mutex<HashMap<Address, VsockConnection<S>>>,
644 packet_filler: &UsbPacketFiller<B>,
645) -> Result<(), Error> {
646 let mut notify = None;
647 if let Some(conn) = connections.lock().unwrap().remove(&address) {
648 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
649 notify = Some(notify_closed);
650 }
651 } else {
652 return Err(Error::other(
653 "Client asked to reset connection {address:?} that did not exist",
654 ));
655 }
656
657 if let Some(mut notify) = notify {
658 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
659 }
660
661 let header = &mut Header::new(PacketType::Reset);
662 header.set_address(address);
663 packet_filler
664 .write_vsock_packet(&Packet { header, payload: &[] })
665 .await
666 .expect("Reset packet should never be too big");
667 Ok(())
668}
669
670enum ConnectionStateWriter<S> {
675 NotYetAvailable(Vec<Waker>),
676 Available(OverflowWriter<S>),
677}
678
679impl<S> ConnectionStateWriter<S> {
680 fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
682 ConnectionStateWriterFut { writer: this, lock_fut: None }
683 }
684}
685
686struct ConnectionStateWriterFut<S> {
688 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
689 lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
690}
691
692impl<S> Future for ConnectionStateWriterFut<S> {
693 type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
694
695 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
696 let writer = Arc::clone(&self.writer);
697 let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
698 let mut lock = ready!(lock_fut.poll_unpin(cx));
699 self.lock_fut = None;
700 match &mut *lock {
701 ConnectionStateWriter::Available(_) => Poll::Ready(lock),
702 ConnectionStateWriter::NotYetAvailable(queue) => {
703 queue.push(cx.waker().clone());
704 Poll::Pending
705 }
706 }
707 }
708}
709
710enum VsockConnectionState<S> {
711 ConnectingOutgoing(oneshot::Sender<ConnectionState>),
712 ConnectingIncoming,
713 Connected {
714 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
715 notify_closed: mpsc::Sender<Result<(), Error>>,
716 pause_state: Arc<PauseState>,
717 reader_scope: Scope,
718 },
719 Invalid,
720}
721
722struct VsockConnection<S> {
723 _address: Address,
724 state: VsockConnectionState<S>,
725}
726
727#[derive(Debug)]
731pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
732
733impl ConnectionState {
734 pub async fn wait_for_close(mut self) -> Result<(), Error> {
737 self.0
738 .next()
739 .await
740 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
741 }
742}
743
744#[derive(Debug)]
747pub struct ConnectionRequest {
748 address: Address,
749}
750
751impl ConnectionRequest {
752 pub fn new(address: Address) -> Self {
754 Self { address }
755 }
756
757 pub fn address(&self) -> &Address {
759 &self.address
760 }
761}
762
763#[cfg(test)]
764mod test {
765 use std::sync::Arc;
766
767 use crate::VsockPacketIterator;
768
769 use super::*;
770
771 #[cfg(not(target_os = "fuchsia"))]
772 use fuchsia_async::emulated_handle::Socket as SyncSocket;
773 use fuchsia_async::{Socket, Task};
774 use futures::StreamExt;
775 #[cfg(target_os = "fuchsia")]
776 use zx::Socket as SyncSocket;
777
778 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
779 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
780 loop {
781 println!("waiting for usb packet");
782 builder = echo_connection.fill_usb_packet(builder).await;
783 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
784 println!("got usb packet, echoing it back to the other side");
785 let mut packet_count = 0;
786 for packet in packets {
787 let packet = packet.unwrap();
788 match packet.header.packet_type {
789 PacketType::Connect => {
790 let mut reply_header = packet.header.clone();
792 reply_header.packet_type = PacketType::Accept;
793 echo_connection
794 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
795 .await
796 .unwrap();
797 }
798 PacketType::Accept => {
799 }
801 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
802 }
803 packet_count += 1;
804 }
805 println!("handled {packet_count} packets");
806 }
807 }
808
809 #[fuchsia::test]
810 async fn data_over_control_socket() {
811 let (socket, other_socket) = SyncSocket::create_stream();
812 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
813 let mut socket = Socket::from_socket(socket);
814 let connection = Arc::new(Connection::new(
815 ProtocolVersion::LATEST,
816 Some(Socket::from_socket(other_socket)),
817 incoming_requests_tx,
818 ));
819
820 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
821
822 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
823 println!("round tripping packet of size {size}");
824 socket.write_all(&vec![size; size as usize]).await.unwrap();
825 let mut buf = vec![0u8; size as usize];
826 socket.read_exact(&mut buf).await.unwrap();
827 assert_eq!(buf, vec![size; size as usize]);
828 }
829 echo_task.abort().await;
830 }
831
832 #[fuchsia::test]
833 async fn data_over_normal_outgoing_socket() {
834 let (_control_socket, other_socket) = SyncSocket::create_stream();
835 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
836 let connection = Arc::new(Connection::new(
837 ProtocolVersion::LATEST,
838 Some(Socket::from_socket(other_socket)),
839 incoming_requests_tx,
840 ));
841
842 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
843
844 let (socket, other_socket) = SyncSocket::create_stream();
845 let mut socket = Socket::from_socket(socket);
846 connection
847 .connect(
848 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
849 Socket::from_socket(other_socket),
850 )
851 .await
852 .unwrap();
853
854 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
855 println!("round tripping packet of size {size}");
856 socket.write_all(&vec![size; size as usize]).await.unwrap();
857 let mut buf = vec![0u8; size as usize];
858 socket.read_exact(&mut buf).await.unwrap();
859 assert_eq!(buf, vec![size; size as usize]);
860 }
861 echo_task.abort().await;
862 }
863
864 #[fuchsia::test]
865 async fn data_over_normal_incoming_socket() {
866 let (_control_socket, other_socket) = SyncSocket::create_stream();
867 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
868 let connection = Arc::new(Connection::new(
869 ProtocolVersion::LATEST,
870 Some(Socket::from_socket(other_socket)),
871 incoming_requests_tx,
872 ));
873
874 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
875
876 let header = &mut Header::new(PacketType::Connect);
877 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
878 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
879
880 let request = incoming_requests.next().await.unwrap();
881 assert_eq!(
882 request.address,
883 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
884 );
885
886 let (socket, other_socket) = SyncSocket::create_stream();
887 let mut socket = Socket::from_socket(socket);
888 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
889
890 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
891 println!("round tripping packet of size {size}");
892 socket.write_all(&vec![size; size as usize]).await.unwrap();
893 let mut buf = vec![0u8; size as usize];
894 socket.read_exact(&mut buf).await.unwrap();
895 assert_eq!(buf, vec![size; size as usize]);
896 }
897 echo_task.abort().await;
898 }
899
900 async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
901 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
902 loop {
903 builder = from.fill_usb_packet(builder).await;
904 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
905 for packet in packets {
906 println!("forwarding vsock packet");
907 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
908 }
909 }
910 }
911
912 pub(crate) trait EndToEndTestFn<R>:
913 AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
914 {
915 }
916 impl<T, R> EndToEndTestFn<R> for T where
917 T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
918 {
919 }
920
921 pub(crate) async fn end_to_end_test<R1, R2>(
922 left_side: impl EndToEndTestFn<R1>,
923 right_side: impl EndToEndTestFn<R2>,
924 ) -> (R1, R2) {
925 type Connection = crate::Connection<Vec<u8>, Socket>;
926 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
927 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
928 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
929 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
930
931 let connection1 = Arc::new(Connection::new(
932 ProtocolVersion::LATEST,
933 Some(Socket::from_socket(other_socket1)),
934 incoming_requests_tx1,
935 ));
936 let connection2 = Arc::new(Connection::new(
937 ProtocolVersion::LATEST,
938 Some(Socket::from_socket(other_socket2)),
939 incoming_requests_tx2,
940 ));
941
942 let conn1 = connection1.clone();
943 let conn2 = connection2.clone();
944 let passthrough_task = Task::spawn(async move {
945 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
946 println!("passthrough task loop ended");
947 });
948
949 let res = futures::join!(
950 left_side(connection1, incoming_requests1),
951 right_side(connection2, incoming_requests2)
952 );
953 passthrough_task.abort().await;
954 res
955 }
956
957 #[fuchsia::test]
958 async fn data_over_end_to_end() {
959 end_to_end_test(
960 async |conn, _incoming| {
961 println!("sending request on connection 1");
962 let (socket, other_socket) = SyncSocket::create_stream();
963 let mut socket = Socket::from_socket(socket);
964 let state = conn
965 .connect(
966 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
967 Socket::from_socket(other_socket),
968 )
969 .await
970 .unwrap();
971
972 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
973 println!("round tripping packet of size {size}");
974 socket.write_all(&vec![size; size as usize]).await.unwrap();
975 }
976 drop(socket);
977 state.wait_for_close().await.unwrap();
978 },
979 async |conn, mut incoming| {
980 println!("accepting request on connection 2");
981 let request = incoming.next().await.unwrap();
982 assert_eq!(
983 request.address,
984 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
985 );
986
987 let (socket, other_socket) = SyncSocket::create_stream();
988 let mut socket = Socket::from_socket(socket);
989 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
990
991 println!("accepted request on connection 2");
992 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
993 let mut buf = vec![0u8; size as usize];
994 socket.read_exact(&mut buf).await.unwrap();
995 assert_eq!(buf, vec![size; size as usize]);
996 }
997 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
998 state.wait_for_close().await.unwrap();
999 },
1000 )
1001 .await;
1002 }
1003
1004 #[fuchsia::test]
1005 async fn normal_close_end_to_end() {
1006 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1007 end_to_end_test(
1008 async |conn, _incoming| {
1009 let (socket, other_socket) = SyncSocket::create_stream();
1010 let mut socket = Socket::from_socket(socket);
1011 let state =
1012 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1013 conn.close(&addr).await;
1014 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1015 state.wait_for_close().await.unwrap();
1016 },
1017 async |conn, mut incoming| {
1018 println!("accepting request on connection 2");
1019 let request = incoming.next().await.unwrap();
1020 assert_eq!(request.address, addr.clone(),);
1021
1022 let (socket, other_socket) = SyncSocket::create_stream();
1023 let mut socket = Socket::from_socket(socket);
1024 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1025 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1026 state.wait_for_close().await.unwrap();
1027 },
1028 )
1029 .await;
1030 }
1031
1032 #[fuchsia::test]
1033 async fn reset_end_to_end() {
1034 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1035 end_to_end_test(
1036 async |conn, _incoming| {
1037 let (socket, other_socket) = SyncSocket::create_stream();
1038 let mut socket = Socket::from_socket(socket);
1039 let state =
1040 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1041 conn.reset(&addr).await.unwrap();
1042 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1043 state.wait_for_close().await.expect_err("expected reset");
1044 },
1045 async |conn, mut incoming| {
1046 println!("accepting request on connection 2");
1047 let request = incoming.next().await.unwrap();
1048 assert_eq!(request.address, addr.clone(),);
1049
1050 let (socket, other_socket) = SyncSocket::create_stream();
1051 let mut socket = Socket::from_socket(socket);
1052 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1053 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1054 state.wait_for_close().await.unwrap();
1055 },
1056 )
1057 .await;
1058 }
1059}