1use futures::channel::{mpsc, oneshot};
6use futures::lock::{Mutex, OwnedMutexGuard};
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::future::Future;
11use std::io::{Error, ErrorKind};
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{ready, Context, Poll, Waker};
16
17use fuchsia_async::Scope;
18use futures::io::{ReadHalf, WriteHalf};
19use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, SinkExt, StreamExt};
20
21use crate::connection::overflow_writer::OverflowHandleFut;
22use crate::{
23 Address, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder, UsbPacketFiller,
24};
25
26mod overflow_writer;
27mod pause_state;
28
29use overflow_writer::OverflowWriter;
30use pause_state::PauseState;
31
32pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
34impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
35
36#[derive(Copy, Clone, PartialEq, Eq)]
37enum PausePacket {
38 Pause,
39 UnPause,
40}
41
42impl PausePacket {
43 fn bytes(&self) -> [u8; 1] {
44 match self {
45 PausePacket::Pause => [1],
46 PausePacket::UnPause => [0],
47 }
48 }
49}
50
51pub struct ReadyConnect<B, S> {
54 connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
55 packet_filler: Arc<UsbPacketFiller<B>>,
56 address: Address,
57 connection_state: ConnectionState,
58}
59
60impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> ReadyConnect<B, S> {
61 pub async fn finish_connect(self, socket: S) -> ConnectionState {
63 let (read_socket, write_socket) = socket.split();
64 let writer = {
65 let conns = self.connections.lock().unwrap();
66 let Some(conn) = conns.get(&self.address) else {
67 warn!("Connection state was missing after connection success!");
68 return self.connection_state;
69 };
70 let VsockConnectionState::Connected { writer, reader_scope, pause_state, .. } =
71 &conn.state
72 else {
73 warn!("Connection state was invalid after connection success!");
74 return self.connection_state;
75 };
76 reader_scope.spawn(Connection::<B, S>::run_socket(
77 read_socket,
78 self.address,
79 self.packet_filler,
80 Arc::clone(pause_state),
81 ));
82 Arc::clone(writer)
83 };
84 let mut writer = writer.lock().await;
85 let ConnectionStateWriter::NotYetAvailable(wakers) = std::mem::replace(
86 &mut *writer,
87 ConnectionStateWriter::Available(OverflowWriter::new(write_socket)),
88 ) else {
89 unreachable!("Connection completed multiple times!")
90 };
91
92 wakers.into_iter().for_each(Waker::wake);
93 self.connection_state
94 }
95}
96
97pub struct Connection<B, S> {
110 control_socket_writer: Option<Mutex<WriteHalf<S>>>,
111 packet_filler: Arc<UsbPacketFiller<B>>,
112 protocol_version: ProtocolVersion,
113 connections: Arc<std::sync::Mutex<HashMap<Address, VsockConnection<S>>>>,
114 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
115 task_scope: Scope,
116}
117
118impl<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static> Connection<B, S> {
119 pub fn new(
126 protocol_version: ProtocolVersion,
127 control_socket: Option<S>,
128 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
129 ) -> Self {
130 let packet_filler = Arc::new(UsbPacketFiller::default());
131 let connections = Default::default();
132 let task_scope = Scope::new_with_name("vsock_usb");
133 let control_socket_writer = control_socket.map(|control_socket| {
134 let (control_socket_reader, control_socket_writer) = control_socket.split();
135 task_scope.spawn(Self::run_socket(
136 control_socket_reader,
137 Address::default(),
138 packet_filler.clone(),
139 PauseState::new(),
140 ));
141 Mutex::new(control_socket_writer)
142 });
143 Self {
144 control_socket_writer,
145 packet_filler,
146 connections,
147 incoming_requests_tx,
148 protocol_version,
149 task_scope,
150 }
151 }
152
153 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
154 let header = &mut Header::new(PacketType::Finish);
155 header.set_address(address);
156 usb_packet_filler
157 .write_vsock_packet(&Packet { header, payload: &[] })
158 .await
159 .expect("Finish packet should never be too big");
160 }
161
162 async fn run_socket(
163 mut reader: ReadHalf<S>,
164 address: Address,
165 usb_packet_filler: Arc<UsbPacketFiller<B>>,
166 pause_state: Arc<PauseState>,
167 ) {
168 let mut buf = [0; 4096];
169 loop {
170 log::trace!("reading from control socket");
171 let read = match pause_state.while_unpaused(reader.read(&mut buf)).await {
172 Ok(0) => {
173 if !address.is_zeros() {
174 Self::send_close_packet(&address, &usb_packet_filler).await;
175 }
176 return;
177 }
178 Ok(read) => read,
179 Err(err) => {
180 if address.is_zeros() {
181 log::error!("Error reading usb socket: {err:?}");
182 } else {
183 Self::send_close_packet(&address, &usb_packet_filler).await;
184 }
185 return;
186 }
187 };
188 log::trace!("writing {read} bytes to vsock packet");
189 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
190 log::trace!("wrote {read} bytes to vsock packet");
191 }
192 }
193
194 fn set_connection(
195 &self,
196 address: Address,
197 state: VsockConnectionState<S>,
198 ) -> Result<(), Error> {
199 let mut connections = self.connections.lock().unwrap();
200 if !connections.contains_key(&address) {
201 connections.insert(address.clone(), VsockConnection { _address: address, state });
202 Ok(())
203 } else {
204 Err(Error::other(format!("connection on address {address:?} already set")))
205 }
206 }
207
208 pub async fn send_empty_echo(&self) {
211 debug!("Sending empty echo packet");
212 let header = &mut Header::new(PacketType::Echo);
213 self.packet_filler
214 .write_vsock_packet(&Packet { header, payload: &[] })
215 .await
216 .expect("empty echo packet should never be too large to fit in a usb packet");
217 }
218
219 pub async fn connect(&self, addr: Address, socket: S) -> Result<ConnectionState, Error> {
224 let ready = self.connect_late(addr).await?;
225 Ok(ready.finish_connect(socket).await)
226 }
227
228 pub async fn connect_late(&self, addr: Address) -> Result<ReadyConnect<B, S>, Error> {
235 let (connected_tx, connected_rx) = oneshot::channel();
236
237 self.set_connection(addr.clone(), VsockConnectionState::ConnectingOutgoing(connected_tx))?;
238
239 let header = &mut Header::new(PacketType::Connect);
240 header.set_address(&addr);
241 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
242 let Ok(conn_state) = connected_rx.await else {
243 return Err(Error::other("Accept was never received for {addr:?}"));
244 };
245
246 Ok(ReadyConnect {
247 connections: Arc::clone(&self.connections),
248 packet_filler: Arc::clone(&self.packet_filler),
249 address: addr,
250 connection_state: conn_state,
251 })
252 }
253
254 pub async fn close(&self, address: &Address) {
256 Self::send_close_packet(address, &self.packet_filler).await
257 }
258
259 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
261 reset(address, &self.connections, &self.packet_filler).await
262 }
263
264 pub async fn accept(
268 &self,
269 request: ConnectionRequest,
270 socket: S,
271 ) -> Result<ConnectionState, Error> {
272 let ready = self.accept_late(request).await?;
273 Ok(ready.finish_connect(socket).await)
274 }
275
276 pub async fn accept_late(
280 &self,
281 request: ConnectionRequest,
282 ) -> Result<ReadyConnect<B, S>, Error> {
283 let address = request.address;
284 let notify_closed_rx;
285 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
286 let VsockConnectionState::ConnectingIncoming = &conn.state else {
287 return Err(Error::other(format!(
288 "Attempted to accept connection that was not waiting at {address:?}"
289 )));
290 };
291
292 let notify_closed = mpsc::channel(2);
293 notify_closed_rx = notify_closed.1;
294 let notify_closed = notify_closed.0;
295 let pause_state = PauseState::new();
296
297 let reader_scope = Scope::new_with_name("connection-reader");
298
299 conn.state = VsockConnectionState::Connected {
300 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
301 reader_scope,
302 notify_closed,
303 pause_state,
304 };
305 } else {
306 return Err(Error::other(format!(
307 "Attempting to accept connection that did not exist at {address:?}"
308 )));
309 }
310 let header = &mut Header::new(PacketType::Accept);
311 header.set_address(&address);
312 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
313 Ok(ReadyConnect {
314 connections: Arc::clone(&self.connections),
315 packet_filler: Arc::clone(&self.packet_filler),
316 address,
317 connection_state: ConnectionState(notify_closed_rx),
318 })
319 }
320
321 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
323 let address = request.address;
324 match self.connections.lock().unwrap().entry(address.clone()) {
325 Entry::Occupied(entry) => {
326 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
327 return Err(Error::other(format!(
328 "Attempted to reject connection that was not waiting at {address:?}"
329 )));
330 };
331 entry.remove();
332 }
333 Entry::Vacant(_) => {
334 return Err(Error::other(format!(
335 "Attempted to reject connection that was not waiting at {address:?}"
336 )));
337 }
338 }
339
340 let header = &mut Header::new(PacketType::Reset);
341 header.set_address(&address);
342 self.packet_filler
343 .write_vsock_packet(&Packet { header, payload: &[] })
344 .await
345 .expect("accept packet should never be too large for packet buffer");
346 Ok(())
347 }
348
349 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
350 if address.is_zeros() {
352 if let Some(writer) = self.control_socket_writer.as_ref() {
353 writer.lock().await.write_all(payload).await?;
354 } else {
355 trace!("Discarding {} bytes of data sent to control socket", payload.len());
356 }
357 Ok(())
358 } else {
359 let payload_socket;
360 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
361 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
362 warn!(
363 "Received data packet for connection in unexpected state for {address:?}"
364 );
365 return Ok(());
366 };
367 payload_socket = writer.clone();
368 } else {
369 warn!("Received data packet for connection that didn't exist at {address:?}");
370 return Ok(());
371 }
372 let mut socket_guard =
373 ConnectionStateWriter::wait_available(Arc::clone(&payload_socket)).await;
374 let ConnectionStateWriter::Available(socket) = &mut *socket_guard else {
375 unreachable!("wait_available didn't wait until socket was available!");
376 };
377 match socket.write_all(payload) {
378 Err(err) => {
379 debug!(
380 "Write to socket address {address:?} failed, \
381 resetting connection immediately: {err:?}"
382 );
383 self.reset(&address)
384 .await
385 .inspect_err(|err| {
386 warn!(
387 "Attempt to reset connection to {address:?} \
388 failed after write error: {err:?}"
389 );
390 })
391 .ok();
392 }
393 Ok(status) => {
394 if status.overflowed() {
395 if self.protocol_version.has_pause_packets() {
396 let header = &mut Header::new(PacketType::Pause);
397 header.set_address(&address);
398 self.packet_filler
399 .write_vsock_packet(&Packet {
400 header,
401 payload: &PausePacket::Pause.bytes(),
402 })
403 .await
404 .expect(
405 "pause packet should never be too large to fit in a usb packet",
406 );
407 }
408
409 let weak_payload_socket = Arc::downgrade(&payload_socket);
410 let connections = Arc::clone(&self.connections);
411 let has_pause_packets = self.protocol_version.has_pause_packets();
412 let packet_filler = Arc::clone(&self.packet_filler);
413 self.task_scope.spawn(async move {
414 let res = OverflowHandleFut::new(weak_payload_socket).await;
415
416 if let Err(err) = res {
417 debug!(
418 "Write to socket address {address:?} failed while \
419 processing backlog, resetting connection at next poll: {err:?}"
420 );
421 if let Err(err) = reset(&address, &connections, &packet_filler).await {
422 debug!("Error sending reset frame after overflow write failed: {err:?}");
423 }
424 } else if has_pause_packets {
425 let header = &mut Header::new(PacketType::Pause);
426 header.set_address(&address);
427 packet_filler
428 .write_vsock_packet(&Packet { header, payload: &PausePacket::UnPause.bytes() })
429 .await
430 .expect("pause packet should never be too large to fit in a usb packet");
431 }
432 });
433 }
434 }
435 }
436 Ok(())
437 }
438 }
439
440 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
441 debug!("received echo for {address:?} with payload {payload:?}");
442 let header = &mut Header::new(PacketType::EchoReply);
443 header.payload_len.set(payload.len() as u32);
444 header.set_address(&address);
445 self.packet_filler
446 .write_vsock_packet(&Packet { header, payload })
447 .await
448 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
449 }
450
451 async fn handle_echo_reply_packet(
452 &self,
453 address: Address,
454 payload: &[u8],
455 ) -> Result<(), Error> {
456 debug!("received echo reply for {address:?} with payload {payload:?}");
458 Ok(())
459 }
460
461 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
462 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
463 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
464 let VsockConnectionState::ConnectingOutgoing(connected_tx) = state else {
465 warn!("Received accept packet for connection in unexpected state for {address:?}");
466 return Ok(());
467 };
468 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
469 if connected_tx.send(ConnectionState(notify_closed_rx)).is_err() {
470 warn!(
471 "Accept packet received for {address:?} but connect caller stopped waiting for it"
472 );
473 }
474 let pause_state = PauseState::new();
475
476 let reader_scope = Scope::new_with_name("connection-reader");
477 conn.state = VsockConnectionState::Connected {
478 writer: Arc::new(Mutex::new(ConnectionStateWriter::NotYetAvailable(Vec::new()))),
479 reader_scope,
480 notify_closed,
481 pause_state,
482 };
483 } else {
484 warn!("Got accept packet for connection that was not being made at {address:?}");
485 return Ok(());
486 }
487 Ok(())
488 }
489
490 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
491 trace!("received connect packet for {address:?}");
492 match self.connections.lock().unwrap().entry(address.clone()) {
493 Entry::Vacant(entry) => {
494 debug!("valid connect request for {address:?}");
495 entry.insert(VsockConnection {
496 _address: address,
497 state: VsockConnectionState::ConnectingIncoming,
498 });
499 }
500 Entry::Occupied(_) => {
501 warn!(
502 "Received connect packet for already existing \
503 connection for address {address:?}. Ignoring"
504 );
505 return Ok(());
506 }
507 }
508
509 trace!("sending incoming connection request to client for {address:?}");
510 let connection_request = ConnectionRequest { address };
511 self.incoming_requests_tx
512 .clone()
513 .send(connection_request)
514 .await
515 .inspect(|_| trace!("sent incoming request for {address:?}"))
516 .map_err(|_| Error::other("Failed to send connection request"))
517 }
518
519 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
520 trace!("received finish packet for {address:?}");
521 let mut notify;
522 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
523 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
524 warn!(
525 "Received finish (close) packet for {address:?} \
526 which was not in a connected state. Ignoring and dropping connection state."
527 );
528 return Ok(());
529 };
530 notify = notify_closed;
531 } else {
532 warn!(
533 "Received finish (close) packet for connection that didn't exist \
534 on address {address:?}. Ignoring"
535 );
536 return Ok(());
537 }
538
539 notify.send(Ok(())).await.ok();
540
541 let header = &mut Header::new(PacketType::Reset);
542 header.set_address(&address);
543 self.packet_filler
544 .write_vsock_packet(&Packet { header, payload: &[] })
545 .await
546 .expect("accept packet should never be too large for packet buffer");
547 Ok(())
548 }
549
550 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
551 trace!("received reset packet for {address:?}");
552 let mut notify = None;
553 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
554 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
555 notify = Some(notify_closed);
556 } else {
557 debug!(
558 "Received reset packet for connection that wasn't in a connecting or \
559 disconnected state on address {address:?}."
560 );
561 }
562 } else {
563 warn!(
564 "Received reset packet for connection that didn't \
565 exist on address {address:?}. Ignoring"
566 );
567 }
568
569 if let Some(mut notify) = notify {
570 notify.send(Ok(())).await.ok();
571 }
572 Ok(())
573 }
574
575 async fn handle_pause_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
576 if !self.protocol_version.has_pause_packets() {
577 warn!(
578 "Got a pause packet while using protocol \
579 version {} which does not support them. Ignoring",
580 self.protocol_version
581 );
582 return Ok(());
583 }
584
585 let pause = match payload {
586 [1] => true,
587 [0] => false,
588 other => {
589 warn!("Ignoring unexpected pause packet payload {other:?}");
590 return Ok(());
591 }
592 };
593
594 if let Some(conn) = self.connections.lock().unwrap().get(&address) {
595 if let VsockConnectionState::Connected { pause_state, .. } = &conn.state {
596 pause_state.set_paused(pause);
597 } else {
598 warn!("Received pause packet for unestablished connection. Ignoring");
599 };
600 } else {
601 warn!(
602 "Received pause packet for connection that didn't exist on address {address:?}. Ignoring"
603 );
604 }
605
606 Ok(())
607 }
608
609 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
612 trace!("received vsock packet {header:?}", header = packet.header);
613 let payload_len = packet.header.payload_len.get() as usize;
614 let payload = &packet.payload[..payload_len];
615 let address = Address::from(packet.header);
616 match packet.header.packet_type {
617 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
618 PacketType::Data => self.handle_data_packet(address, payload).await,
619 PacketType::Accept => self.handle_accept_packet(address).await,
620 PacketType::Connect => self.handle_connect_packet(address).await,
621 PacketType::Finish => self.handle_finish_packet(address).await,
622 PacketType::Reset => self.handle_reset_packet(address).await,
623 PacketType::Echo => self.handle_echo_packet(address, payload).await,
624 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
625 PacketType::Pause => self.handle_pause_packet(address, payload).await,
626 }
627 }
628
629 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
636 self.packet_filler.fill_usb_packet(builder).await
637 }
638}
639
640async fn reset<B: PacketBuffer, S: AsyncRead + AsyncWrite + Send + 'static>(
641 address: &Address,
642 connections: &std::sync::Mutex<HashMap<Address, VsockConnection<S>>>,
643 packet_filler: &UsbPacketFiller<B>,
644) -> Result<(), Error> {
645 let mut notify = None;
646 if let Some(conn) = connections.lock().unwrap().remove(&address) {
647 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
648 notify = Some(notify_closed);
649 }
650 } else {
651 return Err(Error::other(
652 "Client asked to reset connection {address:?} that did not exist",
653 ));
654 }
655
656 if let Some(mut notify) = notify {
657 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
658 }
659
660 let header = &mut Header::new(PacketType::Reset);
661 header.set_address(address);
662 packet_filler
663 .write_vsock_packet(&Packet { header, payload: &[] })
664 .await
665 .expect("Reset packet should never be too big");
666 Ok(())
667}
668
669enum ConnectionStateWriter<S> {
674 NotYetAvailable(Vec<Waker>),
675 Available(OverflowWriter<S>),
676}
677
678impl<S> ConnectionStateWriter<S> {
679 fn wait_available(this: Arc<Mutex<ConnectionStateWriter<S>>>) -> ConnectionStateWriterFut<S> {
681 ConnectionStateWriterFut { writer: this, lock_fut: None }
682 }
683}
684
685struct ConnectionStateWriterFut<S> {
687 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
688 lock_fut: Option<futures::lock::OwnedMutexLockFuture<ConnectionStateWriter<S>>>,
689}
690
691impl<S> Future for ConnectionStateWriterFut<S> {
692 type Output = OwnedMutexGuard<ConnectionStateWriter<S>>;
693
694 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
695 let writer = Arc::clone(&self.writer);
696 let lock_fut = self.lock_fut.get_or_insert_with(|| writer.lock_owned());
697 let mut lock = ready!(lock_fut.poll_unpin(cx));
698 self.lock_fut = None;
699 match &mut *lock {
700 ConnectionStateWriter::Available(_) => Poll::Ready(lock),
701 ConnectionStateWriter::NotYetAvailable(queue) => {
702 queue.push(cx.waker().clone());
703 Poll::Pending
704 }
705 }
706 }
707}
708
709enum VsockConnectionState<S> {
710 ConnectingOutgoing(oneshot::Sender<ConnectionState>),
711 ConnectingIncoming,
712 Connected {
713 writer: Arc<Mutex<ConnectionStateWriter<S>>>,
714 notify_closed: mpsc::Sender<Result<(), Error>>,
715 pause_state: Arc<PauseState>,
716 reader_scope: Scope,
717 },
718 Invalid,
719}
720
721struct VsockConnection<S> {
722 _address: Address,
723 state: VsockConnectionState<S>,
724}
725
726#[derive(Debug)]
730pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
731
732impl ConnectionState {
733 pub async fn wait_for_close(mut self) -> Result<(), Error> {
736 self.0
737 .next()
738 .await
739 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
740 }
741}
742
743#[derive(Debug)]
746pub struct ConnectionRequest {
747 address: Address,
748}
749
750impl ConnectionRequest {
751 pub fn new(address: Address) -> Self {
753 Self { address }
754 }
755
756 pub fn address(&self) -> &Address {
758 &self.address
759 }
760}
761
762#[cfg(test)]
763mod test {
764 use std::sync::Arc;
765
766 use crate::VsockPacketIterator;
767
768 use super::*;
769
770 #[cfg(not(target_os = "fuchsia"))]
771 use fuchsia_async::emulated_handle::Socket as SyncSocket;
772 use fuchsia_async::{Socket, Task};
773 use futures::StreamExt;
774 #[cfg(target_os = "fuchsia")]
775 use zx::Socket as SyncSocket;
776
777 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>, Socket>>) {
778 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
779 loop {
780 println!("waiting for usb packet");
781 builder = echo_connection.fill_usb_packet(builder).await;
782 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
783 println!("got usb packet, echoing it back to the other side");
784 let mut packet_count = 0;
785 for packet in packets {
786 let packet = packet.unwrap();
787 match packet.header.packet_type {
788 PacketType::Connect => {
789 let mut reply_header = packet.header.clone();
791 reply_header.packet_type = PacketType::Accept;
792 echo_connection
793 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
794 .await
795 .unwrap();
796 }
797 PacketType::Accept => {
798 }
800 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
801 }
802 packet_count += 1;
803 }
804 println!("handled {packet_count} packets");
805 }
806 }
807
808 #[fuchsia::test]
809 async fn data_over_control_socket() {
810 let (socket, other_socket) = SyncSocket::create_stream();
811 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
812 let mut socket = Socket::from_socket(socket);
813 let connection = Arc::new(Connection::new(
814 ProtocolVersion::LATEST,
815 Some(Socket::from_socket(other_socket)),
816 incoming_requests_tx,
817 ));
818
819 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
820
821 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
822 println!("round tripping packet of size {size}");
823 socket.write_all(&vec![size; size as usize]).await.unwrap();
824 let mut buf = vec![0u8; size as usize];
825 socket.read_exact(&mut buf).await.unwrap();
826 assert_eq!(buf, vec![size; size as usize]);
827 }
828 echo_task.abort().await;
829 }
830
831 #[fuchsia::test]
832 async fn data_over_normal_outgoing_socket() {
833 let (_control_socket, other_socket) = SyncSocket::create_stream();
834 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
835 let connection = Arc::new(Connection::new(
836 ProtocolVersion::LATEST,
837 Some(Socket::from_socket(other_socket)),
838 incoming_requests_tx,
839 ));
840
841 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
842
843 let (socket, other_socket) = SyncSocket::create_stream();
844 let mut socket = Socket::from_socket(socket);
845 connection
846 .connect(
847 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
848 Socket::from_socket(other_socket),
849 )
850 .await
851 .unwrap();
852
853 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
854 println!("round tripping packet of size {size}");
855 socket.write_all(&vec![size; size as usize]).await.unwrap();
856 let mut buf = vec![0u8; size as usize];
857 socket.read_exact(&mut buf).await.unwrap();
858 assert_eq!(buf, vec![size; size as usize]);
859 }
860 echo_task.abort().await;
861 }
862
863 #[fuchsia::test]
864 async fn data_over_normal_incoming_socket() {
865 let (_control_socket, other_socket) = SyncSocket::create_stream();
866 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
867 let connection = Arc::new(Connection::new(
868 ProtocolVersion::LATEST,
869 Some(Socket::from_socket(other_socket)),
870 incoming_requests_tx,
871 ));
872
873 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
874
875 let header = &mut Header::new(PacketType::Connect);
876 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
877 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
878
879 let request = incoming_requests.next().await.unwrap();
880 assert_eq!(
881 request.address,
882 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
883 );
884
885 let (socket, other_socket) = SyncSocket::create_stream();
886 let mut socket = Socket::from_socket(socket);
887 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
888
889 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
890 println!("round tripping packet of size {size}");
891 socket.write_all(&vec![size; size as usize]).await.unwrap();
892 let mut buf = vec![0u8; size as usize];
893 socket.read_exact(&mut buf).await.unwrap();
894 assert_eq!(buf, vec![size; size as usize]);
895 }
896 echo_task.abort().await;
897 }
898
899 async fn copy_connection(from: &Connection<Vec<u8>, Socket>, to: &Connection<Vec<u8>, Socket>) {
900 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
901 loop {
902 builder = from.fill_usb_packet(builder).await;
903 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
904 for packet in packets {
905 println!("forwarding vsock packet");
906 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
907 }
908 }
909 }
910
911 pub(crate) trait EndToEndTestFn<R>:
912 AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
913 {
914 }
915 impl<T, R> EndToEndTestFn<R> for T where
916 T: AsyncFnOnce(Arc<Connection<Vec<u8>, Socket>>, mpsc::Receiver<ConnectionRequest>) -> R
917 {
918 }
919
920 pub(crate) async fn end_to_end_test<R1, R2>(
921 left_side: impl EndToEndTestFn<R1>,
922 right_side: impl EndToEndTestFn<R2>,
923 ) -> (R1, R2) {
924 type Connection = crate::Connection<Vec<u8>, Socket>;
925 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
926 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
927 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
928 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
929
930 let connection1 = Arc::new(Connection::new(
931 ProtocolVersion::LATEST,
932 Some(Socket::from_socket(other_socket1)),
933 incoming_requests_tx1,
934 ));
935 let connection2 = Arc::new(Connection::new(
936 ProtocolVersion::LATEST,
937 Some(Socket::from_socket(other_socket2)),
938 incoming_requests_tx2,
939 ));
940
941 let conn1 = connection1.clone();
942 let conn2 = connection2.clone();
943 let passthrough_task = Task::spawn(async move {
944 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
945 println!("passthrough task loop ended");
946 });
947
948 let res = futures::join!(
949 left_side(connection1, incoming_requests1),
950 right_side(connection2, incoming_requests2)
951 );
952 passthrough_task.abort().await;
953 res
954 }
955
956 #[fuchsia::test]
957 async fn data_over_end_to_end() {
958 end_to_end_test(
959 async |conn, _incoming| {
960 println!("sending request on connection 1");
961 let (socket, other_socket) = SyncSocket::create_stream();
962 let mut socket = Socket::from_socket(socket);
963 let state = conn
964 .connect(
965 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
966 Socket::from_socket(other_socket),
967 )
968 .await
969 .unwrap();
970
971 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
972 println!("round tripping packet of size {size}");
973 socket.write_all(&vec![size; size as usize]).await.unwrap();
974 }
975 drop(socket);
976 state.wait_for_close().await.unwrap();
977 },
978 async |conn, mut incoming| {
979 println!("accepting request on connection 2");
980 let request = incoming.next().await.unwrap();
981 assert_eq!(
982 request.address,
983 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
984 );
985
986 let (socket, other_socket) = SyncSocket::create_stream();
987 let mut socket = Socket::from_socket(socket);
988 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
989
990 println!("accepted request on connection 2");
991 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
992 let mut buf = vec![0u8; size as usize];
993 socket.read_exact(&mut buf).await.unwrap();
994 assert_eq!(buf, vec![size; size as usize]);
995 }
996 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
997 state.wait_for_close().await.unwrap();
998 },
999 )
1000 .await;
1001 }
1002
1003 #[fuchsia::test]
1004 async fn normal_close_end_to_end() {
1005 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1006 end_to_end_test(
1007 async |conn, _incoming| {
1008 let (socket, other_socket) = SyncSocket::create_stream();
1009 let mut socket = Socket::from_socket(socket);
1010 let state =
1011 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1012 conn.close(&addr).await;
1013 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1014 state.wait_for_close().await.unwrap();
1015 },
1016 async |conn, mut incoming| {
1017 println!("accepting request on connection 2");
1018 let request = incoming.next().await.unwrap();
1019 assert_eq!(request.address, addr.clone(),);
1020
1021 let (socket, other_socket) = SyncSocket::create_stream();
1022 let mut socket = Socket::from_socket(socket);
1023 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1024 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1025 state.wait_for_close().await.unwrap();
1026 },
1027 )
1028 .await;
1029 }
1030
1031 #[fuchsia::test]
1032 async fn reset_end_to_end() {
1033 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
1034 end_to_end_test(
1035 async |conn, _incoming| {
1036 let (socket, other_socket) = SyncSocket::create_stream();
1037 let mut socket = Socket::from_socket(socket);
1038 let state =
1039 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
1040 conn.reset(&addr).await.unwrap();
1041 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1042 state.wait_for_close().await.expect_err("expected reset");
1043 },
1044 async |conn, mut incoming| {
1045 println!("accepting request on connection 2");
1046 let request = incoming.next().await.unwrap();
1047 assert_eq!(request.address, addr.clone(),);
1048
1049 let (socket, other_socket) = SyncSocket::create_stream();
1050 let mut socket = Socket::from_socket(socket);
1051 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
1052 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
1053 state.wait_for_close().await.unwrap();
1054 },
1055 )
1056 .await;
1057 }
1058}