1use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream};
10use futures::task::{Context, Poll, Waker};
11use futures::{Future, FutureExt, TryFutureExt, ready};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33 MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36 ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37 Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40#[derive(Debug, Clone)]
50pub struct Peer {
51 inner: Arc<PeerInner>,
52}
53
54impl Peer {
55 pub fn new(signaling: Channel) -> Self {
57 Self {
58 inner: Arc::new(PeerInner {
59 signaling,
60 response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61 incoming_requests: Mutex::<RequestQueue>::default(),
62 }),
63 }
64 }
65
66 #[track_caller]
69 pub fn take_request_stream(&self) -> RequestStream {
70 {
71 let mut lock = self.inner.incoming_requests.lock();
72 if let RequestListener::None = lock.listener {
73 lock.listener = RequestListener::New;
74 } else {
75 panic!("Request stream has already been taken");
76 }
77 }
78
79 RequestStream::new(self.inner.clone())
80 }
81
82 pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> + use<> {
86 self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87 }
88
89 pub fn get_capabilities(
96 &self,
97 stream_id: &StreamEndpointId,
98 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
99 let stream_params = &[stream_id.to_msg()];
100 self.send_command::<GetCapabilitiesResponse>(
101 SignalIdentifier::GetCapabilities,
102 stream_params,
103 )
104 .ok_into()
105 }
106
107 pub fn get_all_capabilities(
113 &self,
114 stream_id: &StreamEndpointId,
115 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
116 let stream_params = &[stream_id.to_msg()];
117 self.send_command::<GetCapabilitiesResponse>(
118 SignalIdentifier::GetAllCapabilities,
119 stream_params,
120 )
121 .ok_into()
122 }
123
124 pub fn set_configuration(
131 &self,
132 stream_id: &StreamEndpointId,
133 local_stream_id: &StreamEndpointId,
134 capabilities: &[ServiceCapability],
135 ) -> impl Future<Output = Result<()>> + use<> {
136 assert!(!capabilities.is_empty(), "must set at least one capability");
137 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138 params[0] = stream_id.to_msg();
139 params[1] = local_stream_id.to_msg();
140 let mut idx = 2;
141 for capability in capabilities {
142 if let Err(e) = capability.encode(&mut params[idx..]) {
143 return futures::future::err(e).left_future();
144 }
145 idx += capability.encoded_len();
146 }
147 self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, ¶ms)
148 .ok_into()
149 .right_future()
150 }
151
152 pub fn get_configuration(
158 &self,
159 stream_id: &StreamEndpointId,
160 ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
161 let stream_params = &[stream_id.to_msg()];
162 self.send_command::<GetCapabilitiesResponse>(
163 SignalIdentifier::GetConfiguration,
164 stream_params,
165 )
166 .ok_into()
167 }
168
169 pub fn reconfigure(
178 &self,
179 stream_id: &StreamEndpointId,
180 capabilities: &[ServiceCapability],
181 ) -> impl Future<Output = Result<()>> + use<> {
182 assert!(!capabilities.is_empty(), "must set at least one capability");
183 let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184 params[0] = stream_id.to_msg();
185 let mut idx = 1;
186 for capability in capabilities {
187 if !capability.is_application() {
188 return futures::future::err(Error::Encoding).left_future();
189 }
190 if let Err(e) = capability.encode(&mut params[idx..]) {
191 return futures::future::err(e).left_future();
192 }
193 idx += capability.encoded_len();
194 }
195 self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, ¶ms)
196 .ok_into()
197 .right_future()
198 }
199
200 pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
204 let stream_params = &[stream_id.to_msg()];
205 self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206 }
207
208 pub fn start(
213 &self,
214 stream_ids: &[StreamEndpointId],
215 ) -> impl Future<Output = Result<()>> + use<> {
216 let mut stream_params = Vec::with_capacity(stream_ids.len());
217 for stream_id in stream_ids {
218 stream_params.push(stream_id.to_msg());
219 }
220 self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
221 }
222
223 pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
226 let stream_params = &[stream_id.to_msg()];
227 let response: CommandResponseFut<SimpleResponse> =
228 self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
229 response.ok_into()
230 }
231
232 pub fn suspend(
236 &self,
237 stream_ids: &[StreamEndpointId],
238 ) -> impl Future<Output = Result<()>> + use<> {
239 let mut stream_params = Vec::with_capacity(stream_ids.len());
240 for stream_id in stream_ids {
241 stream_params.push(stream_id.to_msg());
242 }
243 let response: CommandResponseFut<SimpleResponse> =
244 self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
245 response.ok_into()
246 }
247
248 pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
253 let stream_params = &[stream_id.to_msg()];
254 self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
255 }
256
257 pub fn delay_report(
261 &self,
262 stream_id: &StreamEndpointId,
263 delay: u16,
264 ) -> impl Future<Output = Result<()>> + use<> {
265 let delay_bytes: [u8; 2] = delay.to_be_bytes();
266 let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
267 self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
268 }
269
270 const RTX_SIG_TIMER_MS: i64 = 3000;
272 const COMMAND_TIMEOUT: MonotonicDuration =
273 MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
274
275 fn send_command<D: Decodable<Error = Error>>(
278 &self,
279 signal: SignalIdentifier,
280 payload: &[u8],
281 ) -> CommandResponseFut<D> {
282 let send_result = (|| {
283 let id = self.inner.add_response_waiter()?;
284 let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
285 let mut buf = vec![0; header.encoded_len()];
286 header.encode(buf.as_mut_slice())?;
287 buf.extend_from_slice(payload);
288 self.inner.send_signal(buf.as_slice())?;
289 Ok(header)
290 })();
291
292 CommandResponseFut::new(send_result, self.inner.clone())
293 }
294}
295
296struct CommandResponseFut<D: Decodable> {
298 id: SignalIdentifier,
299 fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
300 _phantom: PhantomData<D>,
301}
302
303impl<D: Decodable> Unpin for CommandResponseFut<D> {}
304
305impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
306 fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
307 let header = match send_result {
308 Err(e) => {
309 return Self {
310 id: SignalIdentifier::Abort,
311 fut: Box::pin(MaybeDone::Done(Err(e))),
312 _phantom: PhantomData,
313 };
314 }
315 Ok(header) => header,
316 };
317 let response = CommandResponse { id: header.label(), inner: Some(inner) };
318 let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
319 let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
320
321 Self {
322 id: header.signal(),
323 fut: Box::pin(futures::future::maybe_done(timedout_fut)),
324 _phantom: PhantomData,
325 }
326 }
327}
328
329impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
330 type Output = Result<D>;
331
332 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
333 ready!(self.fut.poll_unpin(cx));
334 Poll::Ready(
335 self.fut
336 .as_mut()
337 .take_output()
338 .unwrap_or(Err(Error::AlreadyReceived))
339 .and_then(|buf| decode_signaling_response(self.id, buf)),
340 )
341 }
342}
343
344#[derive(Debug)]
349pub enum Request {
350 Discover {
351 responder: DiscoverResponder,
352 },
353 GetCapabilities {
354 stream_id: StreamEndpointId,
355 responder: GetCapabilitiesResponder,
356 },
357 GetAllCapabilities {
358 stream_id: StreamEndpointId,
359 responder: GetCapabilitiesResponder,
360 },
361 SetConfiguration {
362 local_stream_id: StreamEndpointId,
363 remote_stream_id: StreamEndpointId,
364 capabilities: Vec<ServiceCapability>,
365 responder: ConfigureResponder,
366 },
367 GetConfiguration {
368 stream_id: StreamEndpointId,
369 responder: GetCapabilitiesResponder,
370 },
371 Reconfigure {
372 local_stream_id: StreamEndpointId,
373 capabilities: Vec<ServiceCapability>,
374 responder: ConfigureResponder,
375 },
376 Open {
377 stream_id: StreamEndpointId,
378 responder: SimpleResponder,
379 },
380 Start {
381 stream_ids: Vec<StreamEndpointId>,
382 responder: StreamResponder,
383 },
384 Close {
385 stream_id: StreamEndpointId,
386 responder: SimpleResponder,
387 },
388 Suspend {
389 stream_ids: Vec<StreamEndpointId>,
390 responder: StreamResponder,
391 },
392 Abort {
393 stream_id: StreamEndpointId,
394 responder: SimpleResponder,
395 },
396 DelayReport {
397 stream_id: StreamEndpointId,
398 delay: u16,
399 responder: SimpleResponder,
400 }, }
402
403macro_rules! parse_one_seid {
404 ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
405 if $body.len() != 1 {
406 Err(Error::RequestInvalid(ErrorCode::BadLength))
407 } else {
408 Ok(Request::$request_variant {
409 stream_id: StreamEndpointId::from_msg(&$body[0]),
410 responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
411 })
412 }
413 };
414}
415
416impl Request {
417 fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
418 if body.len() < 1 {
419 return Err(Error::RequestInvalid(ErrorCode::BadLength));
420 }
421 Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
422 }
423
424 fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
425 if encoded.len() < 2 {
426 return Err(Error::RequestInvalid(ErrorCode::BadLength));
427 }
428 let mut caps = vec![];
429 let mut loc = 0;
430 while loc < encoded.len() {
431 let cap = match ServiceCapability::decode(&encoded[loc..]) {
432 Ok(cap) => cap,
433 Err(Error::RequestInvalid(code)) => {
434 return Err(Error::RequestInvalidExtra(code, encoded[loc]));
435 }
436 Err(e) => return Err(e),
437 };
438 loc += cap.encoded_len();
439 caps.push(cap);
440 }
441 Ok(caps)
442 }
443
444 fn parse(
445 peer: &Arc<PeerInner>,
446 id: TxLabel,
447 signal: SignalIdentifier,
448 body: &[u8],
449 ) -> Result<Request> {
450 match signal {
451 SignalIdentifier::Discover => {
452 if body.len() > 0 {
454 return Err(Error::RequestInvalid(ErrorCode::BadLength));
455 }
456 Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
457 }
458 SignalIdentifier::GetCapabilities => {
459 parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
460 }
461 SignalIdentifier::GetAllCapabilities => parse_one_seid!(
462 body,
463 signal,
464 peer,
465 id,
466 GetAllCapabilities,
467 GetCapabilitiesResponder
468 ),
469 SignalIdentifier::SetConfiguration => {
470 if body.len() < 4 {
471 return Err(Error::RequestInvalid(ErrorCode::BadLength));
472 }
473 let requested = Request::get_req_capabilities(&body[2..])?;
474 Ok(Request::SetConfiguration {
475 local_stream_id: StreamEndpointId::from_msg(&body[0]),
476 remote_stream_id: StreamEndpointId::from_msg(&body[1]),
477 capabilities: requested,
478 responder: ConfigureResponder { signal, peer: peer.clone(), id },
479 })
480 }
481 SignalIdentifier::GetConfiguration => {
482 parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
483 }
484 SignalIdentifier::Reconfigure => {
485 if body.len() < 3 {
486 return Err(Error::RequestInvalid(ErrorCode::BadLength));
487 }
488 let requested = Request::get_req_capabilities(&body[1..])?;
489 match requested.iter().find(|x| !x.is_application()) {
490 Some(x) => {
491 return Err(Error::RequestInvalidExtra(
492 ErrorCode::InvalidCapabilities,
493 (&x.category()).into(),
494 ));
495 }
496 None => (),
497 };
498 Ok(Request::Reconfigure {
499 local_stream_id: StreamEndpointId::from_msg(&body[0]),
500 capabilities: requested,
501 responder: ConfigureResponder { signal, peer: peer.clone(), id },
502 })
503 }
504 SignalIdentifier::Open => {
505 parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
506 }
507 SignalIdentifier::Start => {
508 let seids = Request::get_req_seids(body)?;
509 Ok(Request::Start {
510 stream_ids: seids,
511 responder: StreamResponder { signal, peer: peer.clone(), id },
512 })
513 }
514 SignalIdentifier::Close => {
515 parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
516 }
517 SignalIdentifier::Suspend => {
518 let seids = Request::get_req_seids(body)?;
519 Ok(Request::Suspend {
520 stream_ids: seids,
521 responder: StreamResponder { signal, peer: peer.clone(), id },
522 })
523 }
524 SignalIdentifier::Abort => {
525 parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
526 }
527 SignalIdentifier::DelayReport => {
528 if body.len() != 3 {
529 return Err(Error::RequestInvalid(ErrorCode::BadLength));
530 }
531 let delay_arr: [u8; 2] = [body[1], body[2]];
532 let delay = u16::from_be_bytes(delay_arr);
533 Ok(Request::DelayReport {
534 stream_id: StreamEndpointId::from_msg(&body[0]),
535 delay,
536 responder: SimpleResponder { signal, peer: peer.clone(), id },
537 })
538 }
539 _ => Err(Error::UnimplementedMessage),
540 }
541 }
542}
543
544#[derive(Debug)]
546pub struct RequestStream {
547 inner: Arc<PeerInner>,
548 terminated: bool,
549}
550
551impl RequestStream {
552 fn new(inner: Arc<PeerInner>) -> Self {
553 Self { inner, terminated: false }
554 }
555}
556
557impl Unpin for RequestStream {}
558
559impl Stream for RequestStream {
560 type Item = Result<Request>;
561
562 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
563 Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
564 Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
565 match Request::parse(&self.inner, label, signal, &body) {
566 Err(Error::RequestInvalid(code)) => {
567 self.inner.send_reject(label, signal, code)?;
568 return Poll::Pending;
569 }
570 Err(Error::RequestInvalidExtra(code, extra)) => {
571 self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
572 return Poll::Pending;
573 }
574 Err(Error::UnimplementedMessage) => {
575 self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
576 return Poll::Pending;
577 }
578 x => Some(x),
579 }
580 }
581 Err(Error::PeerDisconnected) => {
582 self.terminated = true;
583 None
584 }
585 Err(e) => Some(Err(e)),
586 })
587 }
588}
589
590impl FusedStream for RequestStream {
591 fn is_terminated(&self) -> bool {
592 self.terminated
593 }
594}
595
596impl Drop for RequestStream {
597 fn drop(&mut self) {
598 self.inner.incoming_requests.lock().listener = RequestListener::None;
599 self.inner.wake_any();
600 }
601}
602
603#[derive(Debug)]
605pub struct SimpleResponse {}
606
607impl Decodable for SimpleResponse {
608 type Error = Error;
609
610 fn decode(from: &[u8]) -> Result<Self> {
611 if from.len() > 0 {
612 return Err(Error::InvalidMessage);
613 }
614 Ok(SimpleResponse {})
615 }
616}
617
618impl Into<()> for SimpleResponse {
619 fn into(self) -> () {
620 ()
621 }
622}
623
624#[derive(Debug)]
625struct DiscoverResponse {
626 endpoints: Vec<StreamInformation>,
627}
628
629impl Decodable for DiscoverResponse {
630 type Error = Error;
631
632 fn decode(from: &[u8]) -> Result<Self> {
633 let mut endpoints = Vec::<StreamInformation>::new();
634 let mut idx = 0;
635 while idx < from.len() {
636 let endpoint = StreamInformation::decode(&from[idx..])?;
637 idx += endpoint.encoded_len();
638 endpoints.push(endpoint);
639 }
640 Ok(DiscoverResponse { endpoints })
641 }
642}
643
644impl Into<Vec<StreamInformation>> for DiscoverResponse {
645 fn into(self) -> Vec<StreamInformation> {
646 self.endpoints
647 }
648}
649
650#[derive(Debug)]
651pub struct DiscoverResponder {
652 peer: Arc<PeerInner>,
653 id: TxLabel,
654}
655
656impl DiscoverResponder {
657 pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
661 if endpoints.len() == 0 {
662 return Err(Error::Encoding);
664 }
665 let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
666 let mut idx = 0;
667 for endpoint in endpoints {
668 endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
669 idx += endpoint.encoded_len();
670 }
671 self.peer.send_response(self.id, SignalIdentifier::Discover, ¶ms)
672 }
673
674 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
675 self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
676 }
677}
678
679#[derive(Debug)]
680pub struct GetCapabilitiesResponder {
681 peer: Arc<PeerInner>,
682 signal: SignalIdentifier,
683 id: TxLabel,
684}
685
686impl GetCapabilitiesResponder {
687 pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
688 let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
689 let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
690 let mut reply = vec![0 as u8; reply_len];
691 let mut pos = 0;
692 for capability in included_iter {
693 let size = capability.encoded_len();
694 capability.encode(&mut reply[pos..pos + size])?;
695 pos += size;
696 }
697 self.peer.send_response(self.id, self.signal, &reply)
698 }
699
700 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
701 self.peer.send_reject(self.id, self.signal, error_code)
702 }
703}
704
705#[derive(Debug)]
706struct GetCapabilitiesResponse {
707 capabilities: Vec<ServiceCapability>,
708}
709
710impl Decodable for GetCapabilitiesResponse {
711 type Error = Error;
712
713 fn decode(from: &[u8]) -> Result<Self> {
714 let mut capabilities = Vec::<ServiceCapability>::new();
715 let mut idx = 0;
716 while idx < from.len() {
717 match ServiceCapability::decode(&from[idx..]) {
718 Ok(capability) => {
719 idx = idx + capability.encoded_len();
720 capabilities.push(capability);
721 }
722 Err(_) => {
723 info!(
728 "GetCapabilitiesResponse decode: Capability {:?} not supported.",
729 from[idx]
730 );
731 let length_of_capability = from[idx + 1] as usize;
732 idx = idx + 2 + length_of_capability;
733 }
734 }
735 }
736 Ok(GetCapabilitiesResponse { capabilities })
737 }
738}
739
740impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
741 fn into(self) -> Vec<ServiceCapability> {
742 self.capabilities
743 }
744}
745
746#[derive(Debug)]
747pub struct SimpleResponder {
748 peer: Arc<PeerInner>,
749 signal: SignalIdentifier,
750 id: TxLabel,
751}
752
753impl SimpleResponder {
754 pub fn send(self) -> Result<()> {
755 self.peer.send_response(self.id, self.signal, &[])
756 }
757
758 pub fn reject(self, error_code: ErrorCode) -> Result<()> {
759 self.peer.send_reject(self.id, self.signal, error_code)
760 }
761}
762
763#[derive(Debug)]
764pub struct StreamResponder {
765 peer: Arc<PeerInner>,
766 signal: SignalIdentifier,
767 id: TxLabel,
768}
769
770impl StreamResponder {
771 pub fn send(self) -> Result<()> {
772 self.peer.send_response(self.id, self.signal, &[])
773 }
774
775 pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
776 self.peer.send_reject_params(
777 self.id,
778 self.signal,
779 &[stream_id.to_msg(), u8::from(&error_code)],
780 )
781 }
782}
783
784#[derive(Debug)]
785pub struct ConfigureResponder {
786 peer: Arc<PeerInner>,
787 signal: SignalIdentifier,
788 id: TxLabel,
789}
790
791impl ConfigureResponder {
792 pub fn send(self) -> Result<()> {
793 self.peer.send_response(self.id, self.signal, &[])
794 }
795
796 pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
797 self.peer.send_reject_params(
798 self.id,
799 self.signal,
800 &[u8::from(&category), u8::from(&error_code)],
801 )
802 }
803}
804
805#[derive(Debug)]
806struct UnparsedRequest(SignalingHeader, Vec<u8>);
807
808impl UnparsedRequest {
809 fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
810 UnparsedRequest(header, body)
811 }
812}
813
814#[derive(Debug, Default)]
815struct RequestQueue {
816 listener: RequestListener,
817 queue: VecDeque<UnparsedRequest>,
818}
819
820#[derive(Debug)]
821enum RequestListener {
822 None,
824 New,
826 Some(Waker),
828}
829
830impl Default for RequestListener {
831 fn default() -> Self {
832 RequestListener::None
833 }
834}
835
836#[derive(Debug)]
838enum ResponseWaiter {
839 WillPoll,
841 Waiting(Waker),
843 Received(Vec<u8>),
846 Discard,
849}
850
851impl ResponseWaiter {
852 fn is_received(&self) -> bool {
854 if let ResponseWaiter::Received(_) = self { true } else { false }
855 }
856
857 fn unwrap_received(self) -> Vec<u8> {
858 if let ResponseWaiter::Received(buf) = self { buf } else { panic!("expected received buf") }
859 }
860}
861
862fn decode_signaling_response<D: Decodable<Error = Error>>(
863 expected_signal: SignalIdentifier,
864 buf: Vec<u8>,
865) -> Result<D> {
866 let header = SignalingHeader::decode(buf.as_slice())?;
867 if header.signal() != expected_signal {
868 return Err(Error::InvalidHeader);
869 }
870 let params = &buf[header.encoded_len()..];
871 match header.message_type {
872 SignalingMessageType::ResponseAccept => D::decode(params),
873 SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
874 Err(RemoteReject::from_params(header.signal(), params).into())
875 }
876 SignalingMessageType::Command => unreachable!(),
877 }
878}
879
880#[derive(Debug)]
882pub struct CommandResponse {
883 id: TxLabel,
884 inner: Option<Arc<PeerInner>>,
886}
887
888impl Unpin for CommandResponse {}
889
890impl Future for CommandResponse {
891 type Output = Result<Vec<u8>>;
892 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
893 let this = &mut *self;
894 let res;
895 {
896 let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
897 res = client.poll_recv_response(&this.id, cx);
898 }
899
900 if let Poll::Ready(Ok(_)) = res {
901 let inner = this.inner.take().expect("CommandResponse polled after completion");
902 inner.wake_any();
903 }
904
905 res
906 }
907}
908
909impl FusedFuture for CommandResponse {
910 fn is_terminated(&self) -> bool {
911 self.inner.is_none()
912 }
913}
914
915impl Drop for CommandResponse {
916 fn drop(&mut self) {
917 if let Some(inner) = &self.inner {
918 inner.remove_response_interest(&self.id);
919 inner.wake_any();
920 }
921 }
922}
923
924#[derive(Debug)]
925struct PeerInner {
926 signaling: Channel,
928
929 response_waiters: Mutex<Slab<ResponseWaiter>>,
935
936 incoming_requests: Mutex<RequestQueue>,
940}
941
942impl PeerInner {
943 fn add_response_waiter(&self) -> Result<TxLabel> {
946 let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
947 let id = TxLabel::try_from(key as u8);
948 if id.is_err() {
949 warn!("Transaction IDs are exhausted");
950 let _ = self.response_waiters.lock().remove(key);
951 }
952 id
953 }
954
955 fn remove_response_interest(&self, id: &TxLabel) {
958 let mut lock = self.response_waiters.lock();
959 let idx = usize::from(id);
960 if lock[idx].is_received() {
961 let _ = lock.remove(idx);
962 } else {
963 lock[idx] = ResponseWaiter::Discard;
964 }
965 }
966
967 fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
972 let is_closed = self.recv_all(cx)?;
973
974 let mut lock = self.incoming_requests.lock();
975
976 if let Some(request) = lock.queue.pop_front() {
977 Poll::Ready(Ok(request))
978 } else {
979 lock.listener = RequestListener::Some(cx.waker().clone());
980 if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
981 }
982 }
983
984 fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
989 let is_closed = self.recv_all(cx)?;
990
991 let mut waiters = self.response_waiters.lock();
992 let idx = usize::from(label);
993 if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
996 let buf = waiters.remove(idx).unwrap_received();
998 Poll::Ready(Ok(buf))
999 } else {
1000 *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1002 ResponseWaiter::Waiting(cx.waker().clone());
1003
1004 if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
1005 }
1006 }
1007
1008 fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1012 loop {
1013 let mut next_packet = Vec::new();
1014 let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1015 Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1016 trace!("Signaling peer closed");
1017 return Ok(true);
1018 }
1019 Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1020 Poll::Pending => return Ok(false),
1021 Poll::Ready(Ok(size)) => size,
1022 };
1023 if packet_size == 0 {
1024 continue;
1025 }
1026 let header = match SignalingHeader::decode(next_packet.as_slice()) {
1030 Err(Error::InvalidSignalId(label, id)) => {
1031 self.send_general_reject(label, id)?;
1032 continue;
1033 }
1034 Err(_) => {
1035 info!("received unrejectable message");
1038 continue;
1039 }
1040 Ok(x) => x,
1041 };
1042 if header.is_command() {
1044 let mut lock = self.incoming_requests.lock();
1045 let body = next_packet.split_off(header.encoded_len());
1046 lock.queue.push_back(UnparsedRequest::new(header, body));
1047 if let RequestListener::Some(ref waker) = lock.listener {
1048 waker.wake_by_ref();
1049 }
1050 } else {
1051 let mut waiters = self.response_waiters.lock();
1053 let idx = usize::from(&header.label());
1054 if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1055 let _ = waiters.remove(idx);
1056 } else if let Some(entry) = waiters.get_mut(idx) {
1057 let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1058 if let ResponseWaiter::Waiting(waker) = old_entry {
1059 waker.wake();
1060 }
1061 } else {
1062 warn!("response for {:?} we did not send, dropping", header.label());
1063 }
1064 }
1066 }
1067 }
1068
1069 fn wake_any(&self) {
1072 {
1077 let lock = self.response_waiters.lock();
1078 for (_, response_waiter) in lock.iter() {
1079 if let ResponseWaiter::Waiting(waker) = response_waiter {
1080 waker.wake_by_ref();
1081 return;
1082 }
1083 }
1084 }
1085 {
1086 let lock = self.incoming_requests.lock();
1087 if let RequestListener::Some(waker) = &lock.listener {
1088 waker.wake_by_ref();
1089 return;
1090 }
1091 }
1092 }
1093
1094 fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1096 let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1099 self.send_signal(packet)
1100 }
1101
1102 fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1103 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1104 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1105 header.encode(packet.as_mut_slice())?;
1106 packet[header.encoded_len()..].clone_from_slice(params);
1107 self.send_signal(&packet)
1108 }
1109
1110 fn send_reject(
1111 &self,
1112 label: TxLabel,
1113 signal: SignalIdentifier,
1114 error_code: ErrorCode,
1115 ) -> Result<()> {
1116 self.send_reject_params(label, signal, &[u8::from(&error_code)])
1117 }
1118
1119 fn send_reject_params(
1120 &self,
1121 label: TxLabel,
1122 signal: SignalIdentifier,
1123 params: &[u8],
1124 ) -> Result<()> {
1125 let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1126 let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1127 header.encode(packet.as_mut_slice())?;
1128 packet[header.encoded_len()..].clone_from_slice(params);
1129 self.send_signal(&packet)
1130 }
1131
1132 fn send_signal(&self, data: &[u8]) -> Result<()> {
1133 let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1134 Ok(())
1135 }
1136}