Skip to main content

bt_avdtp/
lib.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream, StreamExt};
10use futures::task::{Context, Poll, Waker};
11use futures::{Future, FutureExt, TryFutureExt, ready};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::atomic::{AtomicBool, Ordering};
21use zx::{self as zx, MonotonicDuration};
22
23mod rtp;
24mod stream_endpoint;
25#[cfg(test)]
26mod tests;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33    MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36    ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37    Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40/// An AVDTP signaling peer can send commands to another peer, receive requests and send responses.
41/// Media transport is not handled by this peer.
42///
43/// Requests from the distant peer are delivered through the request stream available through
44/// take_request_stream().  Only one RequestStream can be active at a time.  Only valid requests
45/// are sent to the request stream - invalid formats are automatically rejected.
46///
47/// Responses are sent using responders that are included in the request stream from the connected
48/// peer.
49#[derive(Debug, Clone)]
50pub struct Peer {
51    inner: Arc<PeerInner>,
52}
53
54impl Peer {
55    /// Create a new peer from a signaling channel.
56    pub fn new(signaling: Channel) -> Self {
57        Self {
58            inner: Arc::new(PeerInner {
59                signaling: Mutex::new(signaling),
60                signaling_channel_closed: AtomicBool::new(false),
61                response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
62                incoming_requests: Mutex::<RequestQueue>::default(),
63            }),
64        }
65    }
66
67    /// Take the event listener for this peer. Panics if the stream is already
68    /// held.
69    #[track_caller]
70    pub fn take_request_stream(&self) -> RequestStream {
71        {
72            let mut lock = self.inner.incoming_requests.lock();
73            if let RequestListener::None = lock.listener {
74                lock.listener = RequestListener::New;
75            } else {
76                panic!("Request stream has already been taken");
77            }
78        }
79
80        RequestStream::new(self.inner.clone())
81    }
82
83    /// Send a Stream End Point Discovery (Sec 8.6) command to the remote peer.
84    /// Asynchronously returns a the reply in a vector of endpoint information.
85    /// Error will be RemoteRejected if the remote peer rejected the command.
86    pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> + use<> {
87        self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
88    }
89
90    /// Send a Get Capabilities (Sec 8.7) command to the remote peer for the
91    /// given `stream_id`.
92    /// Asynchronously returns the reply which contains the ServiceCapabilities
93    /// reported.
94    /// In general, Get All Capabilities should be preferred to this command if is supported.
95    /// Error will be RemoteRejected if the remote peer rejects the command.
96    pub fn get_capabilities(
97        &self,
98        stream_id: &StreamEndpointId,
99    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
100        let stream_params = &[stream_id.to_msg()];
101        self.send_command::<GetCapabilitiesResponse>(
102            SignalIdentifier::GetCapabilities,
103            stream_params,
104        )
105        .ok_into()
106    }
107
108    /// Send a Get All Capabilities (Sec 8.8) command to the remote peer for the
109    /// given `stream_id`.
110    /// Asynchronously returns the reply which contains the ServiceCapabilities
111    /// reported.
112    /// Error will be RemoteRejected if the remote peer rejects the command.
113    pub fn get_all_capabilities(
114        &self,
115        stream_id: &StreamEndpointId,
116    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
117        let stream_params = &[stream_id.to_msg()];
118        self.send_command::<GetCapabilitiesResponse>(
119            SignalIdentifier::GetAllCapabilities,
120            stream_params,
121        )
122        .ok_into()
123    }
124
125    /// Send a Stream Configuration (Sec 8.9) command to the remote peer for the
126    /// given remote `stream_id`, communicating the association to a local
127    /// `local_stream_id` and the required stream `capabilities`.
128    /// Panics if `capabilities` is empty.
129    /// Error will be RemoteRejected if the remote refused.
130    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
131    pub fn set_configuration(
132        &self,
133        stream_id: &StreamEndpointId,
134        local_stream_id: &StreamEndpointId,
135        capabilities: &[ServiceCapability],
136    ) -> impl Future<Output = Result<()>> + use<> {
137        assert!(!capabilities.is_empty(), "must set at least one capability");
138        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
139        params[0] = stream_id.to_msg();
140        params[1] = local_stream_id.to_msg();
141        let mut idx = 2;
142        for capability in capabilities {
143            if let Err(e) = capability.encode(&mut params[idx..]) {
144                return futures::future::err(e).left_future();
145            }
146            idx += capability.encoded_len();
147        }
148        self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, &params)
149            .ok_into()
150            .right_future()
151    }
152
153    /// Send a Get Stream Configuration (Sec 8.10) command to the remote peer
154    /// for the given remote `stream_id`.
155    /// Asynchronously returns the set of ServiceCapabilities previously
156    /// configured between these two peers.
157    /// Error will be RemoteRejected if the remote peer rejects this command.
158    pub fn get_configuration(
159        &self,
160        stream_id: &StreamEndpointId,
161    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
162        let stream_params = &[stream_id.to_msg()];
163        self.send_command::<GetCapabilitiesResponse>(
164            SignalIdentifier::GetConfiguration,
165            stream_params,
166        )
167        .ok_into()
168    }
169
170    /// Send a Stream Reconfigure (Sec 8.11) command to the remote peer for the
171    /// given remote `stream_id`, to reconfigure the Application Service
172    /// capabilities in `capabilities`.
173    /// Note: Per the spec, only the Media Codec and Content Protection
174    /// capabilities will be accepted in this command.
175    /// Panics if there are no capabilities to configure.
176    /// Error will be RemoteRejected if the remote refused.
177    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
178    pub fn reconfigure(
179        &self,
180        stream_id: &StreamEndpointId,
181        capabilities: &[ServiceCapability],
182    ) -> impl Future<Output = Result<()>> + use<> {
183        assert!(!capabilities.is_empty(), "must set at least one capability");
184        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
185        params[0] = stream_id.to_msg();
186        let mut idx = 1;
187        for capability in capabilities {
188            if !capability.is_application() {
189                return futures::future::err(Error::Encoding).left_future();
190            }
191            if let Err(e) = capability.encode(&mut params[idx..]) {
192                return futures::future::err(e).left_future();
193            }
194            idx += capability.encoded_len();
195        }
196        self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, &params)
197            .ok_into()
198            .right_future()
199    }
200
201    /// Send a Open Stream Command (Sec 8.12) to the remote peer for the given
202    /// `stream_id`.
203    /// Error will be RemoteRejected if the remote peer rejects the command.
204    pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
205        let stream_params = &[stream_id.to_msg()];
206        self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
207    }
208
209    /// Send a Start Stream Command (Sec 8.13) to the remote peer for all the streams in
210    /// `stream_ids`.
211    /// Returns Ok(()) if the command is accepted, and RemoteStreamRejected with the stream
212    /// endpoint id and error code reported by the remote if the remote signals a failure.
213    pub fn start(
214        &self,
215        stream_ids: &[StreamEndpointId],
216    ) -> impl Future<Output = Result<()>> + use<> {
217        let mut stream_params = Vec::with_capacity(stream_ids.len());
218        for stream_id in stream_ids {
219            stream_params.push(stream_id.to_msg());
220        }
221        self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
222    }
223
224    /// Send a Close Stream Command (Sec 8.14) to the remote peer for the given `stream_id`.
225    /// Error will be RemoteRejected if the remote peer rejects the command.
226    pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
227        let stream_params = &[stream_id.to_msg()];
228        let response: CommandResponseFut<SimpleResponse> =
229            self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
230        response.ok_into()
231    }
232
233    /// Send a Suspend Command (Sec 8.15) to the remote peer for all the streams in `stream_ids`.
234    /// Error will be RemoteRejected if the remote refused, with the stream endpoint identifier
235    /// indicated by the remote set in the RemoteReject.
236    pub fn suspend(
237        &self,
238        stream_ids: &[StreamEndpointId],
239    ) -> impl Future<Output = Result<()>> + use<> {
240        let mut stream_params = Vec::with_capacity(stream_ids.len());
241        for stream_id in stream_ids {
242            stream_params.push(stream_id.to_msg());
243        }
244        let response: CommandResponseFut<SimpleResponse> =
245            self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
246        response.ok_into()
247    }
248
249    /// Send an Abort (Sec 8.16) to the remote peer for the given `stream_id`.
250    /// Returns Ok(()) if the command is accepted, and Err(Timeout) if the remote
251    /// timed out.  The remote peer is not allowed to reject this command, and
252    /// commands that have invalid `stream_id` will timeout instead.
253    pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
254        let stream_params = &[stream_id.to_msg()];
255        self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
256    }
257
258    /// Send a Delay Report (Sec 8.19) to the remote peer for the given `stream_id`.
259    /// `delay` is in tenths of milliseconds.
260    /// Error will be RemoteRejected if the remote peer rejects the command.
261    pub fn delay_report(
262        &self,
263        stream_id: &StreamEndpointId,
264        delay: u16,
265    ) -> impl Future<Output = Result<()>> + use<> {
266        let delay_bytes: [u8; 2] = delay.to_be_bytes();
267        let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
268        self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
269    }
270
271    /// The maximum amount of time we will wait for a response to a signaling command.
272    const RTX_SIG_TIMER_MS: i64 = 3000;
273    const COMMAND_TIMEOUT: MonotonicDuration =
274        MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
275
276    /// Sends a signal on the channel and receive a future that will complete
277    /// when we get the expected response.
278    fn send_command<D: Decodable<Error = Error>>(
279        &self,
280        signal: SignalIdentifier,
281        payload: &[u8],
282    ) -> CommandResponseFut<D> {
283        let send_result = (|| {
284            let id = self.inner.add_response_waiter()?;
285            let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
286            let mut buf = vec![0; header.encoded_len()];
287            header.encode(buf.as_mut_slice())?;
288            buf.extend_from_slice(payload);
289            self.inner.send_signal(buf.as_slice())?;
290            Ok(header)
291        })();
292
293        CommandResponseFut::new(send_result, self.inner.clone())
294    }
295}
296
297/// A future representing the result of a AVDTP command. Decodes the response when it arrives.
298struct CommandResponseFut<D: Decodable> {
299    id: SignalIdentifier,
300    fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
301    _phantom: PhantomData<D>,
302}
303
304impl<D: Decodable> Unpin for CommandResponseFut<D> {}
305
306impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
307    fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
308        let header = match send_result {
309            Err(e) => {
310                return Self {
311                    id: SignalIdentifier::Abort,
312                    fut: Box::pin(MaybeDone::Done(Err(e))),
313                    _phantom: PhantomData,
314                };
315            }
316            Ok(header) => header,
317        };
318        let response = CommandResponse { id: header.label(), inner: Some(inner) };
319        let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
320        let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
321
322        Self {
323            id: header.signal(),
324            fut: Box::pin(futures::future::maybe_done(timedout_fut)),
325            _phantom: PhantomData,
326        }
327    }
328}
329
330impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
331    type Output = Result<D>;
332
333    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
334        ready!(self.fut.poll_unpin(cx));
335        Poll::Ready(
336            self.fut
337                .as_mut()
338                .take_output()
339                .unwrap_or(Err(Error::AlreadyReceived))
340                .and_then(|buf| decode_signaling_response(self.id, buf)),
341        )
342    }
343}
344
345/// A request from the connected peer.
346/// Each variant of this includes a responder which implements two functions:
347///  - send(...) will send a response with the information provided.
348///  - reject(ErrorCode) will send an reject response with the given error code.
349#[derive(Debug)]
350pub enum Request {
351    Discover {
352        responder: DiscoverResponder,
353    },
354    GetCapabilities {
355        stream_id: StreamEndpointId,
356        responder: GetCapabilitiesResponder,
357    },
358    GetAllCapabilities {
359        stream_id: StreamEndpointId,
360        responder: GetCapabilitiesResponder,
361    },
362    SetConfiguration {
363        local_stream_id: StreamEndpointId,
364        remote_stream_id: StreamEndpointId,
365        capabilities: Vec<ServiceCapability>,
366        responder: ConfigureResponder,
367    },
368    GetConfiguration {
369        stream_id: StreamEndpointId,
370        responder: GetCapabilitiesResponder,
371    },
372    Reconfigure {
373        local_stream_id: StreamEndpointId,
374        capabilities: Vec<ServiceCapability>,
375        responder: ConfigureResponder,
376    },
377    Open {
378        stream_id: StreamEndpointId,
379        responder: SimpleResponder,
380    },
381    Start {
382        stream_ids: Vec<StreamEndpointId>,
383        responder: StreamResponder,
384    },
385    Close {
386        stream_id: StreamEndpointId,
387        responder: SimpleResponder,
388    },
389    Suspend {
390        stream_ids: Vec<StreamEndpointId>,
391        responder: StreamResponder,
392    },
393    Abort {
394        stream_id: StreamEndpointId,
395        responder: SimpleResponder,
396    },
397    DelayReport {
398        stream_id: StreamEndpointId,
399        delay: u16,
400        responder: SimpleResponder,
401    }, // TODO(jamuraa): add the rest of the requests
402}
403
404macro_rules! parse_one_seid {
405    ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
406        if $body.len() != 1 {
407            Err(Error::RequestInvalid(ErrorCode::BadLength))
408        } else {
409            Ok(Request::$request_variant {
410                stream_id: StreamEndpointId::from_msg(&$body[0]),
411                responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
412            })
413        }
414    };
415}
416
417impl Request {
418    fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
419        if body.len() < 1 {
420            return Err(Error::RequestInvalid(ErrorCode::BadLength));
421        }
422        Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
423    }
424
425    fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
426        if encoded.len() < 2 {
427            return Err(Error::RequestInvalid(ErrorCode::BadLength));
428        }
429        let mut caps = vec![];
430        let mut loc = 0;
431        while loc < encoded.len() {
432            let cap = match ServiceCapability::decode(&encoded[loc..]) {
433                Ok(cap) => cap,
434                Err(Error::RequestInvalid(code)) => {
435                    return Err(Error::RequestInvalidExtra(code, encoded[loc]));
436                }
437                Err(e) => return Err(e),
438            };
439            loc += cap.encoded_len();
440            caps.push(cap);
441        }
442        Ok(caps)
443    }
444
445    fn parse(
446        peer: &Arc<PeerInner>,
447        id: TxLabel,
448        signal: SignalIdentifier,
449        body: &[u8],
450    ) -> Result<Request> {
451        match signal {
452            SignalIdentifier::Discover => {
453                // Discover Request has no body (Sec 8.6.1)
454                if body.len() > 0 {
455                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
456                }
457                Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
458            }
459            SignalIdentifier::GetCapabilities => {
460                parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
461            }
462            SignalIdentifier::GetAllCapabilities => parse_one_seid!(
463                body,
464                signal,
465                peer,
466                id,
467                GetAllCapabilities,
468                GetCapabilitiesResponder
469            ),
470            SignalIdentifier::SetConfiguration => {
471                if body.len() < 4 {
472                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
473                }
474                let requested = Request::get_req_capabilities(&body[2..])?;
475                Ok(Request::SetConfiguration {
476                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
477                    remote_stream_id: StreamEndpointId::from_msg(&body[1]),
478                    capabilities: requested,
479                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
480                })
481            }
482            SignalIdentifier::GetConfiguration => {
483                parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
484            }
485            SignalIdentifier::Reconfigure => {
486                if body.len() < 3 {
487                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
488                }
489                let requested = Request::get_req_capabilities(&body[1..])?;
490                match requested.iter().find(|x| !x.is_application()) {
491                    Some(x) => {
492                        return Err(Error::RequestInvalidExtra(
493                            ErrorCode::InvalidCapabilities,
494                            (&x.category()).into(),
495                        ));
496                    }
497                    None => (),
498                };
499                Ok(Request::Reconfigure {
500                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
501                    capabilities: requested,
502                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
503                })
504            }
505            SignalIdentifier::Open => {
506                parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
507            }
508            SignalIdentifier::Start => {
509                let seids = Request::get_req_seids(body)?;
510                Ok(Request::Start {
511                    stream_ids: seids,
512                    responder: StreamResponder { signal, peer: peer.clone(), id },
513                })
514            }
515            SignalIdentifier::Close => {
516                parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
517            }
518            SignalIdentifier::Suspend => {
519                let seids = Request::get_req_seids(body)?;
520                Ok(Request::Suspend {
521                    stream_ids: seids,
522                    responder: StreamResponder { signal, peer: peer.clone(), id },
523                })
524            }
525            SignalIdentifier::Abort => {
526                parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
527            }
528            SignalIdentifier::DelayReport => {
529                if body.len() != 3 {
530                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
531                }
532                let delay_arr: [u8; 2] = [body[1], body[2]];
533                let delay = u16::from_be_bytes(delay_arr);
534                Ok(Request::DelayReport {
535                    stream_id: StreamEndpointId::from_msg(&body[0]),
536                    delay,
537                    responder: SimpleResponder { signal, peer: peer.clone(), id },
538                })
539            }
540            _ => Err(Error::UnimplementedMessage),
541        }
542    }
543}
544
545/// A stream of requests from the remote peer.
546#[derive(Debug)]
547pub struct RequestStream {
548    inner: Arc<PeerInner>,
549    terminated: bool,
550}
551
552impl RequestStream {
553    fn new(inner: Arc<PeerInner>) -> Self {
554        Self { inner, terminated: false }
555    }
556}
557
558impl Unpin for RequestStream {}
559
560impl Stream for RequestStream {
561    type Item = Result<Request>;
562
563    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
564        Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
565            Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
566                match Request::parse(&self.inner, label, signal, &body) {
567                    Err(Error::RequestInvalid(code)) => {
568                        self.inner.send_reject(label, signal, code)?;
569                        return Poll::Pending;
570                    }
571                    Err(Error::RequestInvalidExtra(code, extra)) => {
572                        self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
573                        return Poll::Pending;
574                    }
575                    Err(Error::UnimplementedMessage) => {
576                        self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
577                        return Poll::Pending;
578                    }
579                    x => Some(x),
580                }
581            }
582            Err(Error::PeerDisconnected) => {
583                self.terminated = true;
584                None
585            }
586            Err(e) => Some(Err(e)),
587        })
588    }
589}
590
591impl FusedStream for RequestStream {
592    fn is_terminated(&self) -> bool {
593        self.terminated
594    }
595}
596
597impl Drop for RequestStream {
598    fn drop(&mut self) {
599        self.inner.incoming_requests.lock().listener = RequestListener::None;
600        self.inner.wake_any();
601    }
602}
603
604// Simple responses have no body data.
605#[derive(Debug)]
606pub struct SimpleResponse {}
607
608impl Decodable for SimpleResponse {
609    type Error = Error;
610
611    fn decode(from: &[u8]) -> Result<Self> {
612        if from.len() > 0 {
613            return Err(Error::InvalidMessage);
614        }
615        Ok(SimpleResponse {})
616    }
617}
618
619impl Into<()> for SimpleResponse {
620    fn into(self) -> () {
621        ()
622    }
623}
624
625#[derive(Debug)]
626struct DiscoverResponse {
627    endpoints: Vec<StreamInformation>,
628}
629
630impl Decodable for DiscoverResponse {
631    type Error = Error;
632
633    fn decode(from: &[u8]) -> Result<Self> {
634        let mut endpoints = Vec::<StreamInformation>::new();
635        let mut idx = 0;
636        while idx < from.len() {
637            let endpoint = StreamInformation::decode(&from[idx..])?;
638            idx += endpoint.encoded_len();
639            endpoints.push(endpoint);
640        }
641        Ok(DiscoverResponse { endpoints })
642    }
643}
644
645impl Into<Vec<StreamInformation>> for DiscoverResponse {
646    fn into(self) -> Vec<StreamInformation> {
647        self.endpoints
648    }
649}
650
651#[derive(Debug)]
652pub struct DiscoverResponder {
653    peer: Arc<PeerInner>,
654    id: TxLabel,
655}
656
657impl DiscoverResponder {
658    /// Sends the response to a discovery request.
659    /// At least one endpoint must be present.
660    /// Will result in a Error::PeerWrite if the distant peer is disconnected.
661    pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
662        if endpoints.len() == 0 {
663            // There shall be at least one SEP in a response (Sec 8.6.2)
664            return Err(Error::Encoding);
665        }
666        let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
667        let mut idx = 0;
668        for endpoint in endpoints {
669            endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
670            idx += endpoint.encoded_len();
671        }
672        self.peer.send_response(self.id, SignalIdentifier::Discover, &params)
673    }
674
675    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
676        self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
677    }
678}
679
680#[derive(Debug)]
681pub struct GetCapabilitiesResponder {
682    peer: Arc<PeerInner>,
683    signal: SignalIdentifier,
684    id: TxLabel,
685}
686
687impl GetCapabilitiesResponder {
688    pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
689        let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
690        let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
691        let mut reply = vec![0 as u8; reply_len];
692        let mut pos = 0;
693        for capability in included_iter {
694            let size = capability.encoded_len();
695            capability.encode(&mut reply[pos..pos + size])?;
696            pos += size;
697        }
698        self.peer.send_response(self.id, self.signal, &reply)
699    }
700
701    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
702        self.peer.send_reject(self.id, self.signal, error_code)
703    }
704}
705
706#[derive(Debug)]
707struct GetCapabilitiesResponse {
708    capabilities: Vec<ServiceCapability>,
709}
710
711impl Decodable for GetCapabilitiesResponse {
712    type Error = Error;
713
714    fn decode(from: &[u8]) -> Result<Self> {
715        let mut capabilities = Vec::<ServiceCapability>::new();
716        let mut idx = 0;
717        while idx < from.len() {
718            match ServiceCapability::decode(&from[idx..]) {
719                Ok(capability) => {
720                    idx = idx + capability.encoded_len();
721                    capabilities.push(capability);
722                }
723                Err(_) => {
724                    // The capability length of the invalid capability can be nonzero.
725                    // Advance `idx` by the payload amount, but don't push the invalid capability.
726                    // Increment by 1 byte for ServiceCategory, 1 byte for payload length,
727                    // `length_of_capability` bytes for capability length.
728                    info!(
729                        "GetCapabilitiesResponse decode: Capability {:?} not supported.",
730                        from[idx]
731                    );
732                    let length_of_capability = from[idx + 1] as usize;
733                    idx = idx + 2 + length_of_capability;
734                }
735            }
736        }
737        Ok(GetCapabilitiesResponse { capabilities })
738    }
739}
740
741impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
742    fn into(self) -> Vec<ServiceCapability> {
743        self.capabilities
744    }
745}
746
747#[derive(Debug)]
748pub struct SimpleResponder {
749    peer: Arc<PeerInner>,
750    signal: SignalIdentifier,
751    id: TxLabel,
752}
753
754impl SimpleResponder {
755    pub fn send(self) -> Result<()> {
756        self.peer.send_response(self.id, self.signal, &[])
757    }
758
759    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
760        self.peer.send_reject(self.id, self.signal, error_code)
761    }
762}
763
764#[derive(Debug)]
765pub struct StreamResponder {
766    peer: Arc<PeerInner>,
767    signal: SignalIdentifier,
768    id: TxLabel,
769}
770
771impl StreamResponder {
772    pub fn send(self) -> Result<()> {
773        self.peer.send_response(self.id, self.signal, &[])
774    }
775
776    pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
777        self.peer.send_reject_params(
778            self.id,
779            self.signal,
780            &[stream_id.to_msg(), u8::from(&error_code)],
781        )
782    }
783}
784
785#[derive(Debug)]
786pub struct ConfigureResponder {
787    peer: Arc<PeerInner>,
788    signal: SignalIdentifier,
789    id: TxLabel,
790}
791
792impl ConfigureResponder {
793    pub fn send(self) -> Result<()> {
794        self.peer.send_response(self.id, self.signal, &[])
795    }
796
797    pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
798        self.peer.send_reject_params(
799            self.id,
800            self.signal,
801            &[u8::from(&category), u8::from(&error_code)],
802        )
803    }
804}
805
806#[derive(Debug)]
807struct UnparsedRequest(SignalingHeader, Vec<u8>);
808
809impl UnparsedRequest {
810    fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
811        UnparsedRequest(header, body)
812    }
813}
814
815#[derive(Debug, Default)]
816struct RequestQueue {
817    listener: RequestListener,
818    queue: VecDeque<UnparsedRequest>,
819}
820
821#[derive(Debug)]
822enum RequestListener {
823    /// No one is listening.
824    None,
825    /// Someone wants to listen but hasn't polled.
826    New,
827    /// Someone is listening, and can be woken with the waker.
828    Some(Waker),
829}
830
831impl Default for RequestListener {
832    fn default() -> Self {
833        RequestListener::None
834    }
835}
836
837/// An enum representing an interest in the response to a command.
838#[derive(Debug)]
839enum ResponseWaiter {
840    /// A new waiter which hasn't been polled yet.
841    WillPoll,
842    /// A task waiting for a response, which can be woken with the waker.
843    Waiting(Waker),
844    /// A response that has been received, stored here until it's polled, at
845    /// which point it will be decoded.
846    Received(Vec<u8>),
847    /// It's still waiting on the response, but the receiver has decided they
848    /// don't care and we'll throw it out.
849    Discard,
850}
851
852impl ResponseWaiter {
853    /// Check if a message has been received.
854    fn is_received(&self) -> bool {
855        if let ResponseWaiter::Received(_) = self { true } else { false }
856    }
857
858    fn unwrap_received(self) -> Vec<u8> {
859        if let ResponseWaiter::Received(buf) = self { buf } else { panic!("expected received buf") }
860    }
861}
862
863fn decode_signaling_response<D: Decodable<Error = Error>>(
864    expected_signal: SignalIdentifier,
865    buf: Vec<u8>,
866) -> Result<D> {
867    let header = SignalingHeader::decode(buf.as_slice())?;
868    if header.signal() != expected_signal {
869        return Err(Error::InvalidHeader);
870    }
871    let params = &buf[header.encoded_len()..];
872    match header.message_type {
873        SignalingMessageType::ResponseAccept => D::decode(params),
874        SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
875            Err(RemoteReject::from_params(header.signal(), params).into())
876        }
877        SignalingMessageType::Command => unreachable!(),
878    }
879}
880
881/// A future that polls for the response to a command we sent.
882#[derive(Debug)]
883pub struct CommandResponse {
884    id: TxLabel,
885    // Some(x) if we're still waiting on the response.
886    inner: Option<Arc<PeerInner>>,
887}
888
889impl Unpin for CommandResponse {}
890
891impl Future for CommandResponse {
892    type Output = Result<Vec<u8>>;
893    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
894        let this = &mut *self;
895        let res;
896        {
897            let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
898            res = client.poll_recv_response(&this.id, cx);
899        }
900
901        if let Poll::Ready(Ok(_)) = res {
902            let inner = this.inner.take().expect("CommandResponse polled after completion");
903            inner.wake_any();
904        }
905
906        res
907    }
908}
909
910impl FusedFuture for CommandResponse {
911    fn is_terminated(&self) -> bool {
912        self.inner.is_none()
913    }
914}
915
916impl Drop for CommandResponse {
917    fn drop(&mut self) {
918        if let Some(inner) = &self.inner {
919            inner.remove_response_interest(&self.id);
920            inner.wake_any();
921        }
922    }
923}
924
925#[derive(Debug)]
926struct PeerInner {
927    /// The signaling channel
928    signaling: Mutex<Channel>,
929    signaling_channel_closed: AtomicBool,
930
931    /// A map of transaction ids that have been sent but the response has not
932    /// been received and/or processed yet.
933    ///
934    /// Waiters are added with `add_response_waiter` and get removed when they are
935    /// polled or they are removed with `remove_waiter`
936    response_waiters: Mutex<Slab<ResponseWaiter>>,
937
938    /// A queue of requests that have been received and are waiting to
939    /// be responded to, along with the waker for the task that has
940    /// taken the request receiver (if it exists)
941    incoming_requests: Mutex<RequestQueue>,
942}
943
944impl PeerInner {
945    /// Add a response waiter, and return a id that can be used to send the
946    /// transaction.  Responses then can be received using poll_recv_response
947    fn add_response_waiter(&self) -> Result<TxLabel> {
948        let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
949        let id = TxLabel::try_from(key as u8);
950        if id.is_err() {
951            warn!("Transaction IDs are exhausted");
952            let _ = self.response_waiters.lock().remove(key);
953        }
954        id
955    }
956
957    /// When a waiter isn't interested in the response anymore, we need to just
958    /// throw it out.  This is called when the response future is dropped.
959    fn remove_response_interest(&self, id: &TxLabel) {
960        let mut lock = self.response_waiters.lock();
961        let idx = usize::from(id);
962        if lock[idx].is_received() {
963            let _ = lock.remove(idx);
964        } else {
965            lock[idx] = ResponseWaiter::Discard;
966        }
967    }
968
969    // Attempts to receive a new request by processing all packets on the socket.
970    // Resolves to an unprocessed request (header, body) if one was received.
971    // Resolves to an error if there was an error reading from the socket or if the peer
972    // disconnected.
973    fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
974        let is_closed = self.recv_all(cx)?;
975
976        let mut lock = self.incoming_requests.lock();
977
978        if let Some(request) = lock.queue.pop_front() {
979            Poll::Ready(Ok(request))
980        } else {
981            lock.listener = RequestListener::Some(cx.waker().clone());
982            if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
983        }
984    }
985
986    // Attempts to receive a response to a request by processing all packets on the socket.
987    // Resolves to the bytes in the response body if one was received.
988    // Resolves to an error if there was an error reading from the socket, if the peer
989    // disconnected, or if the |label| is not being waited on.
990    fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
991        let is_closed = self.recv_all(cx)?;
992
993        let mut waiters = self.response_waiters.lock();
994        let idx = usize::from(label);
995        // We expect() below because the label above came from an internally-created object,
996        // so the waiters should always exist in the map.
997        if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
998            // We got our response.
999            let buf = waiters.remove(idx).unwrap_received();
1000            Poll::Ready(Ok(buf))
1001        } else {
1002            // Set the waker to be notified when a response shows up.
1003            *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1004                ResponseWaiter::Waiting(cx.waker().clone());
1005
1006            if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
1007        }
1008    }
1009
1010    /// Poll for any packets on the signaling socket
1011    /// Returns whether the channel was closed, or an Error::PeerRead or Error::PeerWrite
1012    /// if there was a problem communicating on the socket.
1013    fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1014        if self.signaling_channel_closed.load(Ordering::Relaxed) {
1015            return Ok(true);
1016        }
1017        loop {
1018            let mut next_packet = {
1019                let mut signaling = self.signaling.lock();
1020                match signaling.poll_next_unpin(cx) {
1021                    Poll::Ready(Some(Ok(packet))) => packet,
1022                    Poll::Ready(Some(Err(zx::Status::PEER_CLOSED))) | Poll::Ready(None) => {
1023                        trace!("Signaling peer closed");
1024                        self.signaling_channel_closed.store(true, Ordering::Relaxed);
1025                        return Ok(true);
1026                    }
1027                    Poll::Ready(Some(Err(e))) => return Err(Error::PeerRead(e)),
1028                    Poll::Pending => return Ok(false),
1029                }
1030            };
1031
1032            // Detects General Reject condition and sends the response back.
1033            // On other headers with errors, sends BAD_HEADER to the peer
1034            // and attempts to continue.
1035            let header = match SignalingHeader::decode(next_packet.as_slice()) {
1036                Err(Error::InvalidSignalId(label, id)) => {
1037                    self.send_general_reject(label, id)?;
1038                    continue;
1039                }
1040                Err(_) => {
1041                    // Only possible other return is OutOfRange
1042                    // Returned only when the packet is too small, can't make a meaningful reject.
1043                    info!("received unrejectable message");
1044                    continue;
1045                }
1046                Ok(x) => x,
1047            };
1048            // Commands from the remote get translated into requests.
1049            if header.is_command() {
1050                let mut lock = self.incoming_requests.lock();
1051                let body = next_packet.split_off(header.encoded_len());
1052                lock.queue.push_back(UnparsedRequest::new(header, body));
1053                if let RequestListener::Some(ref waker) = lock.listener {
1054                    waker.wake_by_ref();
1055                }
1056            } else {
1057                // Should be a response to a command we sent
1058                let mut waiters = self.response_waiters.lock();
1059                let idx = usize::from(&header.label());
1060                if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1061                    let _ = waiters.remove(idx);
1062                } else if let Some(entry) = waiters.get_mut(idx) {
1063                    let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1064                    if let ResponseWaiter::Waiting(waker) = old_entry {
1065                        waker.wake();
1066                    }
1067                } else {
1068                    warn!("response for {:?} we did not send, dropping", header.label());
1069                }
1070                // Note: we drop any TxLabel response we are not waiting for
1071            }
1072        }
1073    }
1074
1075    // Wakes up an arbitrary task that has begun polling on the channel so that
1076    // it will call recv_all and be registered as the new channel reader.
1077    fn wake_any(&self) {
1078        // Try to wake up response waiters first, rather than the event listener.
1079        // The event listener is a stream, and so could be between poll_nexts,
1080        // Response waiters should always be actively polled once
1081        // they've begun being polled on a task.
1082        {
1083            let lock = self.response_waiters.lock();
1084            for (_, response_waiter) in lock.iter() {
1085                if let ResponseWaiter::Waiting(waker) = response_waiter {
1086                    waker.wake_by_ref();
1087                    return;
1088                }
1089            }
1090        }
1091        {
1092            let lock = self.incoming_requests.lock();
1093            if let RequestListener::Some(waker) = &lock.listener {
1094                waker.wake_by_ref();
1095                return;
1096            }
1097        }
1098    }
1099
1100    // Build and send a General Reject message (Section 8.18)
1101    fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1102        // Build the packet ourselves rather than make SignalingHeader build an packet with an
1103        // invalid signal id.
1104        let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1105        self.send_signal(packet)
1106    }
1107
1108    fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1109        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1110        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1111        header.encode(packet.as_mut_slice())?;
1112        packet[header.encoded_len()..].clone_from_slice(params);
1113        self.send_signal(&packet)
1114    }
1115
1116    fn send_reject(
1117        &self,
1118        label: TxLabel,
1119        signal: SignalIdentifier,
1120        error_code: ErrorCode,
1121    ) -> Result<()> {
1122        self.send_reject_params(label, signal, &[u8::from(&error_code)])
1123    }
1124
1125    fn send_reject_params(
1126        &self,
1127        label: TxLabel,
1128        signal: SignalIdentifier,
1129        params: &[u8],
1130    ) -> Result<()> {
1131        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1132        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1133        header.encode(packet.as_mut_slice())?;
1134        packet[header.encoded_len()..].clone_from_slice(params);
1135        self.send_signal(&packet)
1136    }
1137
1138    fn send_signal(&self, data: &[u8]) -> Result<()> {
1139        let _ = self.signaling.lock().write(data).map_err(|x| Error::PeerWrite(x))?;
1140        Ok(())
1141    }
1142}