1use futures::channel::{mpsc, oneshot};
6use futures::lock::Mutex;
7use log::{debug, trace, warn};
8use std::collections::hash_map::Entry;
9use std::collections::HashMap;
10use std::io::{Error, ErrorKind};
11use std::ops::DerefMut;
12use std::sync::Arc;
13
14use fuchsia_async::{Scope, Socket};
15use futures::io::{ReadHalf, WriteHalf};
16use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt};
17
18use crate::{Address, Header, Packet, PacketType, UsbPacketBuilder, UsbPacketFiller};
19
20pub trait PacketBuffer: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
22impl<T> PacketBuffer for T where T: DerefMut<Target = [u8]> + Send + Unpin + 'static {}
23
24pub struct Connection<B> {
37 control_socket_writer: Mutex<WriteHalf<Socket>>,
38 packet_filler: Arc<UsbPacketFiller<B>>,
39 connections: std::sync::Mutex<HashMap<Address, VsockConnection>>,
40 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
41 _task_scope: Scope,
42}
43
44impl<B: PacketBuffer> Connection<B> {
45 pub fn new(
51 control_socket: Socket,
52 incoming_requests_tx: mpsc::Sender<ConnectionRequest>,
53 ) -> Self {
54 let (control_socket_reader, control_socket_writer) = control_socket.split();
55 let control_socket_writer = Mutex::new(control_socket_writer);
56 let packet_filler = Arc::new(UsbPacketFiller::default());
57 let connections = Default::default();
58 let task_scope = Scope::new_with_name("vsock_usb");
59 task_scope.spawn(Self::run_socket(
60 control_socket_reader,
61 Address::default(),
62 packet_filler.clone(),
63 ));
64 Self {
65 control_socket_writer,
66 packet_filler,
67 connections,
68 incoming_requests_tx,
69 _task_scope: task_scope,
70 }
71 }
72
73 async fn send_close_packet(address: &Address, usb_packet_filler: &Arc<UsbPacketFiller<B>>) {
74 let header = &mut Header::new(PacketType::Finish);
75 header.set_address(address);
76 usb_packet_filler
77 .write_vsock_packet(&Packet { header, payload: &[] })
78 .await
79 .expect("Finish packet should never be too big");
80 }
81
82 async fn run_socket(
83 mut reader: ReadHalf<Socket>,
84 address: Address,
85 usb_packet_filler: Arc<UsbPacketFiller<B>>,
86 ) {
87 let mut buf = [0; 4096];
88 loop {
89 log::trace!("reading from control socket");
90 let read = match reader.read(&mut buf).await {
91 Ok(0) => {
92 if !address.is_zeros() {
93 Self::send_close_packet(&address, &usb_packet_filler).await;
94 }
95 return;
96 }
97 Ok(read) => read,
98 Err(err) => {
99 if address.is_zeros() {
100 log::error!("Error reading usb socket: {err:?}");
101 } else {
102 Self::send_close_packet(&address, &usb_packet_filler).await;
103 }
104 return;
105 }
106 };
107 log::trace!("writing {read} bytes to vsock packet");
108 usb_packet_filler.write_vsock_data_all(&address, &buf[..read]).await;
109 log::trace!("wrote {read} bytes to vsock packet");
110 }
111 }
112
113 fn set_connection(&self, address: Address, state: VsockConnectionState) -> Result<(), Error> {
114 let mut connections = self.connections.lock().unwrap();
115 if !connections.contains_key(&address) {
116 connections.insert(address.clone(), VsockConnection { _address: address, state });
117 Ok(())
118 } else {
119 Err(Error::other(format!("connection on address {address:?} already set")))
120 }
121 }
122
123 pub async fn send_empty_echo(&self) {
126 debug!("Sending empty echo packet");
127 let header = &mut Header::new(PacketType::Echo);
128 self.packet_filler
129 .write_vsock_packet(&Packet { header, payload: &[] })
130 .await
131 .expect("empty echo packet should never be too large to fit in a usb packet");
132 }
133
134 pub async fn connect(&self, addr: Address, socket: Socket) -> Result<ConnectionState, Error> {
139 let (read_socket, write_socket) = socket.split();
140 let write_socket = Arc::new(Mutex::new(write_socket));
141 let (connected_tx, connected_rx) = oneshot::channel();
142
143 self.set_connection(
144 addr.clone(),
145 VsockConnectionState::ConnectingOutgoing(write_socket, read_socket, connected_tx),
146 )?;
147
148 let header = &mut Header::new(PacketType::Connect);
149 header.set_address(&addr);
150 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
151 connected_rx.await.map_err(|_| Error::other("Accept was never received for {addr:?}"))?
152 }
153
154 pub async fn close(&self, address: &Address) {
156 Self::send_close_packet(address, &self.packet_filler).await
157 }
158
159 pub async fn reset(&self, address: &Address) -> Result<(), Error> {
161 let mut notify = None;
162 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
163 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
164 notify = Some(notify_closed);
165 }
166 } else {
167 return Err(Error::other(
168 "Client asked to reset connection {address:?} that did not exist",
169 ));
170 }
171
172 if let Some(mut notify) = notify {
173 notify.send(Err(ErrorKind::ConnectionReset.into())).await.ok();
174 }
175
176 let header = &mut Header::new(PacketType::Reset);
177 header.set_address(address);
178 self.packet_filler
179 .write_vsock_packet(&Packet { header, payload: &[] })
180 .await
181 .expect("Reset packet should never be too big");
182 Ok(())
183 }
184
185 pub async fn accept(
189 &self,
190 request: ConnectionRequest,
191 socket: Socket,
192 ) -> Result<ConnectionState, Error> {
193 let address = request.address;
194 let notify_closed_rx;
195 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
196 let VsockConnectionState::ConnectingIncoming = &conn.state else {
197 return Err(Error::other(format!(
198 "Attempted to accept connection that was not waiting at {address:?}"
199 )));
200 };
201
202 let (read_socket, write_socket) = socket.split();
203 let writer = Arc::new(Mutex::new(write_socket));
204 let notify_closed = mpsc::channel(2);
205 notify_closed_rx = notify_closed.1;
206 let notify_closed = notify_closed.0;
207
208 let reader_task = Scope::new_with_name("connection-reader");
209 reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
210
211 conn.state = VsockConnectionState::Connected {
212 writer,
213 _reader_scope: reader_task,
214 notify_closed,
215 };
216 } else {
217 return Err(Error::other(format!(
218 "Attempting to accept connection that did not exist at {address:?}"
219 )));
220 }
221 let header = &mut Header::new(PacketType::Accept);
222 header.set_address(&address);
223 self.packet_filler.write_vsock_packet(&Packet { header, payload: &[] }).await.unwrap();
224 Ok(ConnectionState(notify_closed_rx))
225 }
226
227 pub async fn reject(&self, request: ConnectionRequest) -> Result<(), Error> {
229 let address = request.address;
230 match self.connections.lock().unwrap().entry(address.clone()) {
231 Entry::Occupied(entry) => {
232 let VsockConnectionState::ConnectingIncoming = &entry.get().state else {
233 return Err(Error::other(format!(
234 "Attempted to reject connection that was not waiting at {address:?}"
235 )));
236 };
237 entry.remove();
238 }
239 Entry::Vacant(_) => {
240 return Err(Error::other(format!(
241 "Attempted to reject connection that was not waiting at {address:?}"
242 )));
243 }
244 }
245
246 let header = &mut Header::new(PacketType::Reset);
247 header.set_address(&address);
248 self.packet_filler
249 .write_vsock_packet(&Packet { header, payload: &[] })
250 .await
251 .expect("accept packet should never be too large for packet buffer");
252 Ok(())
253 }
254
255 async fn handle_data_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
256 if address.is_zeros() {
258 let written = self.control_socket_writer.lock().await.write(payload).await?;
259 assert_eq!(written, payload.len());
260 Ok(())
261 } else {
262 let payload_socket;
263 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
264 let VsockConnectionState::Connected { writer, .. } = &conn.state else {
265 warn!(
266 "Received data packet for connection in unexpected state for {address:?}"
267 );
268 return Ok(());
269 };
270 payload_socket = writer.clone();
271 } else {
272 warn!("Received data packet for connection that didn't exist at {address:?}");
273 return Ok(());
274 }
275 if let Err(err) = payload_socket.lock().await.write_all(payload).await {
276 debug!(
277 "Write to socket address {address:?} failed, resetting connection immediately: {err:?}"
278 );
279 self.reset(&address).await.inspect_err(|err| warn!("Attempt to reset connection to {address:?} failed after write error: {err:?}")).ok();
280 }
281 Ok(())
282 }
283 }
284
285 async fn handle_echo_packet(&self, address: Address, payload: &[u8]) -> Result<(), Error> {
286 debug!("received echo for {address:?} with payload {payload:?}");
287 let header = &mut Header::new(PacketType::EchoReply);
288 header.payload_len.set(payload.len() as u32);
289 header.set_address(&address);
290 self.packet_filler
291 .write_vsock_packet(&Packet { header, payload })
292 .await
293 .map_err(|_| Error::other("Echo packet was too large to be sent back"))
294 }
295
296 async fn handle_echo_reply_packet(
297 &self,
298 address: Address,
299 payload: &[u8],
300 ) -> Result<(), Error> {
301 debug!("received echo reply for {address:?} with payload {payload:?}");
303 Ok(())
304 }
305
306 async fn handle_accept_packet(&self, address: Address) -> Result<(), Error> {
307 if let Some(conn) = self.connections.lock().unwrap().get_mut(&address) {
308 let state = std::mem::replace(&mut conn.state, VsockConnectionState::Invalid);
309 let VsockConnectionState::ConnectingOutgoing(writer, read_socket, connected_tx) = state
310 else {
311 warn!("Received accept packet for connection in unexpected state for {address:?}");
312 return Ok(());
313 };
314 let (notify_closed, notify_closed_rx) = mpsc::channel(2);
315 if connected_tx.send(Ok(ConnectionState(notify_closed_rx))).is_err() {
316 warn!(
317 "Accept packet received for {address:?} but connect caller stopped waiting for it"
318 );
319 }
320
321 let reader_task = Scope::new_with_name("connection-reader");
322 reader_task.spawn(Self::run_socket(read_socket, address, self.packet_filler.clone()));
323 conn.state = VsockConnectionState::Connected {
324 writer,
325 _reader_scope: reader_task,
326 notify_closed,
327 };
328 } else {
329 warn!("Got accept packet for connection that was not being made at {address:?}");
330 return Ok(());
331 }
332 Ok(())
333 }
334
335 async fn handle_connect_packet(&self, address: Address) -> Result<(), Error> {
336 trace!("received connect packet for {address:?}");
337 match self.connections.lock().unwrap().entry(address.clone()) {
338 Entry::Vacant(entry) => {
339 debug!("valid connect request for {address:?}");
340 entry.insert(VsockConnection {
341 _address: address,
342 state: VsockConnectionState::ConnectingIncoming,
343 });
344 }
345 Entry::Occupied(_) => {
346 warn!(
347 "Received connect packet for already existing connection for address {address:?}. Ignoring"
348 );
349 return Ok(());
350 }
351 }
352
353 trace!("sending incoming connection request to client for {address:?}");
354 let connection_request = ConnectionRequest { address };
355 self.incoming_requests_tx
356 .clone()
357 .send(connection_request)
358 .await
359 .inspect(|_| trace!("sent incoming request for {address:?}"))
360 .map_err(|_| Error::other("Failed to send connection request"))
361 }
362
363 async fn handle_finish_packet(&self, address: Address) -> Result<(), Error> {
364 trace!("received finish packet for {address:?}");
365 let mut notify;
366 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
367 let VsockConnectionState::Connected { notify_closed, .. } = conn.state else {
368 warn!(
369 "Received finish (close) packet for {address:?} which was not in a connected state. Ignoring and dropping connection state."
370 );
371 return Ok(());
372 };
373 notify = notify_closed;
374 } else {
375 warn!(
376 "Received finish (close) packet for connection that didn't exist on address {address:?}. Ignoring"
377 );
378 return Ok(());
379 }
380
381 notify.send(Ok(())).await.ok();
382
383 let header = &mut Header::new(PacketType::Reset);
384 header.set_address(&address);
385 self.packet_filler
386 .write_vsock_packet(&Packet { header, payload: &[] })
387 .await
388 .expect("accept packet should never be too large for packet buffer");
389 Ok(())
390 }
391
392 async fn handle_reset_packet(&self, address: Address) -> Result<(), Error> {
393 trace!("received reset packet for {address:?}");
394 let mut notify = None;
395 if let Some(conn) = self.connections.lock().unwrap().remove(&address) {
396 if let VsockConnectionState::Connected { notify_closed, .. } = conn.state {
397 notify = Some(notify_closed);
398 } else {
399 debug!(
400 "Received reset packet for connection that wasn't in a connecting or disconnected state on address {address:?}."
401 );
402 }
403 } else {
404 warn!(
405 "Received reset packet for connection that didn't exist on address {address:?}. Ignoring"
406 );
407 }
408
409 if let Some(mut notify) = notify {
410 notify.send(Ok(())).await.ok();
411 }
412 Ok(())
413 }
414
415 pub async fn handle_vsock_packet(&self, packet: Packet<'_>) -> Result<(), Error> {
418 trace!("received vsock packet {header:?}", header = packet.header);
419 let payload_len = packet.header.payload_len.get() as usize;
420 let payload = &packet.payload[..payload_len];
421 let address = Address::from(packet.header);
422 match packet.header.packet_type {
423 PacketType::Sync => Err(Error::other("Received sync packet mid-stream")),
424 PacketType::Data => self.handle_data_packet(address, payload).await,
425 PacketType::Accept => self.handle_accept_packet(address).await,
426 PacketType::Connect => self.handle_connect_packet(address).await,
427 PacketType::Finish => self.handle_finish_packet(address).await,
428 PacketType::Reset => self.handle_reset_packet(address).await,
429 PacketType::Echo => self.handle_echo_packet(address, payload).await,
430 PacketType::EchoReply => self.handle_echo_reply_packet(address, payload).await,
431 }
432 }
433
434 pub async fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> UsbPacketBuilder<B> {
441 self.packet_filler.fill_usb_packet(builder).await
442 }
443}
444
445enum VsockConnectionState {
446 ConnectingOutgoing(
447 Arc<Mutex<WriteHalf<Socket>>>,
448 ReadHalf<Socket>,
449 oneshot::Sender<Result<ConnectionState, Error>>,
450 ),
451 ConnectingIncoming,
452 Connected {
453 writer: Arc<Mutex<WriteHalf<Socket>>>,
454 notify_closed: mpsc::Sender<Result<(), Error>>,
455 _reader_scope: Scope,
456 },
457 Invalid,
458}
459
460struct VsockConnection {
461 _address: Address,
462 state: VsockConnectionState,
463}
464
465#[derive(Debug)]
469pub struct ConnectionState(mpsc::Receiver<Result<(), Error>>);
470
471impl ConnectionState {
472 pub async fn wait_for_close(mut self) -> Result<(), Error> {
475 self.0
476 .next()
477 .await
478 .ok_or_else(|| Error::other("Connection state's other end was dropped"))?
479 }
480}
481
482#[derive(Debug)]
485pub struct ConnectionRequest {
486 address: Address,
487}
488
489impl ConnectionRequest {
490 pub fn new(address: Address) -> Self {
492 Self { address }
493 }
494
495 pub fn address(&self) -> &Address {
497 &self.address
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use std::sync::Arc;
504
505 use crate::VsockPacketIterator;
506
507 use super::*;
508
509 #[cfg(not(target_os = "fuchsia"))]
510 use fuchsia_async::emulated_handle::Socket as SyncSocket;
511 use fuchsia_async::Task;
512 use futures::StreamExt;
513 #[cfg(target_os = "fuchsia")]
514 use zx::Socket as SyncSocket;
515
516 async fn usb_echo_server(echo_connection: Arc<Connection<Vec<u8>>>) {
517 let mut builder = UsbPacketBuilder::new(vec![0; 128]);
518 loop {
519 println!("waiting for usb packet");
520 builder = echo_connection.fill_usb_packet(builder).await;
521 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
522 println!("got usb packet, echoing it back to the other side");
523 let mut packet_count = 0;
524 for packet in packets {
525 let packet = packet.unwrap();
526 match packet.header.packet_type {
527 PacketType::Connect => {
528 let mut reply_header = packet.header.clone();
530 reply_header.packet_type = PacketType::Accept;
531 echo_connection
532 .handle_vsock_packet(Packet { header: &reply_header, payload: &[] })
533 .await
534 .unwrap();
535 }
536 PacketType::Accept => {
537 }
539 _ => echo_connection.handle_vsock_packet(packet).await.unwrap(),
540 }
541 packet_count += 1;
542 }
543 println!("handled {packet_count} packets");
544 }
545 }
546
547 #[fuchsia::test]
548 async fn data_over_control_socket() {
549 let (socket, other_socket) = SyncSocket::create_stream();
550 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
551 let mut socket = Socket::from_socket(socket);
552 let connection =
553 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
554
555 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
556
557 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
558 println!("round tripping packet of size {size}");
559 socket.write_all(&vec![size; size as usize]).await.unwrap();
560 let mut buf = vec![0u8; size as usize];
561 socket.read_exact(&mut buf).await.unwrap();
562 assert_eq!(buf, vec![size; size as usize]);
563 }
564 echo_task.abort().await;
565 }
566
567 #[fuchsia::test]
568 async fn data_over_normal_outgoing_socket() {
569 let (_control_socket, other_socket) = SyncSocket::create_stream();
570 let (incoming_requests_tx, _incoming_requests) = mpsc::channel(5);
571 let connection =
572 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
573
574 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
575
576 let (socket, other_socket) = SyncSocket::create_stream();
577 let mut socket = Socket::from_socket(socket);
578 connection
579 .connect(
580 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
581 Socket::from_socket(other_socket),
582 )
583 .await
584 .unwrap();
585
586 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
587 println!("round tripping packet of size {size}");
588 socket.write_all(&vec![size; size as usize]).await.unwrap();
589 let mut buf = vec![0u8; size as usize];
590 socket.read_exact(&mut buf).await.unwrap();
591 assert_eq!(buf, vec![size; size as usize]);
592 }
593 echo_task.abort().await;
594 }
595
596 #[fuchsia::test]
597 async fn data_over_normal_incoming_socket() {
598 let (_control_socket, other_socket) = SyncSocket::create_stream();
599 let (incoming_requests_tx, mut incoming_requests) = mpsc::channel(5);
600 let connection =
601 Arc::new(Connection::new(Socket::from_socket(other_socket), incoming_requests_tx));
602
603 let echo_task = Task::spawn(usb_echo_server(connection.clone()));
604
605 let header = &mut Header::new(PacketType::Connect);
606 header.set_address(&Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 });
607 connection.handle_vsock_packet(Packet { header, payload: &[] }).await.unwrap();
608
609 let request = incoming_requests.next().await.unwrap();
610 assert_eq!(
611 request.address,
612 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
613 );
614
615 let (socket, other_socket) = SyncSocket::create_stream();
616 let mut socket = Socket::from_socket(socket);
617 connection.accept(request, Socket::from_socket(other_socket)).await.unwrap();
618
619 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
620 println!("round tripping packet of size {size}");
621 socket.write_all(&vec![size; size as usize]).await.unwrap();
622 let mut buf = vec![0u8; size as usize];
623 socket.read_exact(&mut buf).await.unwrap();
624 assert_eq!(buf, vec![size; size as usize]);
625 }
626 echo_task.abort().await;
627 }
628
629 async fn copy_connection(from: &Connection<Vec<u8>>, to: &Connection<Vec<u8>>) {
630 let mut builder = UsbPacketBuilder::new(vec![0; 1024]);
631 loop {
632 builder = from.fill_usb_packet(builder).await;
633 let packets = VsockPacketIterator::new(builder.take_usb_packet().unwrap());
634 for packet in packets {
635 println!("forwarding vsock packet");
636 to.handle_vsock_packet(packet.unwrap()).await.unwrap();
637 }
638 }
639 }
640
641 pub(crate) trait EndToEndTestFn<R>:
642 AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
643 {
644 }
645 impl<T, R> EndToEndTestFn<R> for T where
646 T: AsyncFnOnce(Arc<Connection<Vec<u8>>>, mpsc::Receiver<ConnectionRequest>) -> R
647 {
648 }
649
650 pub(crate) async fn end_to_end_test<R1, R2>(
651 left_side: impl EndToEndTestFn<R1>,
652 right_side: impl EndToEndTestFn<R2>,
653 ) -> (R1, R2) {
654 type Connection = crate::Connection<Vec<u8>>;
655 let (_control_socket1, other_socket1) = SyncSocket::create_stream();
656 let (_control_socket2, other_socket2) = SyncSocket::create_stream();
657 let (incoming_requests_tx1, incoming_requests1) = mpsc::channel(5);
658 let (incoming_requests_tx2, incoming_requests2) = mpsc::channel(5);
659
660 let connection1 =
661 Arc::new(Connection::new(Socket::from_socket(other_socket1), incoming_requests_tx1));
662 let connection2 =
663 Arc::new(Connection::new(Socket::from_socket(other_socket2), incoming_requests_tx2));
664
665 let conn1 = connection1.clone();
666 let conn2 = connection2.clone();
667 let passthrough_task = Task::spawn(async move {
668 futures::join!(copy_connection(&conn1, &conn2), copy_connection(&conn2, &conn1),);
669 println!("passthrough task loop ended");
670 });
671
672 let res = futures::join!(
673 left_side(connection1, incoming_requests1),
674 right_side(connection2, incoming_requests2)
675 );
676 passthrough_task.abort().await;
677 res
678 }
679
680 #[fuchsia::test]
681 async fn data_over_end_to_end() {
682 end_to_end_test(
683 async |conn, _incoming| {
684 println!("sending request on connection 1");
685 let (socket, other_socket) = SyncSocket::create_stream();
686 let mut socket = Socket::from_socket(socket);
687 let state = conn
688 .connect(
689 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 },
690 Socket::from_socket(other_socket),
691 )
692 .await
693 .unwrap();
694
695 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
696 println!("round tripping packet of size {size}");
697 socket.write_all(&vec![size; size as usize]).await.unwrap();
698 }
699 drop(socket);
700 state.wait_for_close().await.unwrap();
701 },
702 async |conn, mut incoming| {
703 println!("accepting request on connection 2");
704 let request = incoming.next().await.unwrap();
705 assert_eq!(
706 request.address,
707 Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 }
708 );
709
710 let (socket, other_socket) = SyncSocket::create_stream();
711 let mut socket = Socket::from_socket(socket);
712 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
713
714 println!("accepted request on connection 2");
715 for size in [1u8, 2, 8, 16, 32, 64, 128, 255] {
716 let mut buf = vec![0u8; size as usize];
717 socket.read_exact(&mut buf).await.unwrap();
718 assert_eq!(buf, vec![size; size as usize]);
719 }
720 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
721 state.wait_for_close().await.unwrap();
722 },
723 )
724 .await;
725 }
726
727 #[fuchsia::test]
728 async fn normal_close_end_to_end() {
729 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
730 end_to_end_test(
731 async |conn, _incoming| {
732 let (socket, other_socket) = SyncSocket::create_stream();
733 let mut socket = Socket::from_socket(socket);
734 let state =
735 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
736 conn.close(&addr).await;
737 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
738 state.wait_for_close().await.unwrap();
739 },
740 async |conn, mut incoming| {
741 println!("accepting request on connection 2");
742 let request = incoming.next().await.unwrap();
743 assert_eq!(request.address, addr.clone(),);
744
745 let (socket, other_socket) = SyncSocket::create_stream();
746 let mut socket = Socket::from_socket(socket);
747 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
748 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
749 state.wait_for_close().await.unwrap();
750 },
751 )
752 .await;
753 }
754
755 #[fuchsia::test]
756 async fn reset_end_to_end() {
757 let addr = Address { device_cid: 1, host_cid: 2, device_port: 3, host_port: 4 };
758 end_to_end_test(
759 async |conn, _incoming| {
760 let (socket, other_socket) = SyncSocket::create_stream();
761 let mut socket = Socket::from_socket(socket);
762 let state =
763 conn.connect(addr.clone(), Socket::from_socket(other_socket)).await.unwrap();
764 conn.reset(&addr).await.unwrap();
765 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
766 state.wait_for_close().await.expect_err("expected reset");
767 },
768 async |conn, mut incoming| {
769 println!("accepting request on connection 2");
770 let request = incoming.next().await.unwrap();
771 assert_eq!(request.address, addr.clone(),);
772
773 let (socket, other_socket) = SyncSocket::create_stream();
774 let mut socket = Socket::from_socket(socket);
775 let state = conn.accept(request, Socket::from_socket(other_socket)).await.unwrap();
776 assert_eq!(socket.read(&mut [0u8; 1]).await.unwrap(), 0);
777 state.wait_for_close().await.unwrap();
778 },
779 )
780 .await;
781 }
782}