bt_avdtp/
lib.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream};
10use futures::task::{Context, Poll, Waker};
11use futures::{Future, FutureExt, TryFutureExt, ready};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33    MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36    ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37    Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40/// An AVDTP signaling peer can send commands to another peer, receive requests and send responses.
41/// Media transport is not handled by this peer.
42///
43/// Requests from the distant peer are delivered through the request stream available through
44/// take_request_stream().  Only one RequestStream can be active at a time.  Only valid requests
45/// are sent to the request stream - invalid formats are automatically rejected.
46///
47/// Responses are sent using responders that are included in the request stream from the connected
48/// peer.
49#[derive(Debug, Clone)]
50pub struct Peer {
51    inner: Arc<PeerInner>,
52}
53
54impl Peer {
55    /// Create a new peer from a signaling channel.
56    pub fn new(signaling: Channel) -> Self {
57        Self {
58            inner: Arc::new(PeerInner {
59                signaling,
60                response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61                incoming_requests: Mutex::<RequestQueue>::default(),
62            }),
63        }
64    }
65
66    /// Take the event listener for this peer. Panics if the stream is already
67    /// held.
68    #[track_caller]
69    pub fn take_request_stream(&self) -> RequestStream {
70        {
71            let mut lock = self.inner.incoming_requests.lock();
72            if let RequestListener::None = lock.listener {
73                lock.listener = RequestListener::New;
74            } else {
75                panic!("Request stream has already been taken");
76            }
77        }
78
79        RequestStream::new(self.inner.clone())
80    }
81
82    /// Send a Stream End Point Discovery (Sec 8.6) command to the remote peer.
83    /// Asynchronously returns a the reply in a vector of endpoint information.
84    /// Error will be RemoteRejected if the remote peer rejected the command.
85    pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> + use<> {
86        self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87    }
88
89    /// Send a Get Capabilities (Sec 8.7) command to the remote peer for the
90    /// given `stream_id`.
91    /// Asynchronously returns the reply which contains the ServiceCapabilities
92    /// reported.
93    /// In general, Get All Capabilities should be preferred to this command if is supported.
94    /// Error will be RemoteRejected if the remote peer rejects the command.
95    pub fn get_capabilities(
96        &self,
97        stream_id: &StreamEndpointId,
98    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
99        let stream_params = &[stream_id.to_msg()];
100        self.send_command::<GetCapabilitiesResponse>(
101            SignalIdentifier::GetCapabilities,
102            stream_params,
103        )
104        .ok_into()
105    }
106
107    /// Send a Get All Capabilities (Sec 8.8) command to the remote peer for the
108    /// given `stream_id`.
109    /// Asynchronously returns the reply which contains the ServiceCapabilities
110    /// reported.
111    /// Error will be RemoteRejected if the remote peer rejects the command.
112    pub fn get_all_capabilities(
113        &self,
114        stream_id: &StreamEndpointId,
115    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
116        let stream_params = &[stream_id.to_msg()];
117        self.send_command::<GetCapabilitiesResponse>(
118            SignalIdentifier::GetAllCapabilities,
119            stream_params,
120        )
121        .ok_into()
122    }
123
124    /// Send a Stream Configuration (Sec 8.9) command to the remote peer for the
125    /// given remote `stream_id`, communicating the association to a local
126    /// `local_stream_id` and the required stream `capabilities`.
127    /// Panics if `capabilities` is empty.
128    /// Error will be RemoteRejected if the remote refused.
129    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
130    pub fn set_configuration(
131        &self,
132        stream_id: &StreamEndpointId,
133        local_stream_id: &StreamEndpointId,
134        capabilities: &[ServiceCapability],
135    ) -> impl Future<Output = Result<()>> + use<> {
136        assert!(!capabilities.is_empty(), "must set at least one capability");
137        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138        params[0] = stream_id.to_msg();
139        params[1] = local_stream_id.to_msg();
140        let mut idx = 2;
141        for capability in capabilities {
142            if let Err(e) = capability.encode(&mut params[idx..]) {
143                return futures::future::err(e).left_future();
144            }
145            idx += capability.encoded_len();
146        }
147        self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, &params)
148            .ok_into()
149            .right_future()
150    }
151
152    /// Send a Get Stream Configuration (Sec 8.10) command to the remote peer
153    /// for the given remote `stream_id`.
154    /// Asynchronously returns the set of ServiceCapabilities previously
155    /// configured between these two peers.
156    /// Error will be RemoteRejected if the remote peer rejects this command.
157    pub fn get_configuration(
158        &self,
159        stream_id: &StreamEndpointId,
160    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> + use<> {
161        let stream_params = &[stream_id.to_msg()];
162        self.send_command::<GetCapabilitiesResponse>(
163            SignalIdentifier::GetConfiguration,
164            stream_params,
165        )
166        .ok_into()
167    }
168
169    /// Send a Stream Reconfigure (Sec 8.11) command to the remote peer for the
170    /// given remote `stream_id`, to reconfigure the Application Service
171    /// capabilities in `capabilities`.
172    /// Note: Per the spec, only the Media Codec and Content Protection
173    /// capabilities will be accepted in this command.
174    /// Panics if there are no capabilities to configure.
175    /// Error will be RemoteRejected if the remote refused.
176    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
177    pub fn reconfigure(
178        &self,
179        stream_id: &StreamEndpointId,
180        capabilities: &[ServiceCapability],
181    ) -> impl Future<Output = Result<()>> + use<> {
182        assert!(!capabilities.is_empty(), "must set at least one capability");
183        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184        params[0] = stream_id.to_msg();
185        let mut idx = 1;
186        for capability in capabilities {
187            if !capability.is_application() {
188                return futures::future::err(Error::Encoding).left_future();
189            }
190            if let Err(e) = capability.encode(&mut params[idx..]) {
191                return futures::future::err(e).left_future();
192            }
193            idx += capability.encoded_len();
194        }
195        self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, &params)
196            .ok_into()
197            .right_future()
198    }
199
200    /// Send a Open Stream Command (Sec 8.12) to the remote peer for the given
201    /// `stream_id`.
202    /// Error will be RemoteRejected if the remote peer rejects the command.
203    pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
204        let stream_params = &[stream_id.to_msg()];
205        self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206    }
207
208    /// Send a Start Stream Command (Sec 8.13) to the remote peer for all the streams in
209    /// `stream_ids`.
210    /// Returns Ok(()) if the command is accepted, and RemoteStreamRejected with the stream
211    /// endpoint id and error code reported by the remote if the remote signals a failure.
212    pub fn start(
213        &self,
214        stream_ids: &[StreamEndpointId],
215    ) -> impl Future<Output = Result<()>> + use<> {
216        let mut stream_params = Vec::with_capacity(stream_ids.len());
217        for stream_id in stream_ids {
218            stream_params.push(stream_id.to_msg());
219        }
220        self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
221    }
222
223    /// Send a Close Stream Command (Sec 8.14) to the remote peer for the given `stream_id`.
224    /// Error will be RemoteRejected if the remote peer rejects the command.
225    pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
226        let stream_params = &[stream_id.to_msg()];
227        let response: CommandResponseFut<SimpleResponse> =
228            self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
229        response.ok_into()
230    }
231
232    /// Send a Suspend Command (Sec 8.15) to the remote peer for all the streams in `stream_ids`.
233    /// Error will be RemoteRejected if the remote refused, with the stream endpoint identifier
234    /// indicated by the remote set in the RemoteReject.
235    pub fn suspend(
236        &self,
237        stream_ids: &[StreamEndpointId],
238    ) -> impl Future<Output = Result<()>> + use<> {
239        let mut stream_params = Vec::with_capacity(stream_ids.len());
240        for stream_id in stream_ids {
241            stream_params.push(stream_id.to_msg());
242        }
243        let response: CommandResponseFut<SimpleResponse> =
244            self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
245        response.ok_into()
246    }
247
248    /// Send an Abort (Sec 8.16) to the remote peer for the given `stream_id`.
249    /// Returns Ok(()) if the command is accepted, and Err(Timeout) if the remote
250    /// timed out.  The remote peer is not allowed to reject this command, and
251    /// commands that have invalid `stream_id` will timeout instead.
252    pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> + use<> {
253        let stream_params = &[stream_id.to_msg()];
254        self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
255    }
256
257    /// Send a Delay Report (Sec 8.19) to the remote peer for the given `stream_id`.
258    /// `delay` is in tenths of milliseconds.
259    /// Error will be RemoteRejected if the remote peer rejects the command.
260    pub fn delay_report(
261        &self,
262        stream_id: &StreamEndpointId,
263        delay: u16,
264    ) -> impl Future<Output = Result<()>> + use<> {
265        let delay_bytes: [u8; 2] = delay.to_be_bytes();
266        let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
267        self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
268    }
269
270    /// The maximum amount of time we will wait for a response to a signaling command.
271    const RTX_SIG_TIMER_MS: i64 = 3000;
272    const COMMAND_TIMEOUT: MonotonicDuration =
273        MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
274
275    /// Sends a signal on the channel and receive a future that will complete
276    /// when we get the expected response.
277    fn send_command<D: Decodable<Error = Error>>(
278        &self,
279        signal: SignalIdentifier,
280        payload: &[u8],
281    ) -> CommandResponseFut<D> {
282        let send_result = (|| {
283            let id = self.inner.add_response_waiter()?;
284            let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
285            let mut buf = vec![0; header.encoded_len()];
286            header.encode(buf.as_mut_slice())?;
287            buf.extend_from_slice(payload);
288            self.inner.send_signal(buf.as_slice())?;
289            Ok(header)
290        })();
291
292        CommandResponseFut::new(send_result, self.inner.clone())
293    }
294}
295
296/// A future representing the result of a AVDTP command. Decodes the response when it arrives.
297struct CommandResponseFut<D: Decodable> {
298    id: SignalIdentifier,
299    fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
300    _phantom: PhantomData<D>,
301}
302
303impl<D: Decodable> Unpin for CommandResponseFut<D> {}
304
305impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
306    fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
307        let header = match send_result {
308            Err(e) => {
309                return Self {
310                    id: SignalIdentifier::Abort,
311                    fut: Box::pin(MaybeDone::Done(Err(e))),
312                    _phantom: PhantomData,
313                };
314            }
315            Ok(header) => header,
316        };
317        let response = CommandResponse { id: header.label(), inner: Some(inner) };
318        let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
319        let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
320
321        Self {
322            id: header.signal(),
323            fut: Box::pin(futures::future::maybe_done(timedout_fut)),
324            _phantom: PhantomData,
325        }
326    }
327}
328
329impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
330    type Output = Result<D>;
331
332    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
333        ready!(self.fut.poll_unpin(cx));
334        Poll::Ready(
335            self.fut
336                .as_mut()
337                .take_output()
338                .unwrap_or(Err(Error::AlreadyReceived))
339                .and_then(|buf| decode_signaling_response(self.id, buf)),
340        )
341    }
342}
343
344/// A request from the connected peer.
345/// Each variant of this includes a responder which implements two functions:
346///  - send(...) will send a response with the information provided.
347///  - reject(ErrorCode) will send an reject response with the given error code.
348#[derive(Debug)]
349pub enum Request {
350    Discover {
351        responder: DiscoverResponder,
352    },
353    GetCapabilities {
354        stream_id: StreamEndpointId,
355        responder: GetCapabilitiesResponder,
356    },
357    GetAllCapabilities {
358        stream_id: StreamEndpointId,
359        responder: GetCapabilitiesResponder,
360    },
361    SetConfiguration {
362        local_stream_id: StreamEndpointId,
363        remote_stream_id: StreamEndpointId,
364        capabilities: Vec<ServiceCapability>,
365        responder: ConfigureResponder,
366    },
367    GetConfiguration {
368        stream_id: StreamEndpointId,
369        responder: GetCapabilitiesResponder,
370    },
371    Reconfigure {
372        local_stream_id: StreamEndpointId,
373        capabilities: Vec<ServiceCapability>,
374        responder: ConfigureResponder,
375    },
376    Open {
377        stream_id: StreamEndpointId,
378        responder: SimpleResponder,
379    },
380    Start {
381        stream_ids: Vec<StreamEndpointId>,
382        responder: StreamResponder,
383    },
384    Close {
385        stream_id: StreamEndpointId,
386        responder: SimpleResponder,
387    },
388    Suspend {
389        stream_ids: Vec<StreamEndpointId>,
390        responder: StreamResponder,
391    },
392    Abort {
393        stream_id: StreamEndpointId,
394        responder: SimpleResponder,
395    },
396    DelayReport {
397        stream_id: StreamEndpointId,
398        delay: u16,
399        responder: SimpleResponder,
400    }, // TODO(jamuraa): add the rest of the requests
401}
402
403macro_rules! parse_one_seid {
404    ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
405        if $body.len() != 1 {
406            Err(Error::RequestInvalid(ErrorCode::BadLength))
407        } else {
408            Ok(Request::$request_variant {
409                stream_id: StreamEndpointId::from_msg(&$body[0]),
410                responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
411            })
412        }
413    };
414}
415
416impl Request {
417    fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
418        if body.len() < 1 {
419            return Err(Error::RequestInvalid(ErrorCode::BadLength));
420        }
421        Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
422    }
423
424    fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
425        if encoded.len() < 2 {
426            return Err(Error::RequestInvalid(ErrorCode::BadLength));
427        }
428        let mut caps = vec![];
429        let mut loc = 0;
430        while loc < encoded.len() {
431            let cap = match ServiceCapability::decode(&encoded[loc..]) {
432                Ok(cap) => cap,
433                Err(Error::RequestInvalid(code)) => {
434                    return Err(Error::RequestInvalidExtra(code, encoded[loc]));
435                }
436                Err(e) => return Err(e),
437            };
438            loc += cap.encoded_len();
439            caps.push(cap);
440        }
441        Ok(caps)
442    }
443
444    fn parse(
445        peer: &Arc<PeerInner>,
446        id: TxLabel,
447        signal: SignalIdentifier,
448        body: &[u8],
449    ) -> Result<Request> {
450        match signal {
451            SignalIdentifier::Discover => {
452                // Discover Request has no body (Sec 8.6.1)
453                if body.len() > 0 {
454                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
455                }
456                Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
457            }
458            SignalIdentifier::GetCapabilities => {
459                parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
460            }
461            SignalIdentifier::GetAllCapabilities => parse_one_seid!(
462                body,
463                signal,
464                peer,
465                id,
466                GetAllCapabilities,
467                GetCapabilitiesResponder
468            ),
469            SignalIdentifier::SetConfiguration => {
470                if body.len() < 4 {
471                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
472                }
473                let requested = Request::get_req_capabilities(&body[2..])?;
474                Ok(Request::SetConfiguration {
475                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
476                    remote_stream_id: StreamEndpointId::from_msg(&body[1]),
477                    capabilities: requested,
478                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
479                })
480            }
481            SignalIdentifier::GetConfiguration => {
482                parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
483            }
484            SignalIdentifier::Reconfigure => {
485                if body.len() < 3 {
486                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
487                }
488                let requested = Request::get_req_capabilities(&body[1..])?;
489                match requested.iter().find(|x| !x.is_application()) {
490                    Some(x) => {
491                        return Err(Error::RequestInvalidExtra(
492                            ErrorCode::InvalidCapabilities,
493                            (&x.category()).into(),
494                        ));
495                    }
496                    None => (),
497                };
498                Ok(Request::Reconfigure {
499                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
500                    capabilities: requested,
501                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
502                })
503            }
504            SignalIdentifier::Open => {
505                parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
506            }
507            SignalIdentifier::Start => {
508                let seids = Request::get_req_seids(body)?;
509                Ok(Request::Start {
510                    stream_ids: seids,
511                    responder: StreamResponder { signal, peer: peer.clone(), id },
512                })
513            }
514            SignalIdentifier::Close => {
515                parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
516            }
517            SignalIdentifier::Suspend => {
518                let seids = Request::get_req_seids(body)?;
519                Ok(Request::Suspend {
520                    stream_ids: seids,
521                    responder: StreamResponder { signal, peer: peer.clone(), id },
522                })
523            }
524            SignalIdentifier::Abort => {
525                parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
526            }
527            SignalIdentifier::DelayReport => {
528                if body.len() != 3 {
529                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
530                }
531                let delay_arr: [u8; 2] = [body[1], body[2]];
532                let delay = u16::from_be_bytes(delay_arr);
533                Ok(Request::DelayReport {
534                    stream_id: StreamEndpointId::from_msg(&body[0]),
535                    delay,
536                    responder: SimpleResponder { signal, peer: peer.clone(), id },
537                })
538            }
539            _ => Err(Error::UnimplementedMessage),
540        }
541    }
542}
543
544/// A stream of requests from the remote peer.
545#[derive(Debug)]
546pub struct RequestStream {
547    inner: Arc<PeerInner>,
548    terminated: bool,
549}
550
551impl RequestStream {
552    fn new(inner: Arc<PeerInner>) -> Self {
553        Self { inner, terminated: false }
554    }
555}
556
557impl Unpin for RequestStream {}
558
559impl Stream for RequestStream {
560    type Item = Result<Request>;
561
562    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
563        Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
564            Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
565                match Request::parse(&self.inner, label, signal, &body) {
566                    Err(Error::RequestInvalid(code)) => {
567                        self.inner.send_reject(label, signal, code)?;
568                        return Poll::Pending;
569                    }
570                    Err(Error::RequestInvalidExtra(code, extra)) => {
571                        self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
572                        return Poll::Pending;
573                    }
574                    Err(Error::UnimplementedMessage) => {
575                        self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
576                        return Poll::Pending;
577                    }
578                    x => Some(x),
579                }
580            }
581            Err(Error::PeerDisconnected) => {
582                self.terminated = true;
583                None
584            }
585            Err(e) => Some(Err(e)),
586        })
587    }
588}
589
590impl FusedStream for RequestStream {
591    fn is_terminated(&self) -> bool {
592        self.terminated
593    }
594}
595
596impl Drop for RequestStream {
597    fn drop(&mut self) {
598        self.inner.incoming_requests.lock().listener = RequestListener::None;
599        self.inner.wake_any();
600    }
601}
602
603// Simple responses have no body data.
604#[derive(Debug)]
605pub struct SimpleResponse {}
606
607impl Decodable for SimpleResponse {
608    type Error = Error;
609
610    fn decode(from: &[u8]) -> Result<Self> {
611        if from.len() > 0 {
612            return Err(Error::InvalidMessage);
613        }
614        Ok(SimpleResponse {})
615    }
616}
617
618impl Into<()> for SimpleResponse {
619    fn into(self) -> () {
620        ()
621    }
622}
623
624#[derive(Debug)]
625struct DiscoverResponse {
626    endpoints: Vec<StreamInformation>,
627}
628
629impl Decodable for DiscoverResponse {
630    type Error = Error;
631
632    fn decode(from: &[u8]) -> Result<Self> {
633        let mut endpoints = Vec::<StreamInformation>::new();
634        let mut idx = 0;
635        while idx < from.len() {
636            let endpoint = StreamInformation::decode(&from[idx..])?;
637            idx += endpoint.encoded_len();
638            endpoints.push(endpoint);
639        }
640        Ok(DiscoverResponse { endpoints })
641    }
642}
643
644impl Into<Vec<StreamInformation>> for DiscoverResponse {
645    fn into(self) -> Vec<StreamInformation> {
646        self.endpoints
647    }
648}
649
650#[derive(Debug)]
651pub struct DiscoverResponder {
652    peer: Arc<PeerInner>,
653    id: TxLabel,
654}
655
656impl DiscoverResponder {
657    /// Sends the response to a discovery request.
658    /// At least one endpoint must be present.
659    /// Will result in a Error::PeerWrite if the distant peer is disconnected.
660    pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
661        if endpoints.len() == 0 {
662            // There shall be at least one SEP in a response (Sec 8.6.2)
663            return Err(Error::Encoding);
664        }
665        let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
666        let mut idx = 0;
667        for endpoint in endpoints {
668            endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
669            idx += endpoint.encoded_len();
670        }
671        self.peer.send_response(self.id, SignalIdentifier::Discover, &params)
672    }
673
674    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
675        self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
676    }
677}
678
679#[derive(Debug)]
680pub struct GetCapabilitiesResponder {
681    peer: Arc<PeerInner>,
682    signal: SignalIdentifier,
683    id: TxLabel,
684}
685
686impl GetCapabilitiesResponder {
687    pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
688        let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
689        let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
690        let mut reply = vec![0 as u8; reply_len];
691        let mut pos = 0;
692        for capability in included_iter {
693            let size = capability.encoded_len();
694            capability.encode(&mut reply[pos..pos + size])?;
695            pos += size;
696        }
697        self.peer.send_response(self.id, self.signal, &reply)
698    }
699
700    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
701        self.peer.send_reject(self.id, self.signal, error_code)
702    }
703}
704
705#[derive(Debug)]
706struct GetCapabilitiesResponse {
707    capabilities: Vec<ServiceCapability>,
708}
709
710impl Decodable for GetCapabilitiesResponse {
711    type Error = Error;
712
713    fn decode(from: &[u8]) -> Result<Self> {
714        let mut capabilities = Vec::<ServiceCapability>::new();
715        let mut idx = 0;
716        while idx < from.len() {
717            match ServiceCapability::decode(&from[idx..]) {
718                Ok(capability) => {
719                    idx = idx + capability.encoded_len();
720                    capabilities.push(capability);
721                }
722                Err(_) => {
723                    // The capability length of the invalid capability can be nonzero.
724                    // Advance `idx` by the payload amount, but don't push the invalid capability.
725                    // Increment by 1 byte for ServiceCategory, 1 byte for payload length,
726                    // `length_of_capability` bytes for capability length.
727                    info!(
728                        "GetCapabilitiesResponse decode: Capability {:?} not supported.",
729                        from[idx]
730                    );
731                    let length_of_capability = from[idx + 1] as usize;
732                    idx = idx + 2 + length_of_capability;
733                }
734            }
735        }
736        Ok(GetCapabilitiesResponse { capabilities })
737    }
738}
739
740impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
741    fn into(self) -> Vec<ServiceCapability> {
742        self.capabilities
743    }
744}
745
746#[derive(Debug)]
747pub struct SimpleResponder {
748    peer: Arc<PeerInner>,
749    signal: SignalIdentifier,
750    id: TxLabel,
751}
752
753impl SimpleResponder {
754    pub fn send(self) -> Result<()> {
755        self.peer.send_response(self.id, self.signal, &[])
756    }
757
758    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
759        self.peer.send_reject(self.id, self.signal, error_code)
760    }
761}
762
763#[derive(Debug)]
764pub struct StreamResponder {
765    peer: Arc<PeerInner>,
766    signal: SignalIdentifier,
767    id: TxLabel,
768}
769
770impl StreamResponder {
771    pub fn send(self) -> Result<()> {
772        self.peer.send_response(self.id, self.signal, &[])
773    }
774
775    pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
776        self.peer.send_reject_params(
777            self.id,
778            self.signal,
779            &[stream_id.to_msg(), u8::from(&error_code)],
780        )
781    }
782}
783
784#[derive(Debug)]
785pub struct ConfigureResponder {
786    peer: Arc<PeerInner>,
787    signal: SignalIdentifier,
788    id: TxLabel,
789}
790
791impl ConfigureResponder {
792    pub fn send(self) -> Result<()> {
793        self.peer.send_response(self.id, self.signal, &[])
794    }
795
796    pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
797        self.peer.send_reject_params(
798            self.id,
799            self.signal,
800            &[u8::from(&category), u8::from(&error_code)],
801        )
802    }
803}
804
805#[derive(Debug)]
806struct UnparsedRequest(SignalingHeader, Vec<u8>);
807
808impl UnparsedRequest {
809    fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
810        UnparsedRequest(header, body)
811    }
812}
813
814#[derive(Debug, Default)]
815struct RequestQueue {
816    listener: RequestListener,
817    queue: VecDeque<UnparsedRequest>,
818}
819
820#[derive(Debug)]
821enum RequestListener {
822    /// No one is listening.
823    None,
824    /// Someone wants to listen but hasn't polled.
825    New,
826    /// Someone is listening, and can be woken with the waker.
827    Some(Waker),
828}
829
830impl Default for RequestListener {
831    fn default() -> Self {
832        RequestListener::None
833    }
834}
835
836/// An enum representing an interest in the response to a command.
837#[derive(Debug)]
838enum ResponseWaiter {
839    /// A new waiter which hasn't been polled yet.
840    WillPoll,
841    /// A task waiting for a response, which can be woken with the waker.
842    Waiting(Waker),
843    /// A response that has been received, stored here until it's polled, at
844    /// which point it will be decoded.
845    Received(Vec<u8>),
846    /// It's still waiting on the response, but the receiver has decided they
847    /// don't care and we'll throw it out.
848    Discard,
849}
850
851impl ResponseWaiter {
852    /// Check if a message has been received.
853    fn is_received(&self) -> bool {
854        if let ResponseWaiter::Received(_) = self { true } else { false }
855    }
856
857    fn unwrap_received(self) -> Vec<u8> {
858        if let ResponseWaiter::Received(buf) = self { buf } else { panic!("expected received buf") }
859    }
860}
861
862fn decode_signaling_response<D: Decodable<Error = Error>>(
863    expected_signal: SignalIdentifier,
864    buf: Vec<u8>,
865) -> Result<D> {
866    let header = SignalingHeader::decode(buf.as_slice())?;
867    if header.signal() != expected_signal {
868        return Err(Error::InvalidHeader);
869    }
870    let params = &buf[header.encoded_len()..];
871    match header.message_type {
872        SignalingMessageType::ResponseAccept => D::decode(params),
873        SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
874            Err(RemoteReject::from_params(header.signal(), params).into())
875        }
876        SignalingMessageType::Command => unreachable!(),
877    }
878}
879
880/// A future that polls for the response to a command we sent.
881#[derive(Debug)]
882pub struct CommandResponse {
883    id: TxLabel,
884    // Some(x) if we're still waiting on the response.
885    inner: Option<Arc<PeerInner>>,
886}
887
888impl Unpin for CommandResponse {}
889
890impl Future for CommandResponse {
891    type Output = Result<Vec<u8>>;
892    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
893        let this = &mut *self;
894        let res;
895        {
896            let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
897            res = client.poll_recv_response(&this.id, cx);
898        }
899
900        if let Poll::Ready(Ok(_)) = res {
901            let inner = this.inner.take().expect("CommandResponse polled after completion");
902            inner.wake_any();
903        }
904
905        res
906    }
907}
908
909impl FusedFuture for CommandResponse {
910    fn is_terminated(&self) -> bool {
911        self.inner.is_none()
912    }
913}
914
915impl Drop for CommandResponse {
916    fn drop(&mut self) {
917        if let Some(inner) = &self.inner {
918            inner.remove_response_interest(&self.id);
919            inner.wake_any();
920        }
921    }
922}
923
924#[derive(Debug)]
925struct PeerInner {
926    /// The signaling channel
927    signaling: Channel,
928
929    /// A map of transaction ids that have been sent but the response has not
930    /// been received and/or processed yet.
931    ///
932    /// Waiters are added with `add_response_waiter` and get removed when they are
933    /// polled or they are removed with `remove_waiter`
934    response_waiters: Mutex<Slab<ResponseWaiter>>,
935
936    /// A queue of requests that have been received and are waiting to
937    /// be responded to, along with the waker for the task that has
938    /// taken the request receiver (if it exists)
939    incoming_requests: Mutex<RequestQueue>,
940}
941
942impl PeerInner {
943    /// Add a response waiter, and return a id that can be used to send the
944    /// transaction.  Responses then can be received using poll_recv_response
945    fn add_response_waiter(&self) -> Result<TxLabel> {
946        let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
947        let id = TxLabel::try_from(key as u8);
948        if id.is_err() {
949            warn!("Transaction IDs are exhausted");
950            let _ = self.response_waiters.lock().remove(key);
951        }
952        id
953    }
954
955    /// When a waiter isn't interested in the response anymore, we need to just
956    /// throw it out.  This is called when the response future is dropped.
957    fn remove_response_interest(&self, id: &TxLabel) {
958        let mut lock = self.response_waiters.lock();
959        let idx = usize::from(id);
960        if lock[idx].is_received() {
961            let _ = lock.remove(idx);
962        } else {
963            lock[idx] = ResponseWaiter::Discard;
964        }
965    }
966
967    // Attempts to receive a new request by processing all packets on the socket.
968    // Resolves to an unprocessed request (header, body) if one was received.
969    // Resolves to an error if there was an error reading from the socket or if the peer
970    // disconnected.
971    fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
972        let is_closed = self.recv_all(cx)?;
973
974        let mut lock = self.incoming_requests.lock();
975
976        if let Some(request) = lock.queue.pop_front() {
977            Poll::Ready(Ok(request))
978        } else {
979            lock.listener = RequestListener::Some(cx.waker().clone());
980            if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
981        }
982    }
983
984    // Attempts to receive a response to a request by processing all packets on the socket.
985    // Resolves to the bytes in the response body if one was received.
986    // Resolves to an error if there was an error reading from the socket, if the peer
987    // disconnected, or if the |label| is not being waited on.
988    fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
989        let is_closed = self.recv_all(cx)?;
990
991        let mut waiters = self.response_waiters.lock();
992        let idx = usize::from(label);
993        // We expect() below because the label above came from an internally-created object,
994        // so the waiters should always exist in the map.
995        if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
996            // We got our response.
997            let buf = waiters.remove(idx).unwrap_received();
998            Poll::Ready(Ok(buf))
999        } else {
1000            // Set the waker to be notified when a response shows up.
1001            *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1002                ResponseWaiter::Waiting(cx.waker().clone());
1003
1004            if is_closed { Poll::Ready(Err(Error::PeerDisconnected)) } else { Poll::Pending }
1005        }
1006    }
1007
1008    /// Poll for any packets on the signaling socket
1009    /// Returns whether the channel was closed, or an Error::PeerRead or Error::PeerWrite
1010    /// if there was a problem communicating on the socket.
1011    fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1012        loop {
1013            let mut next_packet = Vec::new();
1014            let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1015                Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1016                    trace!("Signaling peer closed");
1017                    return Ok(true);
1018                }
1019                Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1020                Poll::Pending => return Ok(false),
1021                Poll::Ready(Ok(size)) => size,
1022            };
1023            if packet_size == 0 {
1024                continue;
1025            }
1026            // Detects General Reject condition and sends the response back.
1027            // On other headers with errors, sends BAD_HEADER to the peer
1028            // and attempts to continue.
1029            let header = match SignalingHeader::decode(next_packet.as_slice()) {
1030                Err(Error::InvalidSignalId(label, id)) => {
1031                    self.send_general_reject(label, id)?;
1032                    continue;
1033                }
1034                Err(_) => {
1035                    // Only possible other return is OutOfRange
1036                    // Returned only when the packet is too small, can't make a meaningful reject.
1037                    info!("received unrejectable message");
1038                    continue;
1039                }
1040                Ok(x) => x,
1041            };
1042            // Commands from the remote get translated into requests.
1043            if header.is_command() {
1044                let mut lock = self.incoming_requests.lock();
1045                let body = next_packet.split_off(header.encoded_len());
1046                lock.queue.push_back(UnparsedRequest::new(header, body));
1047                if let RequestListener::Some(ref waker) = lock.listener {
1048                    waker.wake_by_ref();
1049                }
1050            } else {
1051                // Should be a response to a command we sent
1052                let mut waiters = self.response_waiters.lock();
1053                let idx = usize::from(&header.label());
1054                if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1055                    let _ = waiters.remove(idx);
1056                } else if let Some(entry) = waiters.get_mut(idx) {
1057                    let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1058                    if let ResponseWaiter::Waiting(waker) = old_entry {
1059                        waker.wake();
1060                    }
1061                } else {
1062                    warn!("response for {:?} we did not send, dropping", header.label());
1063                }
1064                // Note: we drop any TxLabel response we are not waiting for
1065            }
1066        }
1067    }
1068
1069    // Wakes up an arbitrary task that has begun polling on the channel so that
1070    // it will call recv_all and be registered as the new channel reader.
1071    fn wake_any(&self) {
1072        // Try to wake up response waiters first, rather than the event listener.
1073        // The event listener is a stream, and so could be between poll_nexts,
1074        // Response waiters should always be actively polled once
1075        // they've begun being polled on a task.
1076        {
1077            let lock = self.response_waiters.lock();
1078            for (_, response_waiter) in lock.iter() {
1079                if let ResponseWaiter::Waiting(waker) = response_waiter {
1080                    waker.wake_by_ref();
1081                    return;
1082                }
1083            }
1084        }
1085        {
1086            let lock = self.incoming_requests.lock();
1087            if let RequestListener::Some(waker) = &lock.listener {
1088                waker.wake_by_ref();
1089                return;
1090            }
1091        }
1092    }
1093
1094    // Build and send a General Reject message (Section 8.18)
1095    fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1096        // Build the packet ourselves rather than make SignalingHeader build an packet with an
1097        // invalid signal id.
1098        let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1099        self.send_signal(packet)
1100    }
1101
1102    fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1103        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1104        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1105        header.encode(packet.as_mut_slice())?;
1106        packet[header.encoded_len()..].clone_from_slice(params);
1107        self.send_signal(&packet)
1108    }
1109
1110    fn send_reject(
1111        &self,
1112        label: TxLabel,
1113        signal: SignalIdentifier,
1114        error_code: ErrorCode,
1115    ) -> Result<()> {
1116        self.send_reject_params(label, signal, &[u8::from(&error_code)])
1117    }
1118
1119    fn send_reject_params(
1120        &self,
1121        label: TxLabel,
1122        signal: SignalIdentifier,
1123        params: &[u8],
1124    ) -> Result<()> {
1125        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1126        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1127        header.encode(packet.as_mut_slice())?;
1128        packet[header.encoded_len()..].clone_from_slice(params);
1129        self.send_signal(&packet)
1130    }
1131
1132    fn send_signal(&self, data: &[u8]) -> Result<()> {
1133        let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1134        Ok(())
1135    }
1136}