1use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream};
10use futures::task::{Context, Poll, Waker};
11use futures::{ready, Future, FutureExt, TryFutureExt};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33 MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36 ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37 Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40#[derive(Debug, Clone)]
50pub struct Peer {
51 inner: Arc<PeerInner>,
52}
53
54impl Peer {
55 pub fn new(signaling: Channel) -> Self {
57 Self {
58 inner: Arc::new(PeerInner {
59 signaling,
60 response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61 incoming_requests: Mutex::<RequestQueue>::default(),
62 }),
63 }
64 }
65
66 #[track_caller]
69 pub fn take_request_stream(&self) -> RequestStream {
70 {
71 let mut lock = self.inner.incoming_requests.lock();
72 if let RequestListener::None = lock.listener {
73 lock.listener = RequestListener::New;
74 } else {
75 panic!("Request stream has already been taken");
76 }
77 }
78
79 RequestStream::new(self.inner.clone())
80 }
81
82 pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> {
86 self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87 }
88
89 pub fn get_capabilities(
96 &self,
97 stream_id: &StreamEndpointId,
98 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
99 let stream_params = &[stream_id.to_msg()];
100 self.send_command::<GetCapabilitiesResponse>(
101 SignalIdentifier::GetCapabilities,
102 stream_params,
103 )
104 .ok_into()
105 }
106
107 pub fn get_all_capabilities(
113 &self,
114 stream_id: &StreamEndpointId,
115 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
116 let stream_params = &[stream_id.to_msg()];
117 self.send_command::<GetCapabilitiesResponse>(
118 SignalIdentifier::GetAllCapabilities,
119 stream_params,
120 )
121 .ok_into()
122 }
123
124 pub fn set_configuration(
131 &self,
132 stream_id: &StreamEndpointId,
133 local_stream_id: &StreamEndpointId,
134 capabilities: &[ServiceCapability],
135 ) -> impl Future<Output = Result<()>> {
136 assert!(!capabilities.is_empty(), "must set at least one capability");
137 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138 params[0] = stream_id.to_msg();
139 params[1] = local_stream_id.to_msg();
140 let mut idx = 2;
141 for capability in capabilities {
142 if let Err(e) = capability.encode(&mut params[idx..]) {
143 return futures::future::err(e).left_future();
144 }
145 idx += capability.encoded_len();
146 }
147 self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, ¶ms)
148 .ok_into()
149 .right_future()
150 }
151
152 pub fn get_configuration(
158 &self,
159 stream_id: &StreamEndpointId,
160 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
161 let stream_params = &[stream_id.to_msg()];
162 self.send_command::<GetCapabilitiesResponse>(
163 SignalIdentifier::GetConfiguration,
164 stream_params,
165 )
166 .ok_into()
167 }
168
169 pub fn reconfigure(
178 &self,
179 stream_id: &StreamEndpointId,
180 capabilities: &[ServiceCapability],
181 ) -> impl Future<Output = Result<()>> {
182 assert!(!capabilities.is_empty(), "must set at least one capability");
183 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184 params[0] = stream_id.to_msg();
185 let mut idx = 1;
186 for capability in capabilities {
187 if !capability.is_application() {
188 return futures::future::err(Error::Encoding).left_future();
189 }
190 if let Err(e) = capability.encode(&mut params[idx..]) {
191 return futures::future::err(e).left_future();
192 }
193 idx += capability.encoded_len();
194 }
195 self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, ¶ms)
196 .ok_into()
197 .right_future()
198 }
199
200 pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
204 let stream_params = &[stream_id.to_msg()];
205 self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206 }
207
208 pub fn start(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
213 let mut stream_params = Vec::with_capacity(stream_ids.len());
214 for stream_id in stream_ids {
215 stream_params.push(stream_id.to_msg());
216 }
217 self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
218 }
219
220 pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
223 let stream_params = &[stream_id.to_msg()];
224 let response: CommandResponseFut<SimpleResponse> =
225 self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
226 response.ok_into()
227 }
228
229 pub fn suspend(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
233 let mut stream_params = Vec::with_capacity(stream_ids.len());
234 for stream_id in stream_ids {
235 stream_params.push(stream_id.to_msg());
236 }
237 let response: CommandResponseFut<SimpleResponse> =
238 self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
239 response.ok_into()
240 }
241
242 pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
247 let stream_params = &[stream_id.to_msg()];
248 self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
249 }
250
251 pub fn delay_report(
255 &self,
256 stream_id: &StreamEndpointId,
257 delay: u16,
258 ) -> impl Future<Output = Result<()>> {
259 let delay_bytes: [u8; 2] = delay.to_be_bytes();
260 let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
261 self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
262 }
263
264 const RTX_SIG_TIMER_MS: i64 = 3000;
266 const COMMAND_TIMEOUT: MonotonicDuration =
267 MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
268
269 fn send_command<D: Decodable<Error = Error>>(
272 &self,
273 signal: SignalIdentifier,
274 payload: &[u8],
275 ) -> CommandResponseFut<D> {
276 let send_result = (|| {
277 let id = self.inner.add_response_waiter()?;
278 let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
279 let mut buf = vec![0; header.encoded_len()];
280 header.encode(buf.as_mut_slice())?;
281 buf.extend_from_slice(payload);
282 self.inner.send_signal(buf.as_slice())?;
283 Ok(header)
284 })();
285
286 CommandResponseFut::new(send_result, self.inner.clone())
287 }
288}
289
290struct CommandResponseFut<D: Decodable> {
292 id: SignalIdentifier,
293 fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
294 _phantom: PhantomData<D>,
295}
296
297impl<D: Decodable> Unpin for CommandResponseFut<D> {}
298
299impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
300 fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
301 let header = match send_result {
302 Err(e) => {
303 return Self {
304 id: SignalIdentifier::Abort,
305 fut: Box::pin(MaybeDone::Done(Err(e))),
306 _phantom: PhantomData,
307 }
308 }
309 Ok(header) => header,
310 };
311 let response = CommandResponse { id: header.label(), inner: Some(inner) };
312 let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
313 let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
314
315 Self {
316 id: header.signal(),
317 fut: Box::pin(futures::future::maybe_done(timedout_fut)),
318 _phantom: PhantomData,
319 }
320 }
321}
322
323impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
324 type Output = Result<D>;
325
326 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
327 ready!(self.fut.poll_unpin(cx));
328 Poll::Ready(
329 self.fut
330 .as_mut()
331 .take_output()
332 .unwrap_or(Err(Error::AlreadyReceived))
333 .and_then(|buf| decode_signaling_response(self.id, buf)),
334 )
335 }
336}
337
338#[derive(Debug)]
343pub enum Request {
344 Discover {
345 responder: DiscoverResponder,
346 },
347 GetCapabilities {
348 stream_id: StreamEndpointId,
349 responder: GetCapabilitiesResponder,
350 },
351 GetAllCapabilities {
352 stream_id: StreamEndpointId,
353 responder: GetCapabilitiesResponder,
354 },
355 SetConfiguration {
356 local_stream_id: StreamEndpointId,
357 remote_stream_id: StreamEndpointId,
358 capabilities: Vec<ServiceCapability>,
359 responder: ConfigureResponder,
360 },
361 GetConfiguration {
362 stream_id: StreamEndpointId,
363 responder: GetCapabilitiesResponder,
364 },
365 Reconfigure {
366 local_stream_id: StreamEndpointId,
367 capabilities: Vec<ServiceCapability>,
368 responder: ConfigureResponder,
369 },
370 Open {
371 stream_id: StreamEndpointId,
372 responder: SimpleResponder,
373 },
374 Start {
375 stream_ids: Vec<StreamEndpointId>,
376 responder: StreamResponder,
377 },
378 Close {
379 stream_id: StreamEndpointId,
380 responder: SimpleResponder,
381 },
382 Suspend {
383 stream_ids: Vec<StreamEndpointId>,
384 responder: StreamResponder,
385 },
386 Abort {
387 stream_id: StreamEndpointId,
388 responder: SimpleResponder,
389 },
390 DelayReport {
391 stream_id: StreamEndpointId,
392 delay: u16,
393 responder: SimpleResponder,
394 }, }
396
397macro_rules! parse_one_seid {
398 ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
399 if $body.len() != 1 {
400 Err(Error::RequestInvalid(ErrorCode::BadLength))
401 } else {
402 Ok(Request::$request_variant {
403 stream_id: StreamEndpointId::from_msg(&$body[0]),
404 responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
405 })
406 }
407 };
408}
409
410impl Request {
411 fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
412 if body.len() < 1 {
413 return Err(Error::RequestInvalid(ErrorCode::BadLength));
414 }
415 Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
416 }
417
418 fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
419 if encoded.len() < 2 {
420 return Err(Error::RequestInvalid(ErrorCode::BadLength));
421 }
422 let mut caps = vec![];
423 let mut loc = 0;
424 while loc < encoded.len() {
425 let cap = match ServiceCapability::decode(&encoded[loc..]) {
426 Ok(cap) => cap,
427 Err(Error::RequestInvalid(code)) => {
428 return Err(Error::RequestInvalidExtra(code, encoded[loc]));
429 }
430 Err(e) => return Err(e),
431 };
432 loc += cap.encoded_len();
433 caps.push(cap);
434 }
435 Ok(caps)
436 }
437
438 fn parse(
439 peer: &Arc<PeerInner>,
440 id: TxLabel,
441 signal: SignalIdentifier,
442 body: &[u8],
443 ) -> Result<Request> {
444 match signal {
445 SignalIdentifier::Discover => {
446 if body.len() > 0 {
448 return Err(Error::RequestInvalid(ErrorCode::BadLength));
449 }
450 Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
451 }
452 SignalIdentifier::GetCapabilities => {
453 parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
454 }
455 SignalIdentifier::GetAllCapabilities => parse_one_seid!(
456 body,
457 signal,
458 peer,
459 id,
460 GetAllCapabilities,
461 GetCapabilitiesResponder
462 ),
463 SignalIdentifier::SetConfiguration => {
464 if body.len() < 4 {
465 return Err(Error::RequestInvalid(ErrorCode::BadLength));
466 }
467 let requested = Request::get_req_capabilities(&body[2..])?;
468 Ok(Request::SetConfiguration {
469 local_stream_id: StreamEndpointId::from_msg(&body[0]),
470 remote_stream_id: StreamEndpointId::from_msg(&body[1]),
471 capabilities: requested,
472 responder: ConfigureResponder { signal, peer: peer.clone(), id },
473 })
474 }
475 SignalIdentifier::GetConfiguration => {
476 parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
477 }
478 SignalIdentifier::Reconfigure => {
479 if body.len() < 3 {
480 return Err(Error::RequestInvalid(ErrorCode::BadLength));
481 }
482 let requested = Request::get_req_capabilities(&body[1..])?;
483 match requested.iter().find(|x| !x.is_application()) {
484 Some(x) => {
485 return Err(Error::RequestInvalidExtra(
486 ErrorCode::InvalidCapabilities,
487 (&x.category()).into(),
488 ));
489 }
490 None => (),
491 };
492 Ok(Request::Reconfigure {
493 local_stream_id: StreamEndpointId::from_msg(&body[0]),
494 capabilities: requested,
495 responder: ConfigureResponder { signal, peer: peer.clone(), id },
496 })
497 }
498 SignalIdentifier::Open => {
499 parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
500 }
501 SignalIdentifier::Start => {
502 let seids = Request::get_req_seids(body)?;
503 Ok(Request::Start {
504 stream_ids: seids,
505 responder: StreamResponder { signal, peer: peer.clone(), id },
506 })
507 }
508 SignalIdentifier::Close => {
509 parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
510 }
511 SignalIdentifier::Suspend => {
512 let seids = Request::get_req_seids(body)?;
513 Ok(Request::Suspend {
514 stream_ids: seids,
515 responder: StreamResponder { signal, peer: peer.clone(), id },
516 })
517 }
518 SignalIdentifier::Abort => {
519 parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
520 }
521 SignalIdentifier::DelayReport => {
522 if body.len() != 3 {
523 return Err(Error::RequestInvalid(ErrorCode::BadLength));
524 }
525 let delay_arr: [u8; 2] = [body[1], body[2]];
526 let delay = u16::from_be_bytes(delay_arr);
527 Ok(Request::DelayReport {
528 stream_id: StreamEndpointId::from_msg(&body[0]),
529 delay,
530 responder: SimpleResponder { signal, peer: peer.clone(), id },
531 })
532 }
533 _ => Err(Error::UnimplementedMessage),
534 }
535 }
536}
537
538#[derive(Debug)]
540pub struct RequestStream {
541 inner: Arc<PeerInner>,
542 terminated: bool,
543}
544
545impl RequestStream {
546 fn new(inner: Arc<PeerInner>) -> Self {
547 Self { inner, terminated: false }
548 }
549}
550
551impl Unpin for RequestStream {}
552
553impl Stream for RequestStream {
554 type Item = Result<Request>;
555
556 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
557 Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
558 Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
559 match Request::parse(&self.inner, label, signal, &body) {
560 Err(Error::RequestInvalid(code)) => {
561 self.inner.send_reject(label, signal, code)?;
562 return Poll::Pending;
563 }
564 Err(Error::RequestInvalidExtra(code, extra)) => {
565 self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
566 return Poll::Pending;
567 }
568 Err(Error::UnimplementedMessage) => {
569 self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
570 return Poll::Pending;
571 }
572 x => Some(x),
573 }
574 }
575 Err(Error::PeerDisconnected) => {
576 self.terminated = true;
577 None
578 }
579 Err(e) => Some(Err(e)),
580 })
581 }
582}
583
584impl FusedStream for RequestStream {
585 fn is_terminated(&self) -> bool {
586 self.terminated
587 }
588}
589
590impl Drop for RequestStream {
591 fn drop(&mut self) {
592 self.inner.incoming_requests.lock().listener = RequestListener::None;
593 self.inner.wake_any();
594 }
595}
596
597#[derive(Debug)]
599pub struct SimpleResponse {}
600
601impl Decodable for SimpleResponse {
602 type Error = Error;
603
604 fn decode(from: &[u8]) -> Result<Self> {
605 if from.len() > 0 {
606 return Err(Error::InvalidMessage);
607 }
608 Ok(SimpleResponse {})
609 }
610}
611
612impl Into<()> for SimpleResponse {
613 fn into(self) -> () {
614 ()
615 }
616}
617
618#[derive(Debug)]
619struct DiscoverResponse {
620 endpoints: Vec<StreamInformation>,
621}
622
623impl Decodable for DiscoverResponse {
624 type Error = Error;
625
626 fn decode(from: &[u8]) -> Result<Self> {
627 let mut endpoints = Vec::<StreamInformation>::new();
628 let mut idx = 0;
629 while idx < from.len() {
630 let endpoint = StreamInformation::decode(&from[idx..])?;
631 idx += endpoint.encoded_len();
632 endpoints.push(endpoint);
633 }
634 Ok(DiscoverResponse { endpoints })
635 }
636}
637
638impl Into<Vec<StreamInformation>> for DiscoverResponse {
639 fn into(self) -> Vec<StreamInformation> {
640 self.endpoints
641 }
642}
643
644#[derive(Debug)]
645pub struct DiscoverResponder {
646 peer: Arc<PeerInner>,
647 id: TxLabel,
648}
649
650impl DiscoverResponder {
651 pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
655 if endpoints.len() == 0 {
656 return Err(Error::Encoding);
658 }
659 let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
660 let mut idx = 0;
661 for endpoint in endpoints {
662 endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
663 idx += endpoint.encoded_len();
664 }
665 self.peer.send_response(self.id, SignalIdentifier::Discover, ¶ms)
666 }
667
668 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
669 self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
670 }
671}
672
673#[derive(Debug)]
674pub struct GetCapabilitiesResponder {
675 peer: Arc<PeerInner>,
676 signal: SignalIdentifier,
677 id: TxLabel,
678}
679
680impl GetCapabilitiesResponder {
681 pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
682 let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
683 let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
684 let mut reply = vec![0 as u8; reply_len];
685 let mut pos = 0;
686 for capability in included_iter {
687 let size = capability.encoded_len();
688 capability.encode(&mut reply[pos..pos + size])?;
689 pos += size;
690 }
691 self.peer.send_response(self.id, self.signal, &reply)
692 }
693
694 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
695 self.peer.send_reject(self.id, self.signal, error_code)
696 }
697}
698
699#[derive(Debug)]
700struct GetCapabilitiesResponse {
701 capabilities: Vec<ServiceCapability>,
702}
703
704impl Decodable for GetCapabilitiesResponse {
705 type Error = Error;
706
707 fn decode(from: &[u8]) -> Result<Self> {
708 let mut capabilities = Vec::<ServiceCapability>::new();
709 let mut idx = 0;
710 while idx < from.len() {
711 match ServiceCapability::decode(&from[idx..]) {
712 Ok(capability) => {
713 idx = idx + capability.encoded_len();
714 capabilities.push(capability);
715 }
716 Err(_) => {
717 info!(
722 "GetCapabilitiesResponse decode: Capability {:?} not supported.",
723 from[idx]
724 );
725 let length_of_capability = from[idx + 1] as usize;
726 idx = idx + 2 + length_of_capability;
727 }
728 }
729 }
730 Ok(GetCapabilitiesResponse { capabilities })
731 }
732}
733
734impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
735 fn into(self) -> Vec<ServiceCapability> {
736 self.capabilities
737 }
738}
739
740#[derive(Debug)]
741pub struct SimpleResponder {
742 peer: Arc<PeerInner>,
743 signal: SignalIdentifier,
744 id: TxLabel,
745}
746
747impl SimpleResponder {
748 pub fn send(self) -> Result<()> {
749 self.peer.send_response(self.id, self.signal, &[])
750 }
751
752 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
753 self.peer.send_reject(self.id, self.signal, error_code)
754 }
755}
756
757#[derive(Debug)]
758pub struct StreamResponder {
759 peer: Arc<PeerInner>,
760 signal: SignalIdentifier,
761 id: TxLabel,
762}
763
764impl StreamResponder {
765 pub fn send(self) -> Result<()> {
766 self.peer.send_response(self.id, self.signal, &[])
767 }
768
769 pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
770 self.peer.send_reject_params(
771 self.id,
772 self.signal,
773 &[stream_id.to_msg(), u8::from(&error_code)],
774 )
775 }
776}
777
778#[derive(Debug)]
779pub struct ConfigureResponder {
780 peer: Arc<PeerInner>,
781 signal: SignalIdentifier,
782 id: TxLabel,
783}
784
785impl ConfigureResponder {
786 pub fn send(self) -> Result<()> {
787 self.peer.send_response(self.id, self.signal, &[])
788 }
789
790 pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
791 self.peer.send_reject_params(
792 self.id,
793 self.signal,
794 &[u8::from(&category), u8::from(&error_code)],
795 )
796 }
797}
798
799#[derive(Debug)]
800struct UnparsedRequest(SignalingHeader, Vec<u8>);
801
802impl UnparsedRequest {
803 fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
804 UnparsedRequest(header, body)
805 }
806}
807
808#[derive(Debug, Default)]
809struct RequestQueue {
810 listener: RequestListener,
811 queue: VecDeque<UnparsedRequest>,
812}
813
814#[derive(Debug)]
815enum RequestListener {
816 None,
818 New,
820 Some(Waker),
822}
823
824impl Default for RequestListener {
825 fn default() -> Self {
826 RequestListener::None
827 }
828}
829
830#[derive(Debug)]
832enum ResponseWaiter {
833 WillPoll,
835 Waiting(Waker),
837 Received(Vec<u8>),
840 Discard,
843}
844
845impl ResponseWaiter {
846 fn is_received(&self) -> bool {
848 if let ResponseWaiter::Received(_) = self {
849 true
850 } else {
851 false
852 }
853 }
854
855 fn unwrap_received(self) -> Vec<u8> {
856 if let ResponseWaiter::Received(buf) = self {
857 buf
858 } else {
859 panic!("expected received buf")
860 }
861 }
862}
863
864fn decode_signaling_response<D: Decodable<Error = Error>>(
865 expected_signal: SignalIdentifier,
866 buf: Vec<u8>,
867) -> Result<D> {
868 let header = SignalingHeader::decode(buf.as_slice())?;
869 if header.signal() != expected_signal {
870 return Err(Error::InvalidHeader);
871 }
872 let params = &buf[header.encoded_len()..];
873 match header.message_type {
874 SignalingMessageType::ResponseAccept => D::decode(params),
875 SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
876 Err(RemoteReject::from_params(header.signal(), params).into())
877 }
878 SignalingMessageType::Command => unreachable!(),
879 }
880}
881
882#[derive(Debug)]
884pub struct CommandResponse {
885 id: TxLabel,
886 inner: Option<Arc<PeerInner>>,
888}
889
890impl Unpin for CommandResponse {}
891
892impl Future for CommandResponse {
893 type Output = Result<Vec<u8>>;
894 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
895 let this = &mut *self;
896 let res;
897 {
898 let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
899 res = client.poll_recv_response(&this.id, cx);
900 }
901
902 if let Poll::Ready(Ok(_)) = res {
903 let inner = this.inner.take().expect("CommandResponse polled after completion");
904 inner.wake_any();
905 }
906
907 res
908 }
909}
910
911impl FusedFuture for CommandResponse {
912 fn is_terminated(&self) -> bool {
913 self.inner.is_none()
914 }
915}
916
917impl Drop for CommandResponse {
918 fn drop(&mut self) {
919 if let Some(inner) = &self.inner {
920 inner.remove_response_interest(&self.id);
921 inner.wake_any();
922 }
923 }
924}
925
926#[derive(Debug)]
927struct PeerInner {
928 signaling: Channel,
930
931 response_waiters: Mutex<Slab<ResponseWaiter>>,
937
938 incoming_requests: Mutex<RequestQueue>,
942}
943
944impl PeerInner {
945 fn add_response_waiter(&self) -> Result<TxLabel> {
948 let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
949 let id = TxLabel::try_from(key as u8);
950 if id.is_err() {
951 warn!("Transaction IDs are exhausted");
952 let _ = self.response_waiters.lock().remove(key);
953 }
954 id
955 }
956
957 fn remove_response_interest(&self, id: &TxLabel) {
960 let mut lock = self.response_waiters.lock();
961 let idx = usize::from(id);
962 if lock[idx].is_received() {
963 let _ = lock.remove(idx);
964 } else {
965 lock[idx] = ResponseWaiter::Discard;
966 }
967 }
968
969 fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
974 let is_closed = self.recv_all(cx)?;
975
976 let mut lock = self.incoming_requests.lock();
977
978 if let Some(request) = lock.queue.pop_front() {
979 Poll::Ready(Ok(request))
980 } else {
981 lock.listener = RequestListener::Some(cx.waker().clone());
982 if is_closed {
983 Poll::Ready(Err(Error::PeerDisconnected))
984 } else {
985 Poll::Pending
986 }
987 }
988 }
989
990 fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
995 let is_closed = self.recv_all(cx)?;
996
997 let mut waiters = self.response_waiters.lock();
998 let idx = usize::from(label);
999 if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
1002 let buf = waiters.remove(idx).unwrap_received();
1004 Poll::Ready(Ok(buf))
1005 } else {
1006 *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1008 ResponseWaiter::Waiting(cx.waker().clone());
1009
1010 if is_closed {
1011 Poll::Ready(Err(Error::PeerDisconnected))
1012 } else {
1013 Poll::Pending
1014 }
1015 }
1016 }
1017
1018 fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1022 loop {
1023 let mut next_packet = Vec::new();
1024 let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1025 Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1026 trace!("Signaling peer closed");
1027 return Ok(true);
1028 }
1029 Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1030 Poll::Pending => return Ok(false),
1031 Poll::Ready(Ok(size)) => size,
1032 };
1033 if packet_size == 0 {
1034 continue;
1035 }
1036 let header = match SignalingHeader::decode(next_packet.as_slice()) {
1040 Err(Error::InvalidSignalId(label, id)) => {
1041 self.send_general_reject(label, id)?;
1042 continue;
1043 }
1044 Err(_) => {
1045 info!("received unrejectable message");
1048 continue;
1049 }
1050 Ok(x) => x,
1051 };
1052 if header.is_command() {
1054 let mut lock = self.incoming_requests.lock();
1055 let body = next_packet.split_off(header.encoded_len());
1056 lock.queue.push_back(UnparsedRequest::new(header, body));
1057 if let RequestListener::Some(ref waker) = lock.listener {
1058 waker.wake_by_ref();
1059 }
1060 } else {
1061 let mut waiters = self.response_waiters.lock();
1063 let idx = usize::from(&header.label());
1064 if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1065 let _ = waiters.remove(idx);
1066 } else if let Some(entry) = waiters.get_mut(idx) {
1067 let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1068 if let ResponseWaiter::Waiting(waker) = old_entry {
1069 waker.wake();
1070 }
1071 } else {
1072 warn!("response for {:?} we did not send, dropping", header.label());
1073 }
1074 }
1076 }
1077 }
1078
1079 fn wake_any(&self) {
1082 {
1087 let lock = self.response_waiters.lock();
1088 for (_, response_waiter) in lock.iter() {
1089 if let ResponseWaiter::Waiting(waker) = response_waiter {
1090 waker.wake_by_ref();
1091 return;
1092 }
1093 }
1094 }
1095 {
1096 let lock = self.incoming_requests.lock();
1097 if let RequestListener::Some(waker) = &lock.listener {
1098 waker.wake_by_ref();
1099 return;
1100 }
1101 }
1102 }
1103
1104 fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1106 let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1109 self.send_signal(packet)
1110 }
1111
1112 fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1113 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1114 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1115 header.encode(packet.as_mut_slice())?;
1116 packet[header.encoded_len()..].clone_from_slice(params);
1117 self.send_signal(&packet)
1118 }
1119
1120 fn send_reject(
1121 &self,
1122 label: TxLabel,
1123 signal: SignalIdentifier,
1124 error_code: ErrorCode,
1125 ) -> Result<()> {
1126 self.send_reject_params(label, signal, &[u8::from(&error_code)])
1127 }
1128
1129 fn send_reject_params(
1130 &self,
1131 label: TxLabel,
1132 signal: SignalIdentifier,
1133 params: &[u8],
1134 ) -> Result<()> {
1135 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1136 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1137 header.encode(packet.as_mut_slice())?;
1138 packet[header.encoded_len()..].clone_from_slice(params);
1139 self.send_signal(&packet)
1140 }
1141
1142 fn send_signal(&self, data: &[u8]) -> Result<()> {
1143 let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1144 Ok(())
1145 }
1146}