1use fdf_component::{driver_register, Driver, DriverContext, Node};
6use fidl::endpoints::create_endpoints;
7use fuchsia_async::scope::ScopeStream;
8use fuchsia_async::{Scope, Socket, TimeoutExt};
9use fuchsia_component::server::ServiceFs;
10use futures::channel::{mpsc, oneshot};
11use futures::future::{select, Either};
12use futures::io::{ReadHalf, WriteHalf};
13use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, StreamExt, TryStreamExt};
14use log::{debug, error, info, warn};
15use std::io::{Error, ErrorKind};
16use std::pin::pin;
17use std::sync::Arc;
18use std::time::Duration;
19use usb_vsock::{
20 Connection, ConnectionRequest, Header, Packet, PacketType, UsbPacketBuilder,
21 VsockPacketIterator, CID_HOST, VSOCK_MAGIC,
22};
23use zx::{SocketOpts, Status};
24use {fidl_fuchsia_hardware_overnet as overnet, fidl_fuchsia_hardware_vsock as vsock};
25
26mod vsock_service;
27
28use vsock_service::VsockService;
29
30static MTU: usize = 1024;
31
32struct UsbVsockServiceDriver {
33 _scope: Scope,
35 _node: Node,
38}
39
40driver_register!(UsbVsockServiceDriver);
41
42struct UsbConnection {
46 vsock_service: Arc<VsockService<Vec<u8>>>,
47 usb_socket_reader: ReadHalf<Socket>,
48 usb_socket_writer: WriteHalf<Socket>,
49 connection_tx: mpsc::Sender<ConnectionRequest>,
50}
51
52impl UsbConnection {
53 fn new(
54 vsock_service: Arc<VsockService<Vec<u8>>>,
55 usb_socket: zx::Socket,
56 connection_tx: mpsc::Sender<ConnectionRequest>,
57 ) -> Self {
58 assert!(
59 usb_socket.info().unwrap().options.contains(SocketOpts::DATAGRAM),
60 "USB socket must be a datagram socket"
61 );
62 let (usb_socket_reader, usb_socket_writer) = Socket::from_socket(usb_socket).split();
63 Self { vsock_service, usb_socket_reader, usb_socket_writer, connection_tx }
64 }
65
66 async fn clear_host_requests(&mut self, found_magic: &[u8]) -> Option<()> {
69 let mut data = [0; MTU];
70 for _ in 0..10 {
71 let header = &mut Header::new(PacketType::Echo);
72 header.payload_len.set(found_magic.len() as u32);
73 let packet = Packet { header, payload: &found_magic };
74 packet.write_to_unchecked(&mut data);
75 if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
76 error!("Error writing echo to the usb socket: {err:?}");
77 return None;
78 }
79 let next_packet = read_packet_stream(&mut self.usb_socket_reader, &mut data)
80 .on_timeout(Duration::from_millis(100), || Err(ErrorKind::TimedOut.into()))
81 .await;
82 let mut packets = match next_packet {
83 Ok(None) => {
84 debug!("Usb socket closed");
85 return None;
86 }
87 Err(err) if err.kind() == ErrorKind::TimedOut => {
88 error!("Timed out waiting for matching packet, trying again");
89 continue;
90 }
91 Err(err) => {
92 error!("Unexpected error on usb socket: {err}");
93 return None;
94 }
95 Ok(Some(packets)) => packets,
96 };
97
98 while let Some(packet) = packets.next() {
99 match packet {
102 Ok(Packet {
103 header: Header { packet_type: PacketType::EchoReply, .. },
104 payload,
105 }) => {
106 if payload == found_magic {
107 debug!("host replied to echo packet and it was received, continuing synchronization");
108 return Some(());
109 } else {
110 warn!("Got echo reply with incorrect payload, ignoring.")
111 }
112 }
113 Ok(packet) => {
114 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
115 }
116 Err(err) => {
117 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
118 }
119 }
120 }
121 }
122 warn!("Failed to receive echo response in time, giving up but still trying to connect");
124 Some(())
125 }
126
127 async fn next_socket(&mut self, mut found_magic: Option<Vec<u8>>) -> Option<u32> {
130 let mut data = [0; MTU];
131 while found_magic.is_none() {
132 let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
133 {
134 Ok(None) => {
135 debug!("Usb socket closed");
136 return None;
137 }
138 Err(err) => {
139 error!("Unexpected error on usb socket: {err}");
140 return None;
141 }
142 Ok(Some(packets)) => packets,
143 };
144
145 while let Some(packet) = packets.next() {
146 match packet {
149 Ok(Packet {
150 header: Header { packet_type: PacketType::Sync, .. },
151 payload,
152 }) => {
153 found_magic = Some(payload.to_owned());
154 }
155 Ok(packet) => {
156 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
157 }
158 Err(err) => {
159 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
160 }
161 }
162 }
163 }
164 let found_magic =
165 found_magic.expect("read loop should not terminate until sync packet is read");
166
167 self.clear_host_requests(&found_magic).await?;
171
172 debug!("Read sync packet, sending it back and setting up a new link");
173 let mut header = Header::new(PacketType::Sync);
174 header.payload_len = (found_magic.len() as u32).into();
175 header.device_cid.set(self.vsock_service.current_cid());
176 header.host_cid.set(CID_HOST);
177 let packet = Packet { header: &header, payload: VSOCK_MAGIC };
178 packet.write_to_unchecked(&mut data);
179 if let Err(err) = self.usb_socket_writer.write(&data[..packet.size()]).await {
180 error!("Error writing overnet magic string to the usb socket: {err:?}");
181 return None;
182 }
183 loop {
185 let mut packets = match read_packet_stream(&mut self.usb_socket_reader, &mut data).await
186 {
187 Ok(None) => {
188 debug!("Usb socket closed");
189 return None;
190 }
191 Err(err) => {
192 error!("Unexpected error on usb socket: {err}");
193 return None;
194 }
195 Ok(Some(packets)) => packets,
196 };
197
198 while let Some(packet) = packets.next() {
199 match packet {
202 Ok(Packet {
203 header: Header { packet_type: PacketType::Sync, device_cid, .. },
204 payload,
205 }) => {
206 if payload != VSOCK_MAGIC {
207 error!("Host gave unsupported protocol version string {payload:?}. Giving up.");
208 return None;
209 }
210 return Some(device_cid.get());
211 }
212 Ok(packet) => {
213 warn!("Got unexpected packet of type {:?} and length {} while waiting for sync packet. Ignoring.", packet.header.packet_type, packet.header.payload_len);
214 }
215 Err(err) => {
216 warn!("Got invalid vsock packet while waiting for sync packet: {err:?}");
217 }
218 }
219 }
220 }
221 }
222
223 async fn run(mut self, mut synchronized: Option<oneshot::Sender<()>>) {
224 let mut found_magic = None;
225 loop {
226 let Some(cid) = self.next_socket(found_magic).await else {
227 info!("USB socket closed or failed");
228 return;
229 };
230 found_magic = None;
232 info!("Bridge established with CID {cid}");
233 let (control_socket, _other_end) = zx::Socket::create_stream();
237 let connection = Arc::new(Connection::new(
238 Socket::from_socket(control_socket),
239 self.connection_tx.clone(),
240 ));
241 self.vsock_service.set_connection(connection.clone(), cid).await;
242 log::trace!("vsock connection set");
243 if let Some(synchronized) = synchronized.take() {
244 synchronized.send(()).ok();
246 }
247 let usb_socket_writer =
248 usb_socket_writer::<MTU>(&connection, &mut self.usb_socket_writer);
249 let usb_socket_reader = usb_socket_reader::<MTU>(
250 &mut found_magic,
251 &mut self.usb_socket_reader,
252 &connection,
253 );
254 let client_socket_copy = pin!(usb_socket_writer);
255 let usb_socket_copy = pin!(usb_socket_reader);
256 let res = select(client_socket_copy, usb_socket_copy).await;
257 match res {
258 Either::Left((Err(err), _)) => {
259 warn!("Error on client to usb socket transfer: {err:?}");
260 }
261 Either::Left((Ok(_), _)) => {
262 debug!("client to usb socket closed normally");
263 }
264 Either::Right((Err(err), _)) => {
265 warn!("Error on usb to client socket transfer: {err:?}");
266 }
267 Either::Right((Ok(_), _)) => {
268 info!("usb to client socket closed normally");
269 }
270 }
271 }
272 }
273}
274
275async fn read_packet_stream<'a>(
276 reader: &mut (impl AsyncRead + Unpin),
277 mut buffer: &'a mut [u8],
278) -> Result<Option<VsockPacketIterator<'a>>, std::io::Error> {
279 let size = reader.read(&mut buffer).await?;
280 if size == 0 {
281 return Ok(None);
282 }
283 Ok(Some(VsockPacketIterator::new(&buffer[0..size])))
284}
285
286async fn usb_socket_writer<const MTU: usize>(
287 connection: &Connection<Vec<u8>>,
288 usb_writer: &mut (impl AsyncWrite + Unpin),
289) -> Result<(), Error> {
290 let mut builder = UsbPacketBuilder::new(vec![0; MTU]);
291 loop {
292 builder = connection.fill_usb_packet(builder).await;
293 let buf = builder.take_usb_packet().unwrap();
294 assert_eq!(
295 buf.len(),
296 usb_writer.write(buf).await?,
297 "datagram socket sent incomplete packet"
298 );
299 }
300}
301
302async fn usb_socket_reader<const MTU: usize>(
303 found_magic: &mut Option<Vec<u8>>,
304 usb_reader: &mut (impl AsyncRead + Unpin),
305 connection: &Connection<Vec<u8>>,
306) -> Result<(), Error> {
307 let mut data = [0; MTU];
308 loop {
309 let Some(mut packets) = read_packet_stream(usb_reader, &mut data).await? else {
310 break;
311 };
312 while let Some(packet) = packets.next() {
313 match packet {
314 Ok(Packet { header: Header { packet_type: PacketType::Sync, .. }, payload }) => {
315 debug!("Found sync packet, ending stream");
316 *found_magic = Some(payload.to_owned());
317 return Ok(());
318 }
319 Ok(packet) => connection.handle_vsock_packet(packet).await?,
320 Err(err) => {
321 error!("Failed to parse vsock packet, going back to waiting for sync packet: {err:?}");
322 break;
323 }
324 }
325 }
326 }
327 Ok(())
328}
329
330struct UsbCallbackHandler {
333 usb_callback_server: overnet::CallbackRequestStream,
334 connection_tx: mpsc::Sender<ConnectionRequest>,
335}
336
337impl UsbCallbackHandler {
338 async fn run(
339 mut self,
340 vsock_service: Arc<VsockService<Vec<u8>>>,
341 mut synchronized: Option<oneshot::Sender<()>>,
342 ) -> Result<(), fidl::Error> {
343 use overnet::CallbackRequest::*;
344 while let Some(req) = self.usb_callback_server.try_next().await? {
345 let NewLink { socket, responder } = req;
346 responder.send()?;
347
348 debug!("Received new socket from usb driver");
349 UsbConnection::new(vsock_service.clone(), socket, self.connection_tx.clone())
350 .run(synchronized.take())
351 .await;
352 }
353 Ok(())
354 }
355}
356
357impl Driver for UsbVsockServiceDriver {
358 const NAME: &str = "usb-vsock-service";
359
360 async fn start(mut context: DriverContext) -> Result<Self, Status> {
361 let node = context.take_node()?;
362 let scope = Scope::new_with_name(Self::NAME);
363 let mut outgoing = ServiceFs::new();
364
365 let usb_device = get_usb_device(&context)?;
366
367 info!("Offering a vsock service in the outgoing directory");
368 outgoing.dir("svc").add_fidl_service_instance("default", move |i| {
369 let vsock::ServiceRequest::Device(request_stream) = i;
370 request_stream
371 });
372
373 context.serve_outgoing(&mut outgoing)?;
374
375 scope.spawn(async move {
376 while let Some(request_stream) = outgoing.next().await {
377 let (usb_callback, usb_callback_server) = create_endpoints();
378 usb_device.set_callback(usb_callback).await.expect("usb device service went away");
379
380 run_connection(usb_callback_server.into_stream(), request_stream, None).await
381 }
382 });
383
384 Ok(Self { _scope: scope, _node: node })
385 }
386
387 async fn stop(&self) {}
388}
389
390async fn run_connection(
391 usb_callback_server: overnet::CallbackRequestStream,
392 mut request_stream: vsock::DeviceRequestStream,
393 synchronized: Option<oneshot::Sender<()>>,
394) {
395 debug!("Waiting for start message on vsock implementation service");
396 let (connection_tx, incoming_connections) = mpsc::channel(1);
397 let svc = match VsockService::wait_for_start(incoming_connections, &mut request_stream).await {
398 Ok(svc) => svc,
399 Err(err) => {
400 error!("Error while waiting for start message from vsock client: {err:?}");
401 return;
402 }
403 };
404 debug!(
405 "Received start message on vsock implementation service, waiting for usb socket handles"
406 );
407
408 let svc = Arc::new(svc);
409 let (mut scopes_stream, scopes) = ScopeStream::new_with_name("usb-vsock-connection".to_owned());
410
411 let usb_callback_handler =
412 UsbCallbackHandler { usb_callback_server, connection_tx: connection_tx.clone() };
413 let usb_svc = svc.clone();
414 scopes.push(async move {
415 if let Err(err) = usb_callback_handler.run(usb_svc, synchronized).await {
416 error!("Error while waiting for usb device callbacks: {err:?}");
417 }
418 });
419 scopes.push(async move {
420 if let Err(err) = svc.run(request_stream).await {
421 error!("Error while servicing vsock client: {err:?}");
422 }
423 });
424 scopes_stream.next().await;
426}
427
428fn get_usb_device(context: &DriverContext) -> Result<overnet::UsbProxy, Status> {
429 let service_proxy = context.incoming.service_marker(overnet::UsbServiceMarker).connect()?;
430
431 service_proxy.connect_to_device().map_err(|err| {
432 error!("Error connecting to usb device proxy at driver startup: {err}");
433 Status::INTERNAL
434 })
435}
436
437#[cfg(test)]
438mod tests {
439 use fidl::endpoints::create_endpoints;
440 use fidl_fuchsia_vsock as vsock_api;
441 use futures::future::join;
442 use log::trace;
443 use usb_vsock::CID_ANY;
444
445 use super::*;
446
447 async fn end_to_end_test(
448 device_side: impl AsyncFn(vsock_api::ConnectorProxy),
449 host_side: impl AsyncFn(Arc<Connection<Vec<u8>>>, u32, mpsc::Receiver<ConnectionRequest>),
450 ) {
451 let scope = Scope::new();
452 let (vsock_impl_client, vsock_impl_server) = create_endpoints::<vsock::DeviceMarker>();
453 let (usb_callback_client, usb_callback_server) =
454 create_endpoints::<overnet::CallbackMarker>();
455 let (started_tx, started_rx) = oneshot::channel();
456 scope.spawn(run_connection(
457 usb_callback_server.into_stream(),
458 vsock_impl_server.into_stream(),
459 Some(started_tx),
460 ));
461 let usb_callback_client = usb_callback_client.into_proxy();
462
463 let (vsock_api_service, vsock_api_future) =
464 vsock_service_lib::Vsock::new(Some(vsock_impl_client.into_proxy()), None)
465 .await
466 .unwrap();
467 scope.spawn_local(async move {
468 vsock_api_future.await.unwrap();
469 });
470
471 let (vsock_api_client, vsock_api_server) = create_endpoints::<vsock_api::ConnectorMarker>();
472 scope.spawn_local(vsock_api_service.run_client_connection(vsock_api_server.into_stream()));
473 let vsock_api_client = vsock_api_client.into_proxy();
474
475 let (usb_packet_socket, usb_packet_server) = zx::Socket::create_datagram();
476 let (mut usb_packet_reader, mut usb_packet_writer) =
477 Socket::from_socket(usb_packet_socket).split();
478 usb_callback_client.new_link(usb_packet_server).await.unwrap();
479
480 let (incoming_tx, incoming_rx) = mpsc::channel(1);
481 let (_control_socket, other_end) = zx::Socket::create_stream();
482 let host_connection =
483 Arc::new(Connection::new(Socket::from_socket(other_end), incoming_tx));
484
485 let header = &mut Header::new(PacketType::Sync);
487 let payload = VSOCK_MAGIC;
488 header.host_cid.set(CID_HOST);
489 header.device_cid.set(CID_ANY);
490 header.payload_len.set(payload.len() as u32);
491 let sync_packet = Packet { header, payload };
492 let mut buf = [0; 1024];
493 sync_packet.write_to_unchecked(&mut buf);
494 assert_eq!(
495 sync_packet.size(),
496 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
497 );
498
499 let mut buf = vec![0; 4096];
501 loop {
502 let packet = read_packet_stream(&mut usb_packet_reader, &mut buf)
503 .await
504 .unwrap()
505 .unwrap()
506 .next()
507 .unwrap()
508 .unwrap();
509 trace!("received packet {packet:?}");
510 match packet.header.packet_type {
511 PacketType::Sync => {
512 assert_eq!(packet.payload, VSOCK_MAGIC);
513 assert_eq!(packet.header.device_cid.get(), 3);
514 assert_eq!(packet.header.host_cid.get(), CID_HOST);
515 break;
516 }
517 PacketType::Echo => {
518 let header = &mut Header::new(PacketType::EchoReply);
519 let payload = packet.payload;
520 header.payload_len.set(payload.len() as u32);
521 let sync_packet = Packet { header, payload };
522 let mut buf = [0; 1024];
523 sync_packet.write_to_unchecked(&mut buf);
524 assert_eq!(
525 sync_packet.size(),
526 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
527 );
528 }
529 other => panic!("Unexpected packet type while syncing {other:?}"),
530 }
531 }
532
533 let device_cid = 300;
535 let header = &mut Header::new(PacketType::Sync);
536 let payload = VSOCK_MAGIC;
537 header.host_cid.set(CID_HOST);
538 header.device_cid.set(device_cid);
539 header.payload_len.set(payload.len() as u32);
540 let sync_packet = Packet { header, payload };
541 let mut buf = [0; 1024];
542 sync_packet.write_to_unchecked(&mut buf);
543 assert_eq!(
544 sync_packet.size(),
545 usb_packet_writer.write(&buf[..sync_packet.size()]).await.unwrap()
546 );
547
548 started_rx.await.unwrap();
549
550 let writer_connection = host_connection.clone();
551 scope.spawn(async move {
552 let mut buf = UsbPacketBuilder::new(vec![0; 4096]);
553 loop {
554 buf = writer_connection.fill_usb_packet(buf).await;
555 let buf = buf.take_usb_packet().unwrap();
556 for packet in VsockPacketIterator::new(buf) {
557 let packet = packet.unwrap();
558 trace!("sending packet {packet:?}");
559 }
560 let _ = usb_packet_writer.write(buf).await.unwrap();
561 }
562 });
563
564 let reader_connection = host_connection.clone();
565 scope.spawn(async move {
566 let mut buf = vec![0; 4096];
567 while let Ok(bytes) = usb_packet_reader.read(&mut buf).await {
568 for packet in VsockPacketIterator::new(&buf[..bytes]) {
569 let packet = packet.unwrap();
570 trace!("received packet {packet:?}");
571 reader_connection.handle_vsock_packet(packet).await.unwrap();
572 }
573 }
574 });
575
576 let device = device_side(vsock_api_client);
577 let host = host_side(host_connection, device_cid, incoming_rx);
578 join(device, host).await;
579 }
580
581 #[fuchsia::test(allow_stalls = false)]
582 async fn test_device_to_host_connection() {
583 end_to_end_test(
584 async move |vsock_api_client| {
585 let (socket, data) = zx::Socket::create_stream();
586 let mut socket = Socket::from_socket(socket);
587 let (_con, con) = create_endpoints();
588 vsock_api_client
589 .connect(CID_HOST, 200, vsock_api::ConnectionTransport { data, con })
590 .await
591 .unwrap()
592 .map_err(Status::from_raw)
593 .unwrap();
594 let mut buf = [0; 4];
595 socket.read_exact(&mut buf).await.unwrap();
596 assert_eq!(&buf, b"boom");
597 socket.write_all(b"zoom").await.unwrap();
598 assert_eq!(0, socket.read(&mut buf).await.unwrap());
599 trace!("vsock api fin");
600 },
601 async move |host_connection, _device_cid, mut incoming_rx| {
602 let incoming = incoming_rx.next().await.unwrap();
603 trace!("{incoming:?}");
604 let (socket, other_end) = zx::Socket::create_stream();
605 let mut socket = Socket::from_socket(socket);
606 let _state =
607 host_connection.accept(incoming, Socket::from_socket(other_end)).await.unwrap();
608 socket.write_all(b"boom").await.unwrap();
609 let mut buf = [0; 4];
610 socket.read_exact(&mut buf).await.unwrap();
611 assert_eq!(&buf, b"zoom");
612 trace!("host fin");
613 },
614 )
615 .await;
616 }
617
618 #[fuchsia::test(allow_stalls = false)]
619 async fn test_host_to_device_connection() {
620 end_to_end_test(
621 async move |vsock_api_client| {
622 let (other_end, acceptor) = create_endpoints::<vsock_api::AcceptorMarker>();
623 let mut acceptor = acceptor.into_stream();
624 vsock_api_client.listen(200, other_end).await.unwrap().unwrap();
625 let vsock_api::AcceptorRequest::Accept { addr, responder } =
626 acceptor.next().await.unwrap().unwrap();
627 assert_eq!(
628 addr,
629 vsock::Addr { local_port: 200, remote_cid: CID_HOST, remote_port: 9000 }
630 );
631
632 let (socket, data) = zx::Socket::create_stream();
633 let mut socket = Socket::from_socket(socket);
634 let (_con, con) = create_endpoints();
635 responder.send(Some(vsock_api::ConnectionTransport { data, con })).unwrap();
636
637 let mut buf = [0; 4];
638 socket.read_exact(&mut buf).await.unwrap();
639 assert_eq!(&buf, b"boom");
640 socket.write_all(b"zoom").await.unwrap();
641 assert_eq!(0, socket.read(&mut buf).await.unwrap());
642 trace!("vsock api fin");
643 },
644 async move |host_connection, device_cid, _incoming_rx| {
645 let (socket, other_end) = zx::Socket::create_stream();
646 let mut socket = Socket::from_socket(socket);
647 let _state = host_connection
648 .connect(
649 usb_vsock::Address {
650 host_cid: CID_HOST,
651 host_port: 9000,
652 device_cid,
653 device_port: 200,
654 },
655 Socket::from_socket(other_end),
656 )
657 .await
658 .unwrap();
659
660 socket.write_all(b"boom").await.unwrap();
661 let mut buf = [0; 4];
662 socket.read_exact(&mut buf).await.unwrap();
663 assert_eq!(&buf, b"zoom");
664 trace!("host fin");
665 },
666 )
667 .await;
668 }
669}