1use std::cmp::min;
6use std::future::Future;
7use std::iter::FusedIterator;
8use std::ops::DerefMut;
9use std::pin::Pin;
10use std::sync::{Mutex, MutexGuard};
11use std::task::{Context, Poll, Waker};
12
13use futures::task::AtomicWaker;
14use zerocopy::{Immutable, IntoBytes, KnownLayout, TryFromBytes, Unaligned, little_endian};
15
16use crate::Address;
17
18#[repr(u8)]
21#[derive(
22 Debug,
23 TryFromBytes,
24 IntoBytes,
25 KnownLayout,
26 Immutable,
27 Unaligned,
28 PartialEq,
29 Eq,
30 PartialOrd,
31 Ord,
32 Hash,
33 Clone,
34 Copy,
35)]
36pub enum PacketType {
37 Sync = b'S',
43 Echo = b'E',
46 EchoReply = b'e',
49 Connect = b'C',
51 Finish = b'F',
55 Reset = b'R',
60 Accept = b'A',
63 Data = b'D',
68 Pause = b'X',
73}
74
75#[repr(C, packed(1))]
78#[derive(
79 Debug,
80 TryFromBytes,
81 IntoBytes,
82 KnownLayout,
83 Immutable,
84 Unaligned,
85 PartialEq,
86 Eq,
87 PartialOrd,
88 Ord,
89 Hash,
90 Clone,
91)]
92pub struct Header {
93 magic: [u8; 3],
94 pub packet_type: PacketType,
96 pub device_cid: little_endian::U32,
101 pub host_cid: little_endian::U32,
106 pub device_port: little_endian::U32,
111 pub host_port: little_endian::U32,
116 pub payload_len: little_endian::U32,
119}
120
121impl Header {
122 pub const SIZE: usize = size_of::<Self>();
124 const MAGIC: &'static [u8; 3] = b"ffx";
125
126 pub fn new(packet_type: PacketType) -> Self {
129 let device_cid = 0.into();
130 let host_cid = 0.into();
131 let device_port = 0.into();
132 let host_port = 0.into();
133 let payload_len = 0.into();
134 Header {
135 magic: *Self::MAGIC,
136 packet_type,
137 device_cid,
138 host_cid,
139 device_port,
140 host_port,
141 payload_len,
142 }
143 }
144
145 pub fn packet_size(&self) -> usize {
148 Packet::size_with_payload(self.payload_len.get() as usize)
149 }
150
151 pub fn set_address(&mut self, addr: &Address) {
153 self.device_cid.set(addr.device_cid);
154 self.host_cid.set(addr.host_cid);
155 self.device_port.set(addr.device_port);
156 self.host_port.set(addr.host_port);
157 }
158}
159
160#[derive(Debug, Eq, PartialEq, PartialOrd, Ord)]
162pub struct Packet<'a> {
163 pub header: &'a Header,
165 pub payload: &'a [u8],
167}
168
169impl<'a> Packet<'a> {
170 pub fn size(&self) -> usize {
173 self.header.packet_size()
174 }
175
176 fn size_with_payload(payload_size: usize) -> usize {
177 size_of::<Header>() + payload_size
178 }
179
180 fn parse_next(buf: &'a [u8]) -> Result<(Self, &'a [u8]), std::io::Error> {
181 let Some((header, body)) = buf.split_at_checked(size_of::<Header>()) else {
183 return Err(std::io::Error::other("insufficient data for last packet"));
184 };
185 let header = Header::try_ref_from_bytes(header).map_err(|err| {
186 std::io::Error::other(format!("failed to parse usb vsock header: {err:?}"))
187 })?;
188 if header.magic != *Header::MAGIC {
189 return Err(std::io::Error::other(format!("invalid magic bytes on usb vsock header")));
190 }
191 let payload_len = Into::<u64>::into(header.payload_len) as usize;
193 let body_len = body.len();
194 if payload_len > body_len {
195 return Err(std::io::Error::other(format!(
196 "payload length on usb vsock header ({payload_len}) was larger than available in buffer {body_len}"
197 )));
198 }
199
200 let (payload, remain) = body.split_at(payload_len);
201 Ok((Packet { header, payload }, remain))
202 }
203
204 pub fn write_to_unchecked(&'a self, buf: &'a mut [u8]) -> &'a mut [u8] {
213 let (packet, remain) = buf.split_at_mut(self.size());
214 let payload_len = u32::from(self.header.payload_len) as usize;
215 self.header.write_to_prefix(packet).unwrap();
216 self.payload[..payload_len].write_to_suffix(packet).unwrap();
217 remain
218 }
219}
220
221#[derive(Debug, Eq, PartialEq, PartialOrd, Ord)]
223pub struct PacketMut<'a> {
224 pub header: &'a mut Header,
226 pub payload: &'a mut [u8],
228}
229
230impl<'a> PacketMut<'a> {
231 pub fn new_in(packet_type: PacketType, buf: &'a mut [u8]) -> Self {
243 Header::new(packet_type)
244 .write_to_prefix(buf)
245 .expect("not enough room in buffer for packet header");
246 let (header_bytes, payload) = buf.split_at_mut(Header::SIZE);
247 let header = Header::try_mut_from_bytes(header_bytes).unwrap();
248 PacketMut { header, payload }
249 }
250
251 pub fn finish(self, payload_len: usize) -> Result<usize, PacketTooBigError> {
254 if payload_len <= self.payload.len() {
255 self.header.payload_len.set(u32::try_from(payload_len).map_err(|_| PacketTooBigError)?);
256 Ok(Header::SIZE + payload_len)
257 } else {
258 Err(PacketTooBigError)
259 }
260 }
261}
262
263pub struct VsockPacketIterator<'a> {
265 buf: Option<&'a [u8]>,
266}
267
268impl<'a> VsockPacketIterator<'a> {
269 pub fn new(buf: &'a [u8]) -> Self {
272 Self { buf: Some(buf) }
273 }
274}
275
276impl<'a> FusedIterator for VsockPacketIterator<'a> {}
277impl<'a> Iterator for VsockPacketIterator<'a> {
278 type Item = Result<Packet<'a>, std::io::Error>;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 let data = self.buf.take()?;
283
284 if data.len() == 0 {
286 return None;
287 }
288
289 match Packet::parse_next(data) {
290 Ok((header, rest)) => {
291 self.buf = Some(rest);
293 Some(Ok(header))
294 }
295 Err(err) => Some(Err(err)),
296 }
297 }
298}
299
300pub struct UsbPacketBuilder<B> {
303 buffer: B,
304 offset: usize,
305 space_waker: AtomicWaker,
306 packet_waker: AtomicWaker,
307}
308
309#[derive(Debug, Copy, Clone)]
311pub struct PacketTooBigError;
312
313impl<B> UsbPacketBuilder<B> {
314 pub fn new(buffer: B) -> Self {
318 let offset = 0;
319 let space_waker = AtomicWaker::default();
320 let packet_waker = AtomicWaker::default();
321 Self { buffer, offset, space_waker, packet_waker }
322 }
323
324 pub fn has_data(&self) -> bool {
326 self.offset > 0
327 }
328}
329
330impl<B> UsbPacketBuilder<B>
331where
332 B: std::ops::DerefMut<Target = [u8]>,
333{
334 pub fn available(&self) -> usize {
336 self.buffer.len() - self.offset
337 }
338
339 pub fn write_vsock_packet(&mut self, packet: &Packet<'_>) -> Result<(), PacketTooBigError> {
343 let packet_size = packet.size();
344 if self.available() >= packet_size {
345 packet.write_to_unchecked(&mut self.buffer[self.offset..]);
346 self.offset += packet_size;
347 self.packet_waker.wake();
348 Ok(())
349 } else {
350 Err(PacketTooBigError)
351 }
352 }
353
354 pub fn take_usb_packet(&mut self) -> Option<&mut [u8]> {
358 let written = self.offset;
359 if written == 0 {
360 return None;
361 }
362 self.offset = 0;
363 self.space_waker.wake();
364 Some(&mut self.buffer[0..written])
365 }
366}
367
368pub(crate) struct UsbPacketFiller<B> {
369 current_out_packet: Mutex<Option<UsbPacketBuilder<B>>>,
370 out_packet_wakers: Mutex<Vec<Waker>>,
371 filled_packet_waker: AtomicWaker,
372}
373
374impl<B> Default for UsbPacketFiller<B> {
375 fn default() -> Self {
376 let current_out_packet = Mutex::default();
377 let out_packet_wakers = Mutex::default();
378 let filled_packet_waker = AtomicWaker::default();
379 Self { current_out_packet, out_packet_wakers, filled_packet_waker }
380 }
381}
382
383impl<B: DerefMut<Target = [u8]> + Unpin> UsbPacketFiller<B> {
384 fn wait_for_fillable(&self, min_packet_size: usize) -> WaitForFillable<'_, B> {
385 WaitForFillable { filler: &self, min_packet_size }
386 }
387
388 pub async fn write_vsock_packet(&self, packet: &Packet<'_>) -> Result<(), PacketTooBigError> {
389 let mut builder = self.wait_for_fillable(packet.size()).await;
390 builder.as_mut().unwrap().write_vsock_packet(packet)?;
391 self.filled_packet_waker.wake();
392 Ok(())
393 }
394
395 pub async fn write_vsock_data(&self, address: &Address, payload: &[u8]) -> usize {
396 let header = &mut Header::new(PacketType::Data);
397 header.set_address(&address);
398 let mut builder = self.wait_for_fillable(Header::SIZE + 1).await;
399 let builder = builder.as_mut().unwrap();
400 let writing = min(payload.len(), builder.available() - Header::SIZE);
401 header.payload_len.set(writing as u32);
402 builder.write_vsock_packet(&Packet { header, payload: &payload[..writing] }).unwrap();
403 self.filled_packet_waker.wake();
404 writing
405 }
406
407 pub async fn write_vsock_data_all(&self, address: &Address, payload: &[u8]) {
408 let mut written = 0;
409 while written < payload.len() {
410 written += self.write_vsock_data(address, &payload[written..]).await;
411 }
412 }
413
414 pub fn fill_usb_packet(&self, builder: UsbPacketBuilder<B>) -> FillUsbPacket<'_, B> {
421 FillUsbPacket(&self, Some(builder))
422 }
423}
424
425pub(crate) struct FillUsbPacket<'a, B>(&'a UsbPacketFiller<B>, Option<UsbPacketBuilder<B>>);
426
427impl<'a, B: Unpin> Future for FillUsbPacket<'a, B> {
428 type Output = UsbPacketBuilder<B>;
429
430 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
431 if let Some(builder) = self.1.take() {
434 if builder.has_data() {
436 return Poll::Ready(builder);
437 }
438
439 let mut current_out_packet = self.0.current_out_packet.lock().unwrap();
440 assert!(current_out_packet.is_none(), "Can't fill more than one packet at a time");
441 current_out_packet.replace(builder);
442
443 self.0.filled_packet_waker.register(cx.waker());
445 drop(current_out_packet);
447
448 let mut wakers = self.0.out_packet_wakers.lock().unwrap();
451 for waker in wakers.drain(..) {
452 waker.wake();
453 }
454 Poll::Pending
455 } else {
456 let mut current_out_packet = self.0.current_out_packet.lock().unwrap();
457 let Some(builder) = current_out_packet.take() else {
458 panic!("Packet builder was somehow removed from connection prematurely");
459 };
460
461 if builder.has_data() {
462 self.0.filled_packet_waker.wake();
463 Poll::Ready(builder)
464 } else {
465 current_out_packet.replace(builder);
468 Poll::Pending
469 }
470 }
471 }
472}
473
474pub(crate) struct WaitForFillable<'a, B> {
475 filler: &'a UsbPacketFiller<B>,
476 min_packet_size: usize,
477}
478
479impl<'a, B: DerefMut<Target = [u8]> + Unpin> Future for WaitForFillable<'a, B> {
480 type Output = MutexGuard<'a, Option<UsbPacketBuilder<B>>>;
481
482 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
483 let current_out_packet = self.filler.current_out_packet.lock().unwrap();
484 let Some(builder) = &*current_out_packet else {
485 self.filler.out_packet_wakers.lock().unwrap().push(cx.waker().clone());
486 return Poll::Pending;
487 };
488 if builder.available() >= self.min_packet_size {
489 Poll::Ready(current_out_packet)
490 } else {
491 self.filler.out_packet_wakers.lock().unwrap().push(cx.waker().clone());
492 Poll::Pending
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use std::sync::Arc;
500
501 use super::*;
502 use fuchsia_async::Task;
503 use futures::poll;
504
505 async fn assert_pending<F: Future>(fut: F) {
506 let fut = std::pin::pin!(fut);
507 if let Poll::Ready(_) = poll!(fut) {
508 panic!("Future was ready when it shouldn't have been");
509 }
510 }
511 #[test]
512 fn packet_write_doesnt_clobber_header_with_incorrect_len() {
513 let header = &mut Header::new(PacketType::Pause);
514 let payload = &[1];
515 let packet = Packet { header, payload };
516 let mut buf = vec![0; packet.size()];
517 packet.write_to_unchecked(&mut buf);
518
519 let read_packet = Packet::parse_next(&buf).unwrap().0;
520 assert_eq!(&read_packet.header, &packet.header);
521 }
522
523 #[fuchsia::test]
524 async fn roundtrip_packet() {
525 let payload = b"hello world!";
526 let packet = Packet {
527 payload,
528 header: &Header {
529 device_cid: 1.into(),
530 host_cid: 2.into(),
531 device_port: 3.into(),
532 host_port: 4.into(),
533 payload_len: little_endian::U32::from(payload.len() as u32),
534 ..Header::new(PacketType::Data)
535 },
536 };
537 let buffer = vec![0; packet.size()];
538 let builder = UsbPacketBuilder::new(buffer);
539 let filler = UsbPacketFiller::default();
540 let mut filled_fut = filler.fill_usb_packet(builder);
541 println!("we should not be ready to pull a usb packet off yet");
542 assert_pending(&mut filled_fut).await;
543
544 println!("we should be able to write a packet though ({} bytes)", packet.size());
545 filler.write_vsock_packet(&packet).await.unwrap();
546
547 println!("we shouldn't have any space for another packet now");
548 assert_pending(filler.wait_for_fillable(1)).await;
549
550 println!("but we should have a new usb packet available");
551 let mut builder = filled_fut.await;
552 let buffer = builder.take_usb_packet().unwrap();
553
554 println!("the packet we get back out should be the same one we put in");
555 let (read_packet, remain) = Packet::parse_next(buffer).unwrap();
556 assert_eq!(packet, read_packet);
557 assert!(remain.is_empty());
558 }
559
560 #[fuchsia::test]
561 async fn many_packets() {
562 fn make_numbered_packet(num: u32) -> (Header, String) {
563 let payload = format!("packet #{num}!");
564 let header = Header {
565 device_cid: num.into(),
566 device_port: num.into(),
567 host_cid: num.into(),
568 host_port: num.into(),
569 payload_len: little_endian::U32::from(payload.len() as u32),
570 ..Header::new(PacketType::Data)
571 };
572 (header, payload)
573 }
574 const BUFFER_SIZE: usize = 256;
575 let mut builder = UsbPacketBuilder::new(vec![0; BUFFER_SIZE]);
576 let filler = Arc::new(UsbPacketFiller::default());
577
578 let send_filler = filler.clone();
579 let send_task = Task::spawn(async move {
580 for packet_num in 0..1024 {
581 let next_packet = make_numbered_packet(packet_num);
582 let next_packet =
583 Packet { header: &next_packet.0, payload: next_packet.1.as_ref() };
584 send_filler.write_vsock_packet(&next_packet).await.unwrap();
585 }
586 });
587
588 let mut read_packet_num = 0;
589 while read_packet_num < 1024 {
590 builder = filler.fill_usb_packet(builder).await;
591 let buffer = builder.take_usb_packet().unwrap();
592 let mut num_packets = 0;
593 for packet in VsockPacketIterator::new(&buffer) {
594 let packet_compare = make_numbered_packet(read_packet_num);
595 let packet_compare =
596 Packet { header: &packet_compare.0, payload: &packet_compare.1.as_ref() };
597 assert_eq!(packet.unwrap(), packet_compare);
598 read_packet_num += 1;
599 num_packets += 1;
600 }
601 println!(
602 "Read {num_packets} vsock packets from usb packet buffer, had {count} bytes left",
603 count = BUFFER_SIZE - buffer.len()
604 );
605 }
606 send_task.await;
607 assert_eq!(1024, read_packet_num);
608 }
609
610 #[fuchsia::test]
611 async fn packet_fillable_futures() {
612 let filler = UsbPacketFiller::default();
613
614 for _ in 0..10 {
615 println!("register an interest in filling a usb packet");
616 let mut fillable_fut = filler.wait_for_fillable(1);
617 println!("make sure we have nothing to fill");
618 assert!(poll!(&mut fillable_fut).is_pending());
619
620 println!("register a packet for filling");
621 let mut filled_fut = filler.fill_usb_packet(UsbPacketBuilder::new(vec![0; 1024]));
622 println!("make sure we've registered the buffer");
623 assert!(poll!(&mut filled_fut).is_pending());
624
625 println!("now put some things in the packet");
626 let header = &mut Header::new(PacketType::Data);
627 header.payload_len.set(99);
628 let Poll::Ready(mut builder) = poll!(fillable_fut) else {
629 panic!("should have been ready to fill a packet")
630 };
631 builder
632 .as_mut()
633 .unwrap()
634 .write_vsock_packet(&Packet { header, payload: &[b'a'; 99] })
635 .unwrap();
636 drop(builder);
637 let Poll::Ready(mut builder) = poll!(filler.wait_for_fillable(1)) else {
638 panic!("should have been ready to fill a packet(2)")
639 };
640 builder
641 .as_mut()
642 .unwrap()
643 .write_vsock_packet(&Packet { header, payload: &[b'a'; 99] })
644 .unwrap();
645 drop(builder);
646
647 println!("but if we ask for too much space we'll get pending");
648 assert!(poll!(filler.wait_for_fillable(1024 - (99 * 2) + 1)).is_pending());
649
650 println!("and now resolve the filled future and get our data back");
651 let mut filled = filled_fut.await;
652 let packets =
653 Vec::from_iter(VsockPacketIterator::new(filled.take_usb_packet().unwrap()));
654 assert_eq!(packets.len(), 2);
655 }
656 }
657}