1use fdf_component::{driver_register, Driver, DriverContext, Node};
6use fidl::endpoints::create_endpoints;
7use fuchsia_async::scope::ScopeStream;
8use fuchsia_async::{Scope, Socket, TimeoutExt};
9use fuchsia_component::server::ServiceFs;
10use futures::channel::{mpsc, oneshot};
11use futures::future::{select, Either};
12use futures::io::{ReadHalf, WriteHalf};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt};
14use log::{debug, error, info, warn};
15use std::io::{Error, ErrorKind};
16use std::pin::pin;
17use std::sync::Arc;
18use std::time::Duration;
19use usb_vsock::{
20 Connection, ConnectionRequest, Header, Packet, PacketType, ProtocolVersion, UsbPacketBuilder,
21 VsockPacketIterator, CID_HOST,
22};
23use zx::{SocketOpts, Status};
24use {fidl_fuchsia_hardware_overnet as overnet, fidl_fuchsia_hardware_vsock as vsock};
25
26mod vsock_service;
27
28use vsock_service::VsockService;
29
30static MTU: usize = 1024;
31
32struct UsbVsockServiceDriver {
33 _scope: Scope,
35 _node: Node,
38}
39
40driver_register!(UsbVsockServiceDriver);
41
42struct UsbConnection {
46 vsock_service: Arc<VsockService<Vec<u8>>>,
47 usb_socket_reader: ReadHalf<Socket>,
48 usb_socket_writer: WriteHalf<Socket>,
49 connection_tx: mpsc::Sender<ConnectionRequest>,
50}
51
52impl UsbConnection {
53 fn new(
54 vsock_service: Arc<VsockService<Vec<u8>>>,
55 usb_socket: zx::Socket,
56 connection_tx: mpsc::Sender<ConnectionRequest>,
57 ) -> Self {
58 assert!(
59 usb_socket.info().unwrap().options.contains(SocketOpts::DATAGRAM),
60 "USB socket must be a datagram socket"
61 );
62 let (usb_socket_reader, usb_socket_writer) = Socket::from_socket(usb_socket).split();
63 Self { vsock_service, usb_socket_reader, usb_socket_writer, connection_tx }
64 }
65
66 async fn clear_host_requests(&mut self, found_magic: &[u8]) -> Option<()> {
69 let mut data = [0; MTU];
70 for _ in 0..10 {
71 let header = &mut Header::new(PacketType::Echo);
72 header.payload_len.set(found_magic.len() as u32);
73 let packet = Packet { header, payload: &found_magic };
74 packet.write_to_unchecked(&mut data);
75 if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
76 error!("Error writing echo to the usb socket: {err:?}");
77 return None;
78 }
79 let next_packet = read_packet_stream(&mut self.usb_socket_reader, &mut data)
80 .on_timeout(Duration::from_millis(100), || Err(ErrorKind::TimedOut.into()))
81 .await;
82 let mut packets = match next_packet {
83 Ok(None) => {
84 debug!("Usb socket closed");
85 return None;
86 }
87 Err(err) if err.kind() == ErrorKind::TimedOut => {
88 error!("Timed out waiting for matching packet, trying again");
89 continue;
90 }
91 Err(err) => {
92 error!("Unexpected error on usb socket: {err}");
93 return None;
94 }
95 Ok(Some(packets)) => packets,
96 };
97
98 while let Some(packet) = packets.next() {
99 match packet {
102 Ok(Packet {
103 header: Header { packet_type: PacketType::EchoReply, .. },
104 payload,
105 }) => {
106 if payload == found_magic {
107 debug!("host replied to echo packet and it was received, continuing synchronization");
108 return Some(());
109 } else {
110 warn!("Got echo reply with incorrect payload, ignoring.")
111 }
112 }
113 Ok(packet) => {
114 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
115 }
116 Err(err) => {
117 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
118 }
119 }
120 }
121 }
122 warn!("Failed to receive echo response in time, giving up but still trying to connect");
124 Some(())
125 }
126
127 async fn next_socket(
130 &mut self,
131 mut found_magic: Option<Vec<u8>>,
132 ) -> Option<(ProtocolVersion, u32)> {
133 let mut data = [0; MTU];
134 while found_magic.is_none() {
135 let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
136 {
137 Ok(None) => {
138 debug!("Usb socket closed");
139 return None;
140 }
141 Err(err) => {
142 error!("Unexpected error on usb socket: {err}");
143 return None;
144 }
145 Ok(Some(packets)) => packets,
146 };
147
148 while let Some(packet) = packets.next() {
149 match packet {
152 Ok(Packet {
153 header: Header { packet_type: PacketType::Sync, .. },
154 payload,
155 }) => {
156 found_magic = Some(payload.to_owned());
157 }
158 Ok(packet) => {
159 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
160 }
161 Err(err) => {
162 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
163 }
164 }
165 }
166 }
167 let found_magic =
168 found_magic.expect("read loop should not terminate until sync packet is read");
169
170 let incoming_version = ProtocolVersion::from_magic(&found_magic);
171
172 self.clear_host_requests(&found_magic).await?;
176
177 let outgoing_version = if let Some(incoming_version) = incoming_version {
178 let Some(outgoing_version) = ProtocolVersion::LATEST.negotiate(&incoming_version)
179 else {
180 error!("Could not negotiate protocol version: driver has {}, host wants {incoming_version}", ProtocolVersion::LATEST);
181 return None;
182 };
183
184 outgoing_version
185 } else {
186 warn!(
187 "Did not recognize protocol magic {found_magic:?}. Trying version {}",
188 ProtocolVersion::LATEST
189 );
190
191 ProtocolVersion::LATEST
192 };
193
194 let outgoing_magic = outgoing_version.magic();
195
196 debug!("Read sync packet, sending it back and setting up a new link");
197 let mut header = Header::new(PacketType::Sync);
198 header.payload_len = (outgoing_magic.len() as u32).into();
199 header.device_cid.set(self.vsock_service.current_cid());
200 header.host_cid.set(CID_HOST);
201 let packet = Packet { header: &header, payload: outgoing_magic };
202 packet.write_to_unchecked(&mut data);
203 if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
204 error!("Error writing overnet magic string to the usb socket: {err:?}");
205 return None;
206 }
207 loop {
209 let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
210 {
211 Ok(None) => {
212 debug!("Usb socket closed");
213 return None;
214 }
215 Err(err) => {
216 error!("Unexpected error on usb socket: {err}");
217 return None;
218 }
219 Ok(Some(packets)) => packets,
220 };
221
222 while let Some(packet) = packets.next() {
223 match packet {
226 Ok(Packet {
227 header: Header { packet_type: PacketType::Sync, device_cid, .. },
228 payload,
229 }) => {
230 if payload != outgoing_magic {
231 error!("Host gave unsupported protocol version string {payload:?}. Giving up.");
232 return None;
233 }
234 return Some((outgoing_version, device_cid.get()));
235 }
236 Ok(packet) => {
237 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
238 }
239 Err(err) => {
240 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
241 }
242 }
243 }
244 }
245 }
246
247 async fn run(mut self, mut synchronized: Option<oneshot::Sender<()>>) {
248 let mut found_magic = None;
249 loop {
250 let Some((protocol_version, cid)) = self.next_socket(found_magic).await else {
251 info!("USB socket closed or failed");
252 return;
253 };
254 found_magic = None;
256 info!("Bridge established with CID {cid}");
257 let connection =
261 Arc::new(Connection::new(protocol_version, None, self.connection_tx.clone()));
262 self.vsock_service.set_connection(connection.clone(), cid).await;
263 log::trace!("vsock connection set");
264 if let Some(synchronized) = synchronized.take() {
265 synchronized.send(()).ok();
267 }
268 let usb_socket_writer =
269 usb_socket_writer::<MTU>(&connection, &mut self.usb_socket_writer);
270 let usb_socket_reader = usb_socket_reader::<MTU>(
271 &mut found_magic,
272 &mut self.usb_socket_reader,
273 &connection,
274 );
275 let client_socket_copy = pin!(usb_socket_writer);
276 let usb_socket_copy = pin!(usb_socket_reader);
277 let res = select(client_socket_copy, usb_socket_copy).await;
278 match res {
279 Either::Left((Err(err), _)) => {
280 warn!("Error on client to usb socket transfer: {err:?}");
281 }
282 Either::Left((Ok(_), _)) => {
283 debug!("client to usb socket closed normally");
284 }
285 Either::Right((Err(err), _)) => {
286 warn!("Error on usb to client socket transfer: {err:?}");
287 }
288 Either::Right((Ok(_), _)) => {
289 info!("usb to client socket closed normally");
290 }
291 }
292 }
293 }
294}
295
296async fn read_packet_stream<'a>(
297 reader: &mut (impl AsyncRead + Unpin),
298 mut buffer: &'a mut [u8],
299) -> Result<Option<VsockPacketIterator<'a>>, std::io::Error> {
300 let size = reader.read(&mut buffer).await?;
301 if size == 0 {
302 return Ok(None);
303 }
304 Ok(Some(VsockPacketIterator::new(&buffer[0..size])))
305}
306
307async fn usb_socket_writer<const MTU: usize>(
308 connection: &Connection<Vec<u8>, Socket>,
309 usb_writer: &mut (impl AsyncWrite + Unpin),
310) -> Result<(), Error> {
311 let mut builder = UsbPacketBuilder::new(vec![0; MTU]);
312 loop {
313 builder = connection.fill_usb_packet(builder).await;
314 let buf = builder.take_usb_packet().unwrap();
315 assert_eq!(
316 buf.len(),
317 usb_writer.write(buf).await?,
318 "datagram socket sent incomplete packet"
319 );
320 }
321}
322
323async fn usb_socket_reader<const MTU: usize>(
324 found_magic: &mut Option<Vec<u8>>,
325 usb_reader: &mut (impl AsyncRead + Unpin),
326 connection: &Connection<Vec<u8>, Socket>,
327) -> Result<(), Error> {
328 let mut data = [0; MTU];
329 loop {
330 let Some(mut packets) = read_packet_stream(usb_reader, &mut data).await? else {
331 break;
332 };
333 while let Some(packet) = packets.next() {
334 match packet {
335 Ok(Packet { header: Header { packet_type: PacketType::Sync, .. }, payload }) => {
336 debug!("Found sync packet, ending stream");
337 *found_magic = Some(payload.to_owned());
338 return Ok(());
339 }
340 Ok(packet) => connection.handle_vsock_packet(packet).await?,
341 Err(err) => {
342 error!("Failed to parse vsock packet, going back to waiting for sync packet: {err:?}");
343 break;
344 }
345 }
346 }
347 }
348 Ok(())
349}
350
351struct UsbCallbackHandler {
354 usb_callback_server: overnet::CallbackRequestStream,
355 connection_tx: mpsc::Sender<ConnectionRequest>,
356}
357
358impl UsbCallbackHandler {
359 async fn run(
360 mut self,
361 vsock_service: Arc<VsockService<Vec<u8>>>,
362 mut synchronized: Option<oneshot::Sender<()>>,
363 ) -> Result<(), fidl::Error> {
364 use overnet::CallbackRequest::*;
365 while let Some(req) = self.usb_callback_server.try_next().await? {
366 let NewLink { socket, responder } = req;
367 responder.send()?;
368
369 debug!("Received new socket from usb driver");
370 UsbConnection::new(vsock_service.clone(), socket, self.connection_tx.clone())
371 .run(synchronized.take())
372 .await;
373 }
374 Ok(())
375 }
376}
377
378impl Driver for UsbVsockServiceDriver {
379 const NAME: &str = "usb-vsock-service";
380
381 async fn start(mut context: DriverContext) -> Result<Self, Status> {
382 let node = context.take_node()?;
383 let scope = Scope::new_with_name(Self::NAME);
384 let mut outgoing = ServiceFs::new();
385
386 let usb_device = get_usb_device(&context)?;
387
388 info!("Offering a vsock service in the outgoing directory");
389 outgoing.dir("svc").add_fidl_service_instance("default", move |i| {
390 let vsock::ServiceRequest::Device(request_stream) = i;
391 request_stream
392 });
393
394 context.serve_outgoing(&mut outgoing)?;
395
396 scope.spawn(async move {
397 while let Some(request_stream) = outgoing.next().await {
398 let (usb_callback, usb_callback_server) = create_endpoints();
399 usb_device.set_callback(usb_callback).await.expect("usb device service went away");
400
401 run_connection(usb_callback_server.into_stream(), request_stream, None).await
402 }
403 });
404
405 Ok(Self { _scope: scope, _node: node })
406 }
407
408 async fn stop(&self) {}
409}
410
411async fn run_connection(
412 usb_callback_server: overnet::CallbackRequestStream,
413 mut request_stream: vsock::DeviceRequestStream,
414 synchronized: Option<oneshot::Sender<()>>,
415) {
416 debug!("Waiting for start message on vsock implementation service");
417 let (connection_tx, incoming_connections) = mpsc::channel(1);
418 let svc = match VsockService::wait_for_start(incoming_connections, &mut request_stream).await {
419 Ok(svc) => svc,
420 Err(err) => {
421 error!("Error while waiting for start message from vsock client: {err:?}");
422 return;
423 }
424 };
425 debug!(
426 "Received start message on vsock implementation service, waiting for usb socket handles"
427 );
428
429 let svc = Arc::new(svc);
430 let (mut scopes_stream, scopes) = ScopeStream::new_with_name("usb-vsock-connection".to_owned());
431
432 let usb_callback_handler =
433 UsbCallbackHandler { usb_callback_server, connection_tx: connection_tx.clone() };
434 let usb_svc = svc.clone();
435 scopes.push(async move {
436 if let Err(err) = usb_callback_handler.run(usb_svc, synchronized).await {
437 error!("Error while waiting for usb device callbacks: {err:?}");
438 }
439 });
440 scopes.push(async move {
441 if let Err(err) = svc.run(request_stream).await {
442 error!("Error while servicing vsock client: {err:?}");
443 }
444 });
445 scopes_stream.next().await;
447}
448
449fn get_usb_device(context: &DriverContext) -> Result<overnet::UsbProxy, Status> {
450 let service_proxy = context.incoming.service_marker(overnet::UsbServiceMarker).connect()?;
451
452 service_proxy.connect_to_device().map_err(|err| {
453 error!("Error connecting to usb device proxy at driver startup: {err}");
454 Status::INTERNAL
455 })
456}
457
458#[cfg(test)]
459mod tests {
460 use fidl::endpoints::create_endpoints;
461 use fidl_fuchsia_vsock as vsock_api;
462 use futures::future::join;
463 use log::trace;
464 use usb_vsock::CID_ANY;
465
466 use super::*;
467
468 async fn end_to_end_test(
469 device_side: impl AsyncFn(vsock_api::ConnectorProxy),
470 host_side: impl AsyncFn(
471 Arc<Connection<Vec<u8>, Socket>>,
472 u32,
473 mpsc::Receiver<ConnectionRequest>,
474 ),
475 ) {
476 let scope = Scope::new();
477 let (vsock_impl_client, vsock_impl_server) = create_endpoints::<vsock::DeviceMarker>();
478 let (usb_callback_client, usb_callback_server) =
479 create_endpoints::<overnet::CallbackMarker>();
480 let (started_tx, started_rx) = oneshot::channel();
481 scope.spawn(run_connection(
482 usb_callback_server.into_stream(),
483 vsock_impl_server.into_stream(),
484 Some(started_tx),
485 ));
486 let usb_callback_client = usb_callback_client.into_proxy();
487
488 let (vsock_api_service, vsock_api_future) =
489 vsock_service_lib::Vsock::new(Some(vsock_impl_client.into_proxy()), None)
490 .await
491 .unwrap();
492 scope.spawn_local(async move {
493 vsock_api_future.await.unwrap();
494 });
495
496 let (vsock_api_client, vsock_api_server) = create_endpoints::<vsock_api::ConnectorMarker>();
497 scope.spawn_local(vsock_api_service.run_client_connection(vsock_api_server.into_stream()));
498 let vsock_api_client = vsock_api_client.into_proxy();
499
500 let (usb_packet_socket, usb_packet_server) = zx::Socket::create_datagram();
501 let (mut usb_packet_reader, mut usb_packet_writer) =
502 Socket::from_socket(usb_packet_socket).split();
503 usb_callback_client.new_link(usb_packet_server).await.unwrap();
504
505 let (incoming_tx, incoming_rx) = mpsc::channel(1);
506 let host_connection = Arc::new(Connection::new(ProtocolVersion::LATEST, None, incoming_tx));
507
508 let header = &mut Header::new(PacketType::Sync);
510 let payload = ProtocolVersion::LATEST.magic();
511 header.host_cid.set(CID_HOST);
512 header.device_cid.set(CID_ANY);
513 header.payload_len.set(payload.len() as u32);
514 let sync_packet = Packet { header, payload };
515 let mut buf = [0; 1024];
516 sync_packet.write_to_unchecked(&mut buf);
517 assert_eq!(
518 sync_packet.size(),
519 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
520 );
521
522 let mut buf = vec![0; 4096];
524 loop {
525 let packet = read_packet_stream(&mut usb_packet_reader, &mut buf)
526 .await
527 .unwrap()
528 .unwrap()
529 .next()
530 .unwrap()
531 .unwrap();
532 trace!("received packet {packet:?}");
533 match packet.header.packet_type {
534 PacketType::Sync => {
535 assert_eq!(packet.payload, ProtocolVersion::LATEST.magic());
536 assert_eq!(packet.header.device_cid.get(), 3);
537 assert_eq!(packet.header.host_cid.get(), CID_HOST);
538 break;
539 }
540 PacketType::Echo => {
541 let header = &mut Header::new(PacketType::EchoReply);
542 let payload = packet.payload;
543 header.payload_len.set(payload.len() as u32);
544 let sync_packet = Packet { header, payload };
545 let mut buf = [0; 1024];
546 sync_packet.write_to_unchecked(&mut buf);
547 assert_eq!(
548 sync_packet.size(),
549 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
550 );
551 }
552 other => panic!("Unexpected packet type while syncing {other:?}"),
553 }
554 }
555
556 let device_cid = 300;
558 let header = &mut Header::new(PacketType::Sync);
559 let payload = ProtocolVersion::LATEST.magic();
560 header.host_cid.set(CID_HOST);
561 header.device_cid.set(device_cid);
562 header.payload_len.set(payload.len() as u32);
563 let sync_packet = Packet { header, payload };
564 let mut buf = [0; 1024];
565 sync_packet.write_to_unchecked(&mut buf);
566 assert_eq!(
567 sync_packet.size(),
568 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
569 );
570
571 started_rx.await.unwrap();
572
573 let writer_connection = host_connection.clone();
574 scope.spawn(async move {
575 let mut buf = UsbPacketBuilder::new(vec![0; 4096]);
576 loop {
577 buf = writer_connection.fill_usb_packet(buf).await;
578 let buf = buf.take_usb_packet().unwrap();
579 for packet in VsockPacketIterator::new(buf) {
580 let packet = packet.unwrap();
581 trace!("sending packet {packet:?}");
582 }
583 let _ = usb_packet_writer.write(buf).await.unwrap();
584 }
585 });
586
587 let reader_connection = host_connection.clone();
588 scope.spawn(async move {
589 let mut buf = vec![0; 4096];
590 while let Ok(bytes) = usb_packet_reader.read(&mut buf).await {
591 for packet in VsockPacketIterator::new(&buf[..bytes]) {
592 let packet = packet.unwrap();
593 trace!("received packet {packet:?}");
594 reader_connection.handle_vsock_packet(packet).await.unwrap();
595 }
596 }
597 });
598
599 let device = device_side(vsock_api_client);
600 let host = host_side(host_connection, device_cid, incoming_rx);
601 join(device, host).await;
602 }
603
604 #[fuchsia::test(allow_stalls = false)]
605 async fn test_device_to_host_connection() {
606 end_to_end_test(
607 async move |vsock_api_client| {
608 let (socket, data) = zx::Socket::create_stream();
609 let mut socket = Socket::from_socket(socket);
610 let (_con, con) = create_endpoints();
611 vsock_api_client
612 .connect(CID_HOST, 200, vsock_api::ConnectionTransport { data, con })
613 .await
614 .unwrap()
615 .map_err(Status::from_raw)
616 .unwrap();
617 let mut buf = [0; 4];
618 socket.read_exact(&mut buf).await.unwrap();
619 assert_eq!(&buf, b"boom");
620 socket.write_all(b"zoom").await.unwrap();
621 assert_eq!(0, socket.read(&mut buf).await.unwrap());
622 trace!("vsock api fin");
623 },
624 async move |host_connection, _device_cid, mut incoming_rx| {
625 let incoming = incoming_rx.next().await.unwrap();
626 trace!("{incoming:?}");
627 let (socket, other_end) = zx::Socket::create_stream();
628 let mut socket = Socket::from_socket(socket);
629 let _state =
630 host_connection.accept(incoming, Socket::from_socket(other_end)).await.unwrap();
631 socket.write_all(b"boom").await.unwrap();
632 let mut buf = [0; 4];
633 socket.read_exact(&mut buf).await.unwrap();
634 assert_eq!(&buf, b"zoom");
635 trace!("host fin");
636 },
637 )
638 .await;
639 }
640
641 #[fuchsia::test(allow_stalls = false)]
642 async fn test_host_to_device_connection() {
643 end_to_end_test(
644 async move |vsock_api_client| {
645 let (other_end, acceptor) = create_endpoints::<vsock_api::AcceptorMarker>();
646 let mut acceptor = acceptor.into_stream();
647 vsock_api_client.listen(200, other_end).await.unwrap().unwrap();
648 let vsock_api::AcceptorRequest::Accept { addr, responder } =
649 acceptor.next().await.unwrap().unwrap();
650 assert_eq!(
651 addr,
652 vsock::Addr { local_port: 200, remote_cid: CID_HOST, remote_port: 9000 }
653 );
654
655 let (socket, data) = zx::Socket::create_stream();
656 let mut socket = Socket::from_socket(socket);
657 let (_con, con) = create_endpoints();
658 responder.send(Some(vsock_api::ConnectionTransport { data, con })).unwrap();
659
660 let mut buf = [0; 4];
661 socket.read_exact(&mut buf).await.unwrap();
662 assert_eq!(&buf, b"boom");
663 socket.write_all(b"zoom").await.unwrap();
664 assert_eq!(0, socket.read(&mut buf).await.unwrap());
665 trace!("vsock api fin");
666 },
667 async move |host_connection, device_cid, _incoming_rx| {
668 let (socket, other_end) = zx::Socket::create_stream();
669 let mut socket = Socket::from_socket(socket);
670 let _state = host_connection
671 .connect(
672 usb_vsock::Address {
673 host_cid: CID_HOST,
674 host_port: 9000,
675 device_cid,
676 device_port: 200,
677 },
678 Socket::from_socket(other_end),
679 )
680 .await
681 .unwrap();
682
683 socket.write_all(b"boom").await.unwrap();
684 let mut buf = [0; 4];
685 socket.read_exact(&mut buf).await.unwrap();
686 assert_eq!(&buf, b"zoom");
687 trace!("host fin");
688 },
689 )
690 .await;
691 }
692}