1use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream, StreamExt};
10use futures::task::{Context, Poll, Waker};
11use futures::{Future, FutureExt, TryFutureExt, ready};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21use zx::{self as zx, MonotonicDuration};
22
23mod rtp;
24mod stream_endpoint;
25#[cfg(test)]
26mod tests;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33 MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36 ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37 Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40#[derive(Debug, Clone)]
50pub struct Peer {
51 inner: Arc<PeerInner>,
52}
53
54impl Peer {
55 pub fn new(signaling: Channel) -> Self {
57 Self {
58 inner: Arc::new(PeerInner {
59 signaling: Mutex::new(signaling),
60 signaling_channel_closed: AtomicBool::new(false),
61 response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
62 incoming_requests: Mutex::<RequestQueue>::default(),
63 }),
64 }
65 }
66
67 #[track_caller]
70 pub fn take_request_stream(&self) -> RequestStream {
71 {
72 let mut lock = self.inner.incoming_requests.lock();
73 if let RequestListener::None = lock.listener {
74 lock.listener = RequestListener::New;
75 } else {
76 panic!("Request stream has already been taken");
77 }
78 }
79
80 RequestStream::new(self.inner.clone())
81 }
82
83 pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> + use<> {
87 self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
88 }
89
90 pub fn get_capabilities(
97 &self,
98 stream_id: &StreamEndpointId,
99 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
100 let stream_params = &[stream_id.to_msg()];
101 self.send_command::<GetCapabilitiesResponse>(
102 SignalIdentifier::GetCapabilities,
103 stream_params,
104 )
105 .ok_into()
106 }
107
108 pub fn get_all_capabilities(
114 &self,
115 stream_id: &StreamEndpointId,
116 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
117 let stream_params = &[stream_id.to_msg()];
118 self.send_command::<GetCapabilitiesResponse>(
119 SignalIdentifier::GetAllCapabilities,
120 stream_params,
121 )
122 .ok_into()
123 }
124
125 pub fn set_configuration(
132 &self,
133 stream_id: &StreamEndpointId,
134 local_stream_id: &StreamEndpointId,
135 capabilities: &[ServiceCapability],
136 ) -> impl Future<Output = Result<()>> + use<> {
137 assert!(!capabilities.is_empty(), "must set at least one capability");
138 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
139 params[0] = stream_id.to_msg();
140 params[1] = local_stream_id.to_msg();
141 let mut idx = 2;
142 for capability in capabilities {
143 if let Err(e) = capability.encode(&mut params[idx..]) {
144 return futures::future::err(e).left_future();
145 }
146 idx += capability.encoded_len();
147 }
148 self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, ¶ms)
149 .ok_into()
150 .right_future()
151 }
152
153 pub fn get_configuration(
159 &self,
160 stream_id: &StreamEndpointId,
161 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
162 let stream_params = &[stream_id.to_msg()];
163 self.send_command::<GetCapabilitiesResponse>(
164 SignalIdentifier::GetConfiguration,
165 stream_params,
166 )
167 .ok_into()
168 }
169
170 pub fn reconfigure(
179 &self,
180 stream_id: &StreamEndpointId,
181 capabilities: &[ServiceCapability],
182 ) -> impl Future<Output = Result<()>> + use<> {
183 assert!(!capabilities.is_empty(), "must set at least one capability");
184 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
185 params[0] = stream_id.to_msg();
186 let mut idx = 1;
187 for capability in capabilities {
188 if !capability.is_application() {
189 return futures::future::err(Error::Encoding).left_future();
190 }
191 if let Err(e) = capability.encode(&mut params[idx..]) {
192 return futures::future::err(e).left_future();
193 }
194 idx += capability.encoded_len();
195 }
196 self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, ¶ms)
197 .ok_into()
198 .right_future()
199 }
200
201 pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
205 let stream_params = &[stream_id.to_msg()];
206 self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
207 }
208
209 pub fn start(
214 &self,
215 stream_ids: &[StreamEndpointId],
216 ) -> impl Future<Output = Result<()>> + use<> {
217 let mut stream_params = Vec::with_capacity(stream_ids.len());
218 for stream_id in stream_ids {
219 stream_params.push(stream_id.to_msg());
220 }
221 self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
222 }
223
224 pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
227 let stream_params = &[stream_id.to_msg()];
228 let response: CommandResponseFut<SimpleResponse> =
229 self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
230 response.ok_into()
231 }
232
233 pub fn suspend(
237 &self,
238 stream_ids: &[StreamEndpointId],
239 ) -> impl Future<Output = Result<()>> + use<> {
240 let mut stream_params = Vec::with_capacity(stream_ids.len());
241 for stream_id in stream_ids {
242 stream_params.push(stream_id.to_msg());
243 }
244 let response: CommandResponseFut<SimpleResponse> =
245 self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
246 response.ok_into()
247 }
248
249 pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
254 let stream_params = &[stream_id.to_msg()];
255 self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
256 }
257
258 pub fn delay_report(
262 &self,
263 stream_id: &StreamEndpointId,
264 delay: u16,
265 ) -> impl Future<Output = Result<()>> + use<> {
266 let delay_bytes: [u8; 2] = delay.to_be_bytes();
267 let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
268 self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
269 }
270
271 const RTX_SIG_TIMER_MS: i64 = 3000;
273 const COMMAND_TIMEOUT: MonotonicDuration =
274 MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
275
276 fn send_command<D: Decodable<Error = Error>>(
279 &self,
280 signal: SignalIdentifier,
281 payload: &[u8],
282 ) -> CommandResponseFut<D> {
283 let send_result = (|| {
284 let id = self.inner.add_response_waiter()?;
285 let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
286 let mut buf = vec![0; header.encoded_len()];
287 header.encode(buf.as_mut_slice())?;
288 buf.extend_from_slice(payload);
289 self.inner.send_signal(buf.as_slice())?;
290 Ok(header)
291 })();
292
293 CommandResponseFut::new(send_result, self.inner.clone())
294 }
295}
296
297struct CommandResponseFut<D: Decodable> {
299 id: SignalIdentifier,
300 fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
301 _phantom: PhantomData<D>,
302}
303
304impl<D: Decodable> Unpin for CommandResponseFut<D> {}
305
306impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
307 fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
308 let header = match send_result {
309 Err(e) => {
310 return Self {
311 id: SignalIdentifier::Abort,
312 fut: Box::pin(MaybeDone::Done(Err(e))),
313 _phantom: PhantomData,
314 };
315 }
316 Ok(header) => header,
317 };
318 let response = CommandResponse { id: header.label(), inner: Some(inner) };
319 let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
320 let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
321
322 Self {
323 id: header.signal(),
324 fut: Box::pin(futures::future::maybe_done(timedout_fut)),
325 _phantom: PhantomData,
326 }
327 }
328}
329
330impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
331 type Output = Result<D>;
332
333 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334 ready!(self.fut.poll_unpin(cx));
335 Poll::Ready(
336 self.fut
337 .as_mut()
338 .take_output()
339 .unwrap_or(Err(Error::AlreadyReceived))
340 .and_then(|buf| decode_signaling_response(self.id, buf)),
341 )
342 }
343}
344
345#[derive(Debug)]
350pub enum Request {
351 Discover {
352 responder: DiscoverResponder,
353 },
354 GetCapabilities {
355 stream_id: StreamEndpointId,
356 responder: GetCapabilitiesResponder,
357 },
358 GetAllCapabilities {
359 stream_id: StreamEndpointId,
360 responder: GetCapabilitiesResponder,
361 },
362 SetConfiguration {
363 local_stream_id: StreamEndpointId,
364 remote_stream_id: StreamEndpointId,
365 capabilities: Vec<ServiceCapability>,
366 responder: ConfigureResponder,
367 },
368 GetConfiguration {
369 stream_id: StreamEndpointId,
370 responder: GetCapabilitiesResponder,
371 },
372 Reconfigure {
373 local_stream_id: StreamEndpointId,
374 capabilities: Vec<ServiceCapability>,
375 responder: ConfigureResponder,
376 },
377 Open {
378 stream_id: StreamEndpointId,
379 responder: SimpleResponder,
380 },
381 Start {
382 stream_ids: Vec<StreamEndpointId>,
383 responder: StreamResponder,
384 },
385 Close {
386 stream_id: StreamEndpointId,
387 responder: SimpleResponder,
388 },
389 Suspend {
390 stream_ids: Vec<StreamEndpointId>,
391 responder: StreamResponder,
392 },
393 Abort {
394 stream_id: StreamEndpointId,
395 responder: SimpleResponder,
396 },
397 DelayReport {
398 stream_id: StreamEndpointId,
399 delay: u16,
400 responder: SimpleResponder,
401 }, }
403
404macro_rules! parse_one_seid {
405 ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
406 if $body.len() != 1 {
407 Err(Error::RequestInvalid(ErrorCode::BadLength))
408 } else {
409 Ok(Request::$request_variant {
410 stream_id: StreamEndpointId::from_msg(&$body[0]),
411 responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
412 })
413 }
414 };
415}
416
417impl Request {
418 fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
419 if body.len() < 1 {
420 return Err(Error::RequestInvalid(ErrorCode::BadLength));
421 }
422 Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
423 }
424
425 fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
426 if encoded.len() < 2 {
427 return Err(Error::RequestInvalid(ErrorCode::BadLength));
428 }
429 let mut caps = vec![];
430 let mut loc = 0;
431 while loc < encoded.len() {
432 let cap = match ServiceCapability::decode(&encoded[loc..]) {
433 Ok(cap) => cap,
434 Err(Error::RequestInvalid(code)) => {
435 return Err(Error::RequestInvalidExtra(code, encoded[loc]));
436 }
437 Err(e) => return Err(e),
438 };
439 loc += cap.encoded_len();
440 caps.push(cap);
441 }
442 Ok(caps)
443 }
444
445 fn parse(
446 peer: &Arc<PeerInner>,
447 id: TxLabel,
448 signal: SignalIdentifier,
449 body: &[u8],
450 ) -> Result<Request> {
451 match signal {
452 SignalIdentifier::Discover => {
453 if body.len() > 0 {
455 return Err(Error::RequestInvalid(ErrorCode::BadLength));
456 }
457 Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
458 }
459 SignalIdentifier::GetCapabilities => {
460 parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
461 }
462 SignalIdentifier::GetAllCapabilities => parse_one_seid!(
463 body,
464 signal,
465 peer,
466 id,
467 GetAllCapabilities,
468 GetCapabilitiesResponder
469 ),
470 SignalIdentifier::SetConfiguration => {
471 if body.len() < 4 {
472 return Err(Error::RequestInvalid(ErrorCode::BadLength));
473 }
474 let requested = Request::get_req_capabilities(&body[2..])?;
475 Ok(Request::SetConfiguration {
476 local_stream_id: StreamEndpointId::from_msg(&body[0]),
477 remote_stream_id: StreamEndpointId::from_msg(&body[1]),
478 capabilities: requested,
479 responder: ConfigureResponder { signal, peer: peer.clone(), id },
480 })
481 }
482 SignalIdentifier::GetConfiguration => {
483 parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
484 }
485 SignalIdentifier::Reconfigure => {
486 if body.len() < 3 {
487 return Err(Error::RequestInvalid(ErrorCode::BadLength));
488 }
489 let requested = Request::get_req_capabilities(&body[1..])?;
490 match requested.iter().find(|x| !x.is_application()) {
491 Some(x) => {
492 return Err(Error::RequestInvalidExtra(
493 ErrorCode::InvalidCapabilities,
494 (&x.category()).into(),
495 ));
496 }
497 None => (),
498 };
499 Ok(Request::Reconfigure {
500 local_stream_id: StreamEndpointId::from_msg(&body[0]),
501 capabilities: requested,
502 responder: ConfigureResponder { signal, peer: peer.clone(), id },
503 })
504 }
505 SignalIdentifier::Open => {
506 parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
507 }
508 SignalIdentifier::Start => {
509 let seids = Request::get_req_seids(body)?;
510 Ok(Request::Start {
511 stream_ids: seids,
512 responder: StreamResponder { signal, peer: peer.clone(), id },
513 })
514 }
515 SignalIdentifier::Close => {
516 parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
517 }
518 SignalIdentifier::Suspend => {
519 let seids = Request::get_req_seids(body)?;
520 Ok(Request::Suspend {
521 stream_ids: seids,
522 responder: StreamResponder { signal, peer: peer.clone(), id },
523 })
524 }
525 SignalIdentifier::Abort => {
526 parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
527 }
528 SignalIdentifier::DelayReport => {
529 if body.len() != 3 {
530 return Err(Error::RequestInvalid(ErrorCode::BadLength));
531 }
532 let delay_arr: [u8; 2] = [body[1], body[2]];
533 let delay = u16::from_be_bytes(delay_arr);
534 Ok(Request::DelayReport {
535 stream_id: StreamEndpointId::from_msg(&body[0]),
536 delay,
537 responder: SimpleResponder { signal, peer: peer.clone(), id },
538 })
539 }
540 _ => Err(Error::UnimplementedMessage),
541 }
542 }
543}
544
545#[derive(Debug)]
547pub struct RequestStream {
548 inner: Arc<PeerInner>,
549 terminated: bool,
550}
551
552impl RequestStream {
553 fn new(inner: Arc<PeerInner>) -> Self {
554 Self { inner, terminated: false }
555 }
556}
557
558impl Unpin for RequestStream {}
559
560impl Stream for RequestStream {
561 type Item = Result<Request>;
562
563 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
564 Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
565 Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
566 match Request::parse(&self.inner, label, signal, &body) {
567 Err(Error::RequestInvalid(code)) => {
568 self.inner.send_reject(label, signal, code)?;
569 return Poll::Pending;
570 }
571 Err(Error::RequestInvalidExtra(code, extra)) => {
572 self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
573 return Poll::Pending;
574 }
575 Err(Error::UnimplementedMessage) => {
576 self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
577 return Poll::Pending;
578 }
579 x => Some(x),
580 }
581 }
582 Err(Error::PeerDisconnected) => {
583 self.terminated = true;
584 None
585 }
586 Err(e) => Some(Err(e)),
587 })
588 }
589}
590
591impl FusedStream for RequestStream {
592 fn is_terminated(&self) -> bool {
593 self.terminated
594 }
595}
596
597impl Drop for RequestStream {
598 fn drop(&mut self) {
599 self.inner.incoming_requests.lock().listener = RequestListener::None;
600 self.inner.wake_any();
601 }
602}
603
604#[derive(Debug)]
606pub struct SimpleResponse {}
607
608impl Decodable for SimpleResponse {
609 type Error = Error;
610
611 fn decode(from: &[u8]) -> Result<Self> {
612 if from.len() > 0 {
613 return Err(Error::InvalidMessage);
614 }
615 Ok(SimpleResponse {})
616 }
617}
618
619impl Into<()> for SimpleResponse {
620 fn into(self) -> () {
621 ()
622 }
623}
624
625#[derive(Debug)]
626struct DiscoverResponse {
627 endpoints: Vec<StreamInformation>,
628}
629
630impl Decodable for DiscoverResponse {
631 type Error = Error;
632
633 fn decode(from: &[u8]) -> Result<Self> {
634 let mut endpoints = Vec::<StreamInformation>::new();
635 let mut idx = 0;
636 while idx < from.len() {
637 let endpoint = StreamInformation::decode(&from[idx..])?;
638 idx += endpoint.encoded_len();
639 endpoints.push(endpoint);
640 }
641 Ok(DiscoverResponse { endpoints })
642 }
643}
644
645impl Into<Vec<StreamInformation>> for DiscoverResponse {
646 fn into(self) -> Vec<StreamInformation> {
647 self.endpoints
648 }
649}
650
651#[derive(Debug)]
652pub struct DiscoverResponder {
653 peer: Arc<PeerInner>,
654 id: TxLabel,
655}
656
657impl DiscoverResponder {
658 pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
662 if endpoints.len() == 0 {
663 return Err(Error::Encoding);
665 }
666 let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
667 let mut idx = 0;
668 for endpoint in endpoints {
669 endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
670 idx += endpoint.encoded_len();
671 }
672 self.peer.send_response(self.id, SignalIdentifier::Discover, ¶ms)
673 }
674
675 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
676 self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
677 }
678}
679
680#[derive(Debug)]
681pub struct GetCapabilitiesResponder {
682 peer: Arc<PeerInner>,
683 signal: SignalIdentifier,
684 id: TxLabel,
685}
686
687impl GetCapabilitiesResponder {
688 pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
689 let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
690 let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
691 let mut reply = vec![0 as u8; reply_len];
692 let mut pos = 0;
693 for capability in included_iter {
694 let size = capability.encoded_len();
695 capability.encode(&mut reply[pos..pos + size])?;
696 pos += size;
697 }
698 self.peer.send_response(self.id, self.signal, &reply)
699 }
700
701 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
702 self.peer.send_reject(self.id, self.signal, error_code)
703 }
704}
705
706#[derive(Debug)]
707struct GetCapabilitiesResponse {
708 capabilities: Vec<ServiceCapability>,
709}
710
711impl Decodable for GetCapabilitiesResponse {
712 type Error = Error;
713
714 fn decode(from: &[u8]) -> Result<Self> {
715 let mut capabilities = Vec::<ServiceCapability>::new();
716 let mut idx = 0;
717 while idx < from.len() {
718 match ServiceCapability::decode(&from[idx..]) {
719 Ok(capability) => {
720 idx = idx + capability.encoded_len();
721 capabilities.push(capability);
722 }
723 Err(_) => {
724 info!(
729 "GetCapabilitiesResponse decode: Capability {:?} not supported.",
730 from[idx]
731 );
732 let length_of_capability = from[idx + 1] as usize;
733 idx = idx + 2 + length_of_capability;
734 }
735 }
736 }
737 Ok(GetCapabilitiesResponse { capabilities })
738 }
739}
740
741impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
742 fn into(self) -> Vec<ServiceCapability> {
743 self.capabilities
744 }
745}
746
747#[derive(Debug)]
748pub struct SimpleResponder {
749 peer: Arc<PeerInner>,
750 signal: SignalIdentifier,
751 id: TxLabel,
752}
753
754impl SimpleResponder {
755 pub fn send(self) -> Result<()> {
756 self.peer.send_response(self.id, self.signal, &[])
757 }
758
759 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
760 self.peer.send_reject(self.id, self.signal, error_code)
761 }
762}
763
764#[derive(Debug)]
765pub struct StreamResponder {
766 peer: Arc<PeerInner>,
767 signal: SignalIdentifier,
768 id: TxLabel,
769}
770
771impl StreamResponder {
772 pub fn send(self) -> Result<()> {
773 self.peer.send_response(self.id, self.signal, &[])
774 }
775
776 pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
777 self.peer.send_reject_params(
778 self.id,
779 self.signal,
780 &[stream_id.to_msg(), u8::from(&error_code)],
781 )
782 }
783}
784
785#[derive(Debug)]
786pub struct ConfigureResponder {
787 peer: Arc<PeerInner>,
788 signal: SignalIdentifier,
789 id: TxLabel,
790}
791
792impl ConfigureResponder {
793 pub fn send(self) -> Result<()> {
794 self.peer.send_response(self.id, self.signal, &[])
795 }
796
797 pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
798 self.peer.send_reject_params(
799 self.id,
800 self.signal,
801 &[u8::from(&category), u8::from(&error_code)],
802 )
803 }
804}
805
806#[derive(Debug)]
807struct UnparsedRequest(SignalingHeader, Vec<u8>);
808
809impl UnparsedRequest {
810 fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
811 UnparsedRequest(header, body)
812 }
813}
814
815#[derive(Debug, Default)]
816struct RequestQueue {
817 listener: RequestListener,
818 queue: VecDeque<UnparsedRequest>,
819}
820
821#[derive(Debug)]
822enum RequestListener {
823 None,
825 New,
827 Some(Waker),
829}
830
831impl Default for RequestListener {
832 fn default() -> Self {
833 RequestListener::None
834 }
835}
836
837#[derive(Debug)]
839enum ResponseWaiter {
840 WillPoll,
842 Waiting(Waker),
844 Received(Vec<u8>),
847 Discard,
850}
851
852impl ResponseWaiter {
853 fn is_received(&self) -> bool {
855 if let ResponseWaiter::Received(_) = self { true } else { false }
856 }
857
858 fn unwrap_received(self) -> Vec<u8> {
859 if let ResponseWaiter::Received(buf) = self { buf } else { panic!("expected received buf") }
860 }
861}
862
863fn decode_signaling_response<D: Decodable<Error = Error>>(
864 expected_signal: SignalIdentifier,
865 buf: Vec<u8>,
866) -> Result<D> {
867 let header = SignalingHeader::decode(buf.as_slice())?;
868 if header.signal() != expected_signal {
869 return Err(Error::InvalidHeader);
870 }
871 let params = &buf[header.encoded_len()..];
872 match header.message_type {
873 SignalingMessageType::ResponseAccept => D::decode(params),
874 SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
875 Err(RemoteReject::from_params(header.signal(), params).into())
876 }
877 SignalingMessageType::Command => unreachable!(),
878 }
879}
880
881#[derive(Debug)]
883pub struct CommandResponse {
884 id: TxLabel,
885 inner: Option<Arc<PeerInner>>,
887}
888
889impl Unpin for CommandResponse {}
890
891impl Future for CommandResponse {
892 type Output = Result<Vec<u8>>;
893 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
894 let this = &mut *self;
895 let res;
896 {
897 let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
898 res = client.poll_recv_response(&this.id, cx);
899 }
900
901 if let Poll::Ready(Ok(_)) = res {
902 let inner = this.inner.take().expect("CommandResponse polled after completion");
903 inner.wake_any();
904 }
905
906 res
907 }
908}
909
910impl FusedFuture for CommandResponse {
911 fn is_terminated(&self) -> bool {
912 self.inner.is_none()
913 }
914}
915
916impl Drop for CommandResponse {
917 fn drop(&mut self) {
918 if let Some(inner) = &self.inner {
919 inner.remove_response_interest(&self.id);
920 inner.wake_any();
921 }
922 }
923}
924
925#[derive(Debug)]
926struct PeerInner {
927 signaling: Mutex<Channel>,
929 signaling_channel_closed: AtomicBool,
930
931 response_waiters: Mutex<Slab<ResponseWaiter>>,
937
938 incoming_requests: Mutex<RequestQueue>,
942}
943
944impl PeerInner {
945 fn add_response_waiter(&self) -> Result<TxLabel> {
948 let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
949 let id = TxLabel::try_from(key as u8);
950 if id.is_err() {
951 warn!("Transaction IDs are exhausted");
952 let _ = self.response_waiters.lock().remove(key);
953 }
954 id
955 }
956
957 fn remove_response_interest(&self, id: &TxLabel) {
960 let mut lock = self.response_waiters.lock();
961 let idx = usize::from(id);
962 if lock[idx].is_received() {
963 let _ = lock.remove(idx);
964 } else {
965 lock[idx] = ResponseWaiter::Discard;
966 }
967 }
968
969 fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
974 let is_closed = self.recv_all(cx)?;
975
976 let mut lock = self.incoming_requests.lock();
977
978 if let Some(request) = lock.queue.pop_front() {
979 Poll::Ready(Ok(request))
980 } else {
981 lock.listener = RequestListener::Some(cx.waker().clone());
982 if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
983 }
984 }
985
986 fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
991 let is_closed = self.recv_all(cx)?;
992
993 let mut waiters = self.response_waiters.lock();
994 let idx = usize::from(label);
995 if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
998 let buf = waiters.remove(idx).unwrap_received();
1000 Poll::Ready(Ok(buf))
1001 } else {
1002 *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1004 ResponseWaiter::Waiting(cx.waker().clone());
1005
1006 if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
1007 }
1008 }
1009
1010 fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1014 if self.signaling_channel_closed.load(Ordering::Relaxed) {
1015 return Ok(true);
1016 }
1017 loop {
1018 let mut next_packet = {
1019 let mut signaling = self.signaling.lock();
1020 match signaling.poll_next_unpin(cx) {
1021 Poll::Ready(Some(Ok(packet))) => packet,
1022 Poll::Ready(Some(Err(zx::Status::PEER_CLOSED))) | Poll::Ready(None) => {
1023 trace!("Signaling peer closed");
1024 self.signaling_channel_closed.store(true, Ordering::Relaxed);
1025 return Ok(true);
1026 }
1027 Poll::Ready(Some(Err(e))) => return Err(Error::PeerRead(e)),
1028 Poll::Pending => return Ok(false),
1029 }
1030 };
1031
1032 let header = match SignalingHeader::decode(next_packet.as_slice()) {
1036 Err(Error::InvalidSignalId(label, id)) => {
1037 self.send_general_reject(label, id)?;
1038 continue;
1039 }
1040 Err(_) => {
1041 info!("received unrejectable message");
1044 continue;
1045 }
1046 Ok(x) => x,
1047 };
1048 if header.is_command() {
1050 let mut lock = self.incoming_requests.lock();
1051 let body = next_packet.split_off(header.encoded_len());
1052 lock.queue.push_back(UnparsedRequest::new(header, body));
1053 if let RequestListener::Some(ref waker) = lock.listener {
1054 waker.wake_by_ref();
1055 }
1056 } else {
1057 let mut waiters = self.response_waiters.lock();
1059 let idx = usize::from(&header.label());
1060 if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1061 let _ = waiters.remove(idx);
1062 } else if let Some(entry) = waiters.get_mut(idx) {
1063 let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1064 if let ResponseWaiter::Waiting(waker) = old_entry {
1065 waker.wake();
1066 }
1067 } else {
1068 warn!("response for {:?} we did not send, dropping", header.label());
1069 }
1070 }
1072 }
1073 }
1074
1075 fn wake_any(&self) {
1078 {
1083 let lock = self.response_waiters.lock();
1084 for (_, response_waiter) in lock.iter() {
1085 if let ResponseWaiter::Waiting(waker) = response_waiter {
1086 waker.wake_by_ref();
1087 return;
1088 }
1089 }
1090 }
1091 {
1092 let lock = self.incoming_requests.lock();
1093 if let RequestListener::Some(waker) = &lock.listener {
1094 waker.wake_by_ref();
1095 return;
1096 }
1097 }
1098 }
1099
1100 fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1102 let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1105 self.send_signal(packet)
1106 }
1107
1108 fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1109 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1110 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1111 header.encode(packet.as_mut_slice())?;
1112 packet[header.encoded_len()..].clone_from_slice(params);
1113 self.send_signal(&packet)
1114 }
1115
1116 fn send_reject(
1117 &self,
1118 label: TxLabel,
1119 signal: SignalIdentifier,
1120 error_code: ErrorCode,
1121 ) -> Result<()> {
1122 self.send_reject_params(label, signal, &[u8::from(&error_code)])
1123 }
1124
1125 fn send_reject_params(
1126 &self,
1127 label: TxLabel,
1128 signal: SignalIdentifier,
1129 params: &[u8],
1130 ) -> Result<()> {
1131 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1132 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1133 header.encode(packet.as_mut_slice())?;
1134 packet[header.encoded_len()..].clone_from_slice(params);
1135 self.send_signal(&packet)
1136 }
1137
1138 fn send_signal(&self, data: &[u8]) -> Result<()> {
1139 let _ = self.signaling.lock().write(data).map_err(|x| Error::PeerWrite(x))?;
1140 Ok(())
1141 }
1142}