bt_avdtp/
lib.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fuchsia_async::{DurationExt, OnTimeout, TimeoutExt};
6use fuchsia_bluetooth::types::Channel;
7use fuchsia_sync::Mutex;
8use futures::future::{FusedFuture, MaybeDone};
9use futures::stream::{FusedStream, Stream};
10use futures::task::{Context, Poll, Waker};
11use futures::{ready, Future, FutureExt, TryFutureExt};
12use log::{info, trace, warn};
13use packet_encoding::{Decodable, Encodable};
14use slab::Slab;
15use std::collections::VecDeque;
16use std::marker::PhantomData;
17use std::mem;
18use std::pin::Pin;
19use std::sync::Arc;
20use zx::{self as zx, MonotonicDuration};
21
22#[cfg(test)]
23mod tests;
24
25mod rtp;
26mod stream_endpoint;
27mod types;
28
29use crate::types::{SignalIdentifier, SignalingHeader, SignalingMessageType, TxLabel};
30
31pub use crate::rtp::{RtpError, RtpHeader};
32pub use crate::stream_endpoint::{
33    MediaStream, StreamEndpoint, StreamEndpointUpdateCallback, StreamState,
34};
35pub use crate::types::{
36    ContentProtectionType, EndpointType, Error, ErrorCode, MediaCodecType, MediaType, RemoteReject,
37    Result, ServiceCapability, ServiceCategory, StreamEndpointId, StreamInformation,
38};
39
40/// An AVDTP signaling peer can send commands to another peer, receive requests and send responses.
41/// Media transport is not handled by this peer.
42///
43/// Requests from the distant peer are delivered through the request stream available through
44/// take_request_stream().  Only one RequestStream can be active at a time.  Only valid requests
45/// are sent to the request stream - invalid formats are automatically rejected.
46///
47/// Responses are sent using responders that are included in the request stream from the connected
48/// peer.
49#[derive(Debug, Clone)]
50pub struct Peer {
51    inner: Arc<PeerInner>,
52}
53
54impl Peer {
55    /// Create a new peer from a signaling channel.
56    pub fn new(signaling: Channel) -> Self {
57        Self {
58            inner: Arc::new(PeerInner {
59                signaling,
60                response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
61                incoming_requests: Mutex::<RequestQueue>::default(),
62            }),
63        }
64    }
65
66    /// Take the event listener for this peer. Panics if the stream is already
67    /// held.
68    #[track_caller]
69    pub fn take_request_stream(&self) -> RequestStream {
70        {
71            let mut lock = self.inner.incoming_requests.lock();
72            if let RequestListener::None = lock.listener {
73                lock.listener = RequestListener::New;
74            } else {
75                panic!("Request stream has already been taken");
76            }
77        }
78
79        RequestStream::new(self.inner.clone())
80    }
81
82    /// Send a Stream End Point Discovery (Sec 8.6) command to the remote peer.
83    /// Asynchronously returns a the reply in a vector of endpoint information.
84    /// Error will be RemoteRejected if the remote peer rejected the command.
85    pub fn discover(&self) -> impl Future<Output = Result<Vec<StreamInformation>>> {
86        self.send_command::<DiscoverResponse>(SignalIdentifier::Discover, &[]).ok_into()
87    }
88
89    /// Send a Get Capabilities (Sec 8.7) command to the remote peer for the
90    /// given `stream_id`.
91    /// Asynchronously returns the reply which contains the ServiceCapabilities
92    /// reported.
93    /// In general, Get All Capabilities should be preferred to this command if is supported.
94    /// Error will be RemoteRejected if the remote peer rejects the command.
95    pub fn get_capabilities(
96        &self,
97        stream_id: &StreamEndpointId,
98    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
99        let stream_params = &[stream_id.to_msg()];
100        self.send_command::<GetCapabilitiesResponse>(
101            SignalIdentifier::GetCapabilities,
102            stream_params,
103        )
104        .ok_into()
105    }
106
107    /// Send a Get All Capabilities (Sec 8.8) command to the remote peer for the
108    /// given `stream_id`.
109    /// Asynchronously returns the reply which contains the ServiceCapabilities
110    /// reported.
111    /// Error will be RemoteRejected if the remote peer rejects the command.
112    pub fn get_all_capabilities(
113        &self,
114        stream_id: &StreamEndpointId,
115    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
116        let stream_params = &[stream_id.to_msg()];
117        self.send_command::<GetCapabilitiesResponse>(
118            SignalIdentifier::GetAllCapabilities,
119            stream_params,
120        )
121        .ok_into()
122    }
123
124    /// Send a Stream Configuration (Sec 8.9) command to the remote peer for the
125    /// given remote `stream_id`, communicating the association to a local
126    /// `local_stream_id` and the required stream `capabilities`.
127    /// Panics if `capabilities` is empty.
128    /// Error will be RemoteRejected if the remote refused.
129    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
130    pub fn set_configuration(
131        &self,
132        stream_id: &StreamEndpointId,
133        local_stream_id: &StreamEndpointId,
134        capabilities: &[ServiceCapability],
135    ) -> impl Future<Output = Result<()>> {
136        assert!(!capabilities.is_empty(), "must set at least one capability");
137        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(2, |a, x| a + x.encoded_len())];
138        params[0] = stream_id.to_msg();
139        params[1] = local_stream_id.to_msg();
140        let mut idx = 2;
141        for capability in capabilities {
142            if let Err(e) = capability.encode(&mut params[idx..]) {
143                return futures::future::err(e).left_future();
144            }
145            idx += capability.encoded_len();
146        }
147        self.send_command::<SimpleResponse>(SignalIdentifier::SetConfiguration, &params)
148            .ok_into()
149            .right_future()
150    }
151
152    /// Send a Get Stream Configuration (Sec 8.10) command to the remote peer
153    /// for the given remote `stream_id`.
154    /// Asynchronously returns the set of ServiceCapabilities previously
155    /// configured between these two peers.
156    /// Error will be RemoteRejected if the remote peer rejects this command.
157    pub fn get_configuration(
158        &self,
159        stream_id: &StreamEndpointId,
160    ) -> impl Future<Output = Result<Vec<ServiceCapability>>> {
161        let stream_params = &[stream_id.to_msg()];
162        self.send_command::<GetCapabilitiesResponse>(
163            SignalIdentifier::GetConfiguration,
164            stream_params,
165        )
166        .ok_into()
167    }
168
169    /// Send a Stream Reconfigure (Sec 8.11) command to the remote peer for the
170    /// given remote `stream_id`, to reconfigure the Application Service
171    /// capabilities in `capabilities`.
172    /// Note: Per the spec, only the Media Codec and Content Protection
173    /// capabilities will be accepted in this command.
174    /// Panics if there are no capabilities to configure.
175    /// Error will be RemoteRejected if the remote refused.
176    /// ServiceCategory will be set on RemoteReject with the indicated issue category.
177    pub fn reconfigure(
178        &self,
179        stream_id: &StreamEndpointId,
180        capabilities: &[ServiceCapability],
181    ) -> impl Future<Output = Result<()>> {
182        assert!(!capabilities.is_empty(), "must set at least one capability");
183        let mut params: Vec<u8> = vec![0; capabilities.iter().fold(1, |a, x| a + x.encoded_len())];
184        params[0] = stream_id.to_msg();
185        let mut idx = 1;
186        for capability in capabilities {
187            if !capability.is_application() {
188                return futures::future::err(Error::Encoding).left_future();
189            }
190            if let Err(e) = capability.encode(&mut params[idx..]) {
191                return futures::future::err(e).left_future();
192            }
193            idx += capability.encoded_len();
194        }
195        self.send_command::<SimpleResponse>(SignalIdentifier::Reconfigure, &params)
196            .ok_into()
197            .right_future()
198    }
199
200    /// Send a Open Stream Command (Sec 8.12) to the remote peer for the given
201    /// `stream_id`.
202    /// Error will be RemoteRejected if the remote peer rejects the command.
203    pub fn open(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
204        let stream_params = &[stream_id.to_msg()];
205        self.send_command::<SimpleResponse>(SignalIdentifier::Open, stream_params).ok_into()
206    }
207
208    /// Send a Start Stream Command (Sec 8.13) to the remote peer for all the streams in
209    /// `stream_ids`.
210    /// Returns Ok(()) if the command is accepted, and RemoteStreamRejected with the stream
211    /// endpoint id and error code reported by the remote if the remote signals a failure.
212    pub fn start(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
213        let mut stream_params = Vec::with_capacity(stream_ids.len());
214        for stream_id in stream_ids {
215            stream_params.push(stream_id.to_msg());
216        }
217        self.send_command::<SimpleResponse>(SignalIdentifier::Start, &stream_params).ok_into()
218    }
219
220    /// Send a Close Stream Command (Sec 8.14) to the remote peer for the given `stream_id`.
221    /// Error will be RemoteRejected if the remote peer rejects the command.
222    pub fn close(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
223        let stream_params = &[stream_id.to_msg()];
224        let response: CommandResponseFut<SimpleResponse> =
225            self.send_command::<SimpleResponse>(SignalIdentifier::Close, stream_params);
226        response.ok_into()
227    }
228
229    /// Send a Suspend Command (Sec 8.15) to the remote peer for all the streams in `stream_ids`.
230    /// Error will be RemoteRejected if the remote refused, with the stream endpoint identifier
231    /// indicated by the remote set in the RemoteReject.
232    pub fn suspend(&self, stream_ids: &[StreamEndpointId]) -> impl Future<Output = Result<()>> {
233        let mut stream_params = Vec::with_capacity(stream_ids.len());
234        for stream_id in stream_ids {
235            stream_params.push(stream_id.to_msg());
236        }
237        let response: CommandResponseFut<SimpleResponse> =
238            self.send_command::<SimpleResponse>(SignalIdentifier::Suspend, &stream_params);
239        response.ok_into()
240    }
241
242    /// Send an Abort (Sec 8.16) to the remote peer for the given `stream_id`.
243    /// Returns Ok(()) if the command is accepted, and Err(Timeout) if the remote
244    /// timed out.  The remote peer is not allowed to reject this command, and
245    /// commands that have invalid `stream_id` will timeout instead.
246    pub fn abort(&self, stream_id: &StreamEndpointId) -> impl Future<Output = Result<()>> {
247        let stream_params = &[stream_id.to_msg()];
248        self.send_command::<SimpleResponse>(SignalIdentifier::Abort, stream_params).ok_into()
249    }
250
251    /// Send a Delay Report (Sec 8.19) to the remote peer for the given `stream_id`.
252    /// `delay` is in tenths of milliseconds.
253    /// Error will be RemoteRejected if the remote peer rejects the command.
254    pub fn delay_report(
255        &self,
256        stream_id: &StreamEndpointId,
257        delay: u16,
258    ) -> impl Future<Output = Result<()>> {
259        let delay_bytes: [u8; 2] = delay.to_be_bytes();
260        let params = &[stream_id.to_msg(), delay_bytes[0], delay_bytes[1]];
261        self.send_command::<SimpleResponse>(SignalIdentifier::DelayReport, params).ok_into()
262    }
263
264    /// The maximum amount of time we will wait for a response to a signaling command.
265    const RTX_SIG_TIMER_MS: i64 = 3000;
266    const COMMAND_TIMEOUT: MonotonicDuration =
267        MonotonicDuration::from_millis(Peer::RTX_SIG_TIMER_MS);
268
269    /// Sends a signal on the channel and receive a future that will complete
270    /// when we get the expected response.
271    fn send_command<D: Decodable<Error = Error>>(
272        &self,
273        signal: SignalIdentifier,
274        payload: &[u8],
275    ) -> CommandResponseFut<D> {
276        let send_result = (|| {
277            let id = self.inner.add_response_waiter()?;
278            let header = SignalingHeader::new(id, signal, SignalingMessageType::Command);
279            let mut buf = vec![0; header.encoded_len()];
280            header.encode(buf.as_mut_slice())?;
281            buf.extend_from_slice(payload);
282            self.inner.send_signal(buf.as_slice())?;
283            Ok(header)
284        })();
285
286        CommandResponseFut::new(send_result, self.inner.clone())
287    }
288}
289
290/// A future representing the result of a AVDTP command. Decodes the response when it arrives.
291struct CommandResponseFut<D: Decodable> {
292    id: SignalIdentifier,
293    fut: Pin<Box<MaybeDone<OnTimeout<CommandResponse, fn() -> Result<Vec<u8>>>>>>,
294    _phantom: PhantomData<D>,
295}
296
297impl<D: Decodable> Unpin for CommandResponseFut<D> {}
298
299impl<D: Decodable<Error = Error>> CommandResponseFut<D> {
300    fn new(send_result: Result<SignalingHeader>, inner: Arc<PeerInner>) -> Self {
301        let header = match send_result {
302            Err(e) => {
303                return Self {
304                    id: SignalIdentifier::Abort,
305                    fut: Box::pin(MaybeDone::Done(Err(e))),
306                    _phantom: PhantomData,
307                }
308            }
309            Ok(header) => header,
310        };
311        let response = CommandResponse { id: header.label(), inner: Some(inner) };
312        let err_timeout: fn() -> Result<Vec<u8>> = || Err(Error::Timeout);
313        let timedout_fut = response.on_timeout(Peer::COMMAND_TIMEOUT.after_now(), err_timeout);
314
315        Self {
316            id: header.signal(),
317            fut: Box::pin(futures::future::maybe_done(timedout_fut)),
318            _phantom: PhantomData,
319        }
320    }
321}
322
323impl<D: Decodable<Error = Error>> Future for CommandResponseFut<D> {
324    type Output = Result<D>;
325
326    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
327        ready!(self.fut.poll_unpin(cx));
328        Poll::Ready(
329            self.fut
330                .as_mut()
331                .take_output()
332                .unwrap_or(Err(Error::AlreadyReceived))
333                .and_then(|buf| decode_signaling_response(self.id, buf)),
334        )
335    }
336}
337
338/// A request from the connected peer.
339/// Each variant of this includes a responder which implements two functions:
340///  - send(...) will send a response with the information provided.
341///  - reject(ErrorCode) will send an reject response with the given error code.
342#[derive(Debug)]
343pub enum Request {
344    Discover {
345        responder: DiscoverResponder,
346    },
347    GetCapabilities {
348        stream_id: StreamEndpointId,
349        responder: GetCapabilitiesResponder,
350    },
351    GetAllCapabilities {
352        stream_id: StreamEndpointId,
353        responder: GetCapabilitiesResponder,
354    },
355    SetConfiguration {
356        local_stream_id: StreamEndpointId,
357        remote_stream_id: StreamEndpointId,
358        capabilities: Vec<ServiceCapability>,
359        responder: ConfigureResponder,
360    },
361    GetConfiguration {
362        stream_id: StreamEndpointId,
363        responder: GetCapabilitiesResponder,
364    },
365    Reconfigure {
366        local_stream_id: StreamEndpointId,
367        capabilities: Vec<ServiceCapability>,
368        responder: ConfigureResponder,
369    },
370    Open {
371        stream_id: StreamEndpointId,
372        responder: SimpleResponder,
373    },
374    Start {
375        stream_ids: Vec<StreamEndpointId>,
376        responder: StreamResponder,
377    },
378    Close {
379        stream_id: StreamEndpointId,
380        responder: SimpleResponder,
381    },
382    Suspend {
383        stream_ids: Vec<StreamEndpointId>,
384        responder: StreamResponder,
385    },
386    Abort {
387        stream_id: StreamEndpointId,
388        responder: SimpleResponder,
389    },
390    DelayReport {
391        stream_id: StreamEndpointId,
392        delay: u16,
393        responder: SimpleResponder,
394    }, // TODO(jamuraa): add the rest of the requests
395}
396
397macro_rules! parse_one_seid {
398    ($body:ident, $signal:ident, $peer:ident, $id:ident, $request_variant:ident, $responder_type:ident) => {
399        if $body.len() != 1 {
400            Err(Error::RequestInvalid(ErrorCode::BadLength))
401        } else {
402            Ok(Request::$request_variant {
403                stream_id: StreamEndpointId::from_msg(&$body[0]),
404                responder: $responder_type { signal: $signal, peer: $peer.clone(), id: $id },
405            })
406        }
407    };
408}
409
410impl Request {
411    fn get_req_seids(body: &[u8]) -> Result<Vec<StreamEndpointId>> {
412        if body.len() < 1 {
413            return Err(Error::RequestInvalid(ErrorCode::BadLength));
414        }
415        Ok(body.iter().map(&StreamEndpointId::from_msg).collect())
416    }
417
418    fn get_req_capabilities(encoded: &[u8]) -> Result<Vec<ServiceCapability>> {
419        if encoded.len() < 2 {
420            return Err(Error::RequestInvalid(ErrorCode::BadLength));
421        }
422        let mut caps = vec![];
423        let mut loc = 0;
424        while loc < encoded.len() {
425            let cap = match ServiceCapability::decode(&encoded[loc..]) {
426                Ok(cap) => cap,
427                Err(Error::RequestInvalid(code)) => {
428                    return Err(Error::RequestInvalidExtra(code, encoded[loc]));
429                }
430                Err(e) => return Err(e),
431            };
432            loc += cap.encoded_len();
433            caps.push(cap);
434        }
435        Ok(caps)
436    }
437
438    fn parse(
439        peer: &Arc<PeerInner>,
440        id: TxLabel,
441        signal: SignalIdentifier,
442        body: &[u8],
443    ) -> Result<Request> {
444        match signal {
445            SignalIdentifier::Discover => {
446                // Discover Request has no body (Sec 8.6.1)
447                if body.len() > 0 {
448                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
449                }
450                Ok(Request::Discover { responder: DiscoverResponder { peer: peer.clone(), id } })
451            }
452            SignalIdentifier::GetCapabilities => {
453                parse_one_seid!(body, signal, peer, id, GetCapabilities, GetCapabilitiesResponder)
454            }
455            SignalIdentifier::GetAllCapabilities => parse_one_seid!(
456                body,
457                signal,
458                peer,
459                id,
460                GetAllCapabilities,
461                GetCapabilitiesResponder
462            ),
463            SignalIdentifier::SetConfiguration => {
464                if body.len() < 4 {
465                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
466                }
467                let requested = Request::get_req_capabilities(&body[2..])?;
468                Ok(Request::SetConfiguration {
469                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
470                    remote_stream_id: StreamEndpointId::from_msg(&body[1]),
471                    capabilities: requested,
472                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
473                })
474            }
475            SignalIdentifier::GetConfiguration => {
476                parse_one_seid!(body, signal, peer, id, GetConfiguration, GetCapabilitiesResponder)
477            }
478            SignalIdentifier::Reconfigure => {
479                if body.len() < 3 {
480                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
481                }
482                let requested = Request::get_req_capabilities(&body[1..])?;
483                match requested.iter().find(|x| !x.is_application()) {
484                    Some(x) => {
485                        return Err(Error::RequestInvalidExtra(
486                            ErrorCode::InvalidCapabilities,
487                            (&x.category()).into(),
488                        ));
489                    }
490                    None => (),
491                };
492                Ok(Request::Reconfigure {
493                    local_stream_id: StreamEndpointId::from_msg(&body[0]),
494                    capabilities: requested,
495                    responder: ConfigureResponder { signal, peer: peer.clone(), id },
496                })
497            }
498            SignalIdentifier::Open => {
499                parse_one_seid!(body, signal, peer, id, Open, SimpleResponder)
500            }
501            SignalIdentifier::Start => {
502                let seids = Request::get_req_seids(body)?;
503                Ok(Request::Start {
504                    stream_ids: seids,
505                    responder: StreamResponder { signal, peer: peer.clone(), id },
506                })
507            }
508            SignalIdentifier::Close => {
509                parse_one_seid!(body, signal, peer, id, Close, SimpleResponder)
510            }
511            SignalIdentifier::Suspend => {
512                let seids = Request::get_req_seids(body)?;
513                Ok(Request::Suspend {
514                    stream_ids: seids,
515                    responder: StreamResponder { signal, peer: peer.clone(), id },
516                })
517            }
518            SignalIdentifier::Abort => {
519                parse_one_seid!(body, signal, peer, id, Abort, SimpleResponder)
520            }
521            SignalIdentifier::DelayReport => {
522                if body.len() != 3 {
523                    return Err(Error::RequestInvalid(ErrorCode::BadLength));
524                }
525                let delay_arr: [u8; 2] = [body[1], body[2]];
526                let delay = u16::from_be_bytes(delay_arr);
527                Ok(Request::DelayReport {
528                    stream_id: StreamEndpointId::from_msg(&body[0]),
529                    delay,
530                    responder: SimpleResponder { signal, peer: peer.clone(), id },
531                })
532            }
533            _ => Err(Error::UnimplementedMessage),
534        }
535    }
536}
537
538/// A stream of requests from the remote peer.
539#[derive(Debug)]
540pub struct RequestStream {
541    inner: Arc<PeerInner>,
542    terminated: bool,
543}
544
545impl RequestStream {
546    fn new(inner: Arc<PeerInner>) -> Self {
547        Self { inner, terminated: false }
548    }
549}
550
551impl Unpin for RequestStream {}
552
553impl Stream for RequestStream {
554    type Item = Result<Request>;
555
556    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
557        Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
558            Ok(UnparsedRequest(SignalingHeader { label, signal, .. }, body)) => {
559                match Request::parse(&self.inner, label, signal, &body) {
560                    Err(Error::RequestInvalid(code)) => {
561                        self.inner.send_reject(label, signal, code)?;
562                        return Poll::Pending;
563                    }
564                    Err(Error::RequestInvalidExtra(code, extra)) => {
565                        self.inner.send_reject_params(label, signal, &[extra, u8::from(&code)])?;
566                        return Poll::Pending;
567                    }
568                    Err(Error::UnimplementedMessage) => {
569                        self.inner.send_reject(label, signal, ErrorCode::NotSupportedCommand)?;
570                        return Poll::Pending;
571                    }
572                    x => Some(x),
573                }
574            }
575            Err(Error::PeerDisconnected) => {
576                self.terminated = true;
577                None
578            }
579            Err(e) => Some(Err(e)),
580        })
581    }
582}
583
584impl FusedStream for RequestStream {
585    fn is_terminated(&self) -> bool {
586        self.terminated
587    }
588}
589
590impl Drop for RequestStream {
591    fn drop(&mut self) {
592        self.inner.incoming_requests.lock().listener = RequestListener::None;
593        self.inner.wake_any();
594    }
595}
596
597// Simple responses have no body data.
598#[derive(Debug)]
599pub struct SimpleResponse {}
600
601impl Decodable for SimpleResponse {
602    type Error = Error;
603
604    fn decode(from: &[u8]) -> Result<Self> {
605        if from.len() > 0 {
606            return Err(Error::InvalidMessage);
607        }
608        Ok(SimpleResponse {})
609    }
610}
611
612impl Into<()> for SimpleResponse {
613    fn into(self) -> () {
614        ()
615    }
616}
617
618#[derive(Debug)]
619struct DiscoverResponse {
620    endpoints: Vec<StreamInformation>,
621}
622
623impl Decodable for DiscoverResponse {
624    type Error = Error;
625
626    fn decode(from: &[u8]) -> Result<Self> {
627        let mut endpoints = Vec::<StreamInformation>::new();
628        let mut idx = 0;
629        while idx < from.len() {
630            let endpoint = StreamInformation::decode(&from[idx..])?;
631            idx += endpoint.encoded_len();
632            endpoints.push(endpoint);
633        }
634        Ok(DiscoverResponse { endpoints })
635    }
636}
637
638impl Into<Vec<StreamInformation>> for DiscoverResponse {
639    fn into(self) -> Vec<StreamInformation> {
640        self.endpoints
641    }
642}
643
644#[derive(Debug)]
645pub struct DiscoverResponder {
646    peer: Arc<PeerInner>,
647    id: TxLabel,
648}
649
650impl DiscoverResponder {
651    /// Sends the response to a discovery request.
652    /// At least one endpoint must be present.
653    /// Will result in a Error::PeerWrite if the distant peer is disconnected.
654    pub fn send(self, endpoints: &[StreamInformation]) -> Result<()> {
655        if endpoints.len() == 0 {
656            // There shall be at least one SEP in a response (Sec 8.6.2)
657            return Err(Error::Encoding);
658        }
659        let mut params = vec![0 as u8; endpoints.len() * endpoints[0].encoded_len()];
660        let mut idx = 0;
661        for endpoint in endpoints {
662            endpoint.encode(&mut params[idx..idx + endpoint.encoded_len()])?;
663            idx += endpoint.encoded_len();
664        }
665        self.peer.send_response(self.id, SignalIdentifier::Discover, &params)
666    }
667
668    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
669        self.peer.send_reject(self.id, SignalIdentifier::Discover, error_code)
670    }
671}
672
673#[derive(Debug)]
674pub struct GetCapabilitiesResponder {
675    peer: Arc<PeerInner>,
676    signal: SignalIdentifier,
677    id: TxLabel,
678}
679
680impl GetCapabilitiesResponder {
681    pub fn send(self, capabilities: &[ServiceCapability]) -> Result<()> {
682        let included_iter = capabilities.iter().filter(|x| x.in_response(self.signal));
683        let reply_len = included_iter.clone().fold(0, |a, b| a + b.encoded_len());
684        let mut reply = vec![0 as u8; reply_len];
685        let mut pos = 0;
686        for capability in included_iter {
687            let size = capability.encoded_len();
688            capability.encode(&mut reply[pos..pos + size])?;
689            pos += size;
690        }
691        self.peer.send_response(self.id, self.signal, &reply)
692    }
693
694    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
695        self.peer.send_reject(self.id, self.signal, error_code)
696    }
697}
698
699#[derive(Debug)]
700struct GetCapabilitiesResponse {
701    capabilities: Vec<ServiceCapability>,
702}
703
704impl Decodable for GetCapabilitiesResponse {
705    type Error = Error;
706
707    fn decode(from: &[u8]) -> Result<Self> {
708        let mut capabilities = Vec::<ServiceCapability>::new();
709        let mut idx = 0;
710        while idx < from.len() {
711            match ServiceCapability::decode(&from[idx..]) {
712                Ok(capability) => {
713                    idx = idx + capability.encoded_len();
714                    capabilities.push(capability);
715                }
716                Err(_) => {
717                    // The capability length of the invalid capability can be nonzero.
718                    // Advance `idx` by the payload amount, but don't push the invalid capability.
719                    // Increment by 1 byte for ServiceCategory, 1 byte for payload length,
720                    // `length_of_capability` bytes for capability length.
721                    info!(
722                        "GetCapabilitiesResponse decode: Capability {:?} not supported.",
723                        from[idx]
724                    );
725                    let length_of_capability = from[idx + 1] as usize;
726                    idx = idx + 2 + length_of_capability;
727                }
728            }
729        }
730        Ok(GetCapabilitiesResponse { capabilities })
731    }
732}
733
734impl Into<Vec<ServiceCapability>> for GetCapabilitiesResponse {
735    fn into(self) -> Vec<ServiceCapability> {
736        self.capabilities
737    }
738}
739
740#[derive(Debug)]
741pub struct SimpleResponder {
742    peer: Arc<PeerInner>,
743    signal: SignalIdentifier,
744    id: TxLabel,
745}
746
747impl SimpleResponder {
748    pub fn send(self) -> Result<()> {
749        self.peer.send_response(self.id, self.signal, &[])
750    }
751
752    pub fn reject(self, error_code: ErrorCode) -> Result<()> {
753        self.peer.send_reject(self.id, self.signal, error_code)
754    }
755}
756
757#[derive(Debug)]
758pub struct StreamResponder {
759    peer: Arc<PeerInner>,
760    signal: SignalIdentifier,
761    id: TxLabel,
762}
763
764impl StreamResponder {
765    pub fn send(self) -> Result<()> {
766        self.peer.send_response(self.id, self.signal, &[])
767    }
768
769    pub fn reject(self, stream_id: &StreamEndpointId, error_code: ErrorCode) -> Result<()> {
770        self.peer.send_reject_params(
771            self.id,
772            self.signal,
773            &[stream_id.to_msg(), u8::from(&error_code)],
774        )
775    }
776}
777
778#[derive(Debug)]
779pub struct ConfigureResponder {
780    peer: Arc<PeerInner>,
781    signal: SignalIdentifier,
782    id: TxLabel,
783}
784
785impl ConfigureResponder {
786    pub fn send(self) -> Result<()> {
787        self.peer.send_response(self.id, self.signal, &[])
788    }
789
790    pub fn reject(self, category: ServiceCategory, error_code: ErrorCode) -> Result<()> {
791        self.peer.send_reject_params(
792            self.id,
793            self.signal,
794            &[u8::from(&category), u8::from(&error_code)],
795        )
796    }
797}
798
799#[derive(Debug)]
800struct UnparsedRequest(SignalingHeader, Vec<u8>);
801
802impl UnparsedRequest {
803    fn new(header: SignalingHeader, body: Vec<u8>) -> UnparsedRequest {
804        UnparsedRequest(header, body)
805    }
806}
807
808#[derive(Debug, Default)]
809struct RequestQueue {
810    listener: RequestListener,
811    queue: VecDeque<UnparsedRequest>,
812}
813
814#[derive(Debug)]
815enum RequestListener {
816    /// No one is listening.
817    None,
818    /// Someone wants to listen but hasn't polled.
819    New,
820    /// Someone is listening, and can be woken with the waker.
821    Some(Waker),
822}
823
824impl Default for RequestListener {
825    fn default() -> Self {
826        RequestListener::None
827    }
828}
829
830/// An enum representing an interest in the response to a command.
831#[derive(Debug)]
832enum ResponseWaiter {
833    /// A new waiter which hasn't been polled yet.
834    WillPoll,
835    /// A task waiting for a response, which can be woken with the waker.
836    Waiting(Waker),
837    /// A response that has been received, stored here until it's polled, at
838    /// which point it will be decoded.
839    Received(Vec<u8>),
840    /// It's still waiting on the response, but the receiver has decided they
841    /// don't care and we'll throw it out.
842    Discard,
843}
844
845impl ResponseWaiter {
846    /// Check if a message has been received.
847    fn is_received(&self) -> bool {
848        if let ResponseWaiter::Received(_) = self {
849            true
850        } else {
851            false
852        }
853    }
854
855    fn unwrap_received(self) -> Vec<u8> {
856        if let ResponseWaiter::Received(buf) = self {
857            buf
858        } else {
859            panic!("expected received buf")
860        }
861    }
862}
863
864fn decode_signaling_response<D: Decodable<Error = Error>>(
865    expected_signal: SignalIdentifier,
866    buf: Vec<u8>,
867) -> Result<D> {
868    let header = SignalingHeader::decode(buf.as_slice())?;
869    if header.signal() != expected_signal {
870        return Err(Error::InvalidHeader);
871    }
872    let params = &buf[header.encoded_len()..];
873    match header.message_type {
874        SignalingMessageType::ResponseAccept => D::decode(params),
875        SignalingMessageType::GeneralReject | SignalingMessageType::ResponseReject => {
876            Err(RemoteReject::from_params(header.signal(), params).into())
877        }
878        SignalingMessageType::Command => unreachable!(),
879    }
880}
881
882/// A future that polls for the response to a command we sent.
883#[derive(Debug)]
884pub struct CommandResponse {
885    id: TxLabel,
886    // Some(x) if we're still waiting on the response.
887    inner: Option<Arc<PeerInner>>,
888}
889
890impl Unpin for CommandResponse {}
891
892impl Future for CommandResponse {
893    type Output = Result<Vec<u8>>;
894    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
895        let this = &mut *self;
896        let res;
897        {
898            let client = this.inner.as_ref().ok_or(Error::AlreadyReceived)?;
899            res = client.poll_recv_response(&this.id, cx);
900        }
901
902        if let Poll::Ready(Ok(_)) = res {
903            let inner = this.inner.take().expect("CommandResponse polled after completion");
904            inner.wake_any();
905        }
906
907        res
908    }
909}
910
911impl FusedFuture for CommandResponse {
912    fn is_terminated(&self) -> bool {
913        self.inner.is_none()
914    }
915}
916
917impl Drop for CommandResponse {
918    fn drop(&mut self) {
919        if let Some(inner) = &self.inner {
920            inner.remove_response_interest(&self.id);
921            inner.wake_any();
922        }
923    }
924}
925
926#[derive(Debug)]
927struct PeerInner {
928    /// The signaling channel
929    signaling: Channel,
930
931    /// A map of transaction ids that have been sent but the response has not
932    /// been received and/or processed yet.
933    ///
934    /// Waiters are added with `add_response_waiter` and get removed when they are
935    /// polled or they are removed with `remove_waiter`
936    response_waiters: Mutex<Slab<ResponseWaiter>>,
937
938    /// A queue of requests that have been received and are waiting to
939    /// be responded to, along with the waker for the task that has
940    /// taken the request receiver (if it exists)
941    incoming_requests: Mutex<RequestQueue>,
942}
943
944impl PeerInner {
945    /// Add a response waiter, and return a id that can be used to send the
946    /// transaction.  Responses then can be received using poll_recv_response
947    fn add_response_waiter(&self) -> Result<TxLabel> {
948        let key = self.response_waiters.lock().insert(ResponseWaiter::WillPoll);
949        let id = TxLabel::try_from(key as u8);
950        if id.is_err() {
951            warn!("Transaction IDs are exhausted");
952            let _ = self.response_waiters.lock().remove(key);
953        }
954        id
955    }
956
957    /// When a waiter isn't interested in the response anymore, we need to just
958    /// throw it out.  This is called when the response future is dropped.
959    fn remove_response_interest(&self, id: &TxLabel) {
960        let mut lock = self.response_waiters.lock();
961        let idx = usize::from(id);
962        if lock[idx].is_received() {
963            let _ = lock.remove(idx);
964        } else {
965            lock[idx] = ResponseWaiter::Discard;
966        }
967    }
968
969    // Attempts to receive a new request by processing all packets on the socket.
970    // Resolves to an unprocessed request (header, body) if one was received.
971    // Resolves to an error if there was an error reading from the socket or if the peer
972    // disconnected.
973    fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<UnparsedRequest>> {
974        let is_closed = self.recv_all(cx)?;
975
976        let mut lock = self.incoming_requests.lock();
977
978        if let Some(request) = lock.queue.pop_front() {
979            Poll::Ready(Ok(request))
980        } else {
981            lock.listener = RequestListener::Some(cx.waker().clone());
982            if is_closed {
983                Poll::Ready(Err(Error::PeerDisconnected))
984            } else {
985                Poll::Pending
986            }
987        }
988    }
989
990    // Attempts to receive a response to a request by processing all packets on the socket.
991    // Resolves to the bytes in the response body if one was received.
992    // Resolves to an error if there was an error reading from the socket, if the peer
993    // disconnected, or if the |label| is not being waited on.
994    fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Vec<u8>>> {
995        let is_closed = self.recv_all(cx)?;
996
997        let mut waiters = self.response_waiters.lock();
998        let idx = usize::from(label);
999        // We expect() below because the label above came from an internally-created object,
1000        // so the waiters should always exist in the map.
1001        if waiters.get(idx).expect("Polled unregistered waiter").is_received() {
1002            // We got our response.
1003            let buf = waiters.remove(idx).unwrap_received();
1004            Poll::Ready(Ok(buf))
1005        } else {
1006            // Set the waker to be notified when a response shows up.
1007            *waiters.get_mut(idx).expect("Polled unregistered waiter") =
1008                ResponseWaiter::Waiting(cx.waker().clone());
1009
1010            if is_closed {
1011                Poll::Ready(Err(Error::PeerDisconnected))
1012            } else {
1013                Poll::Pending
1014            }
1015        }
1016    }
1017
1018    /// Poll for any packets on the signaling socket
1019    /// Returns whether the channel was closed, or an Error::PeerRead or Error::PeerWrite
1020    /// if there was a problem communicating on the socket.
1021    fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
1022        loop {
1023            let mut next_packet = Vec::new();
1024            let packet_size = match self.signaling.poll_datagram(cx, &mut next_packet) {
1025                Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
1026                    trace!("Signaling peer closed");
1027                    return Ok(true);
1028                }
1029                Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
1030                Poll::Pending => return Ok(false),
1031                Poll::Ready(Ok(size)) => size,
1032            };
1033            if packet_size == 0 {
1034                continue;
1035            }
1036            // Detects General Reject condition and sends the response back.
1037            // On other headers with errors, sends BAD_HEADER to the peer
1038            // and attempts to continue.
1039            let header = match SignalingHeader::decode(next_packet.as_slice()) {
1040                Err(Error::InvalidSignalId(label, id)) => {
1041                    self.send_general_reject(label, id)?;
1042                    continue;
1043                }
1044                Err(_) => {
1045                    // Only possible other return is OutOfRange
1046                    // Returned only when the packet is too small, can't make a meaningful reject.
1047                    info!("received unrejectable message");
1048                    continue;
1049                }
1050                Ok(x) => x,
1051            };
1052            // Commands from the remote get translated into requests.
1053            if header.is_command() {
1054                let mut lock = self.incoming_requests.lock();
1055                let body = next_packet.split_off(header.encoded_len());
1056                lock.queue.push_back(UnparsedRequest::new(header, body));
1057                if let RequestListener::Some(ref waker) = lock.listener {
1058                    waker.wake_by_ref();
1059                }
1060            } else {
1061                // Should be a response to a command we sent
1062                let mut waiters = self.response_waiters.lock();
1063                let idx = usize::from(&header.label());
1064                if let Some(&ResponseWaiter::Discard) = waiters.get(idx) {
1065                    let _ = waiters.remove(idx);
1066                } else if let Some(entry) = waiters.get_mut(idx) {
1067                    let old_entry = mem::replace(entry, ResponseWaiter::Received(next_packet));
1068                    if let ResponseWaiter::Waiting(waker) = old_entry {
1069                        waker.wake();
1070                    }
1071                } else {
1072                    warn!("response for {:?} we did not send, dropping", header.label());
1073                }
1074                // Note: we drop any TxLabel response we are not waiting for
1075            }
1076        }
1077    }
1078
1079    // Wakes up an arbitrary task that has begun polling on the channel so that
1080    // it will call recv_all and be registered as the new channel reader.
1081    fn wake_any(&self) {
1082        // Try to wake up response waiters first, rather than the event listener.
1083        // The event listener is a stream, and so could be between poll_nexts,
1084        // Response waiters should always be actively polled once
1085        // they've begun being polled on a task.
1086        {
1087            let lock = self.response_waiters.lock();
1088            for (_, response_waiter) in lock.iter() {
1089                if let ResponseWaiter::Waiting(waker) = response_waiter {
1090                    waker.wake_by_ref();
1091                    return;
1092                }
1093            }
1094        }
1095        {
1096            let lock = self.incoming_requests.lock();
1097            if let RequestListener::Some(waker) = &lock.listener {
1098                waker.wake_by_ref();
1099                return;
1100            }
1101        }
1102    }
1103
1104    // Build and send a General Reject message (Section 8.18)
1105    fn send_general_reject(&self, label: TxLabel, invalid_signal_id: u8) -> Result<()> {
1106        // Build the packet ourselves rather than make SignalingHeader build an packet with an
1107        // invalid signal id.
1108        let packet: &[u8; 2] = &[u8::from(&label) << 4 | 0x01, invalid_signal_id & 0x3F];
1109        self.send_signal(packet)
1110    }
1111
1112    fn send_response(&self, label: TxLabel, signal: SignalIdentifier, params: &[u8]) -> Result<()> {
1113        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseAccept);
1114        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1115        header.encode(packet.as_mut_slice())?;
1116        packet[header.encoded_len()..].clone_from_slice(params);
1117        self.send_signal(&packet)
1118    }
1119
1120    fn send_reject(
1121        &self,
1122        label: TxLabel,
1123        signal: SignalIdentifier,
1124        error_code: ErrorCode,
1125    ) -> Result<()> {
1126        self.send_reject_params(label, signal, &[u8::from(&error_code)])
1127    }
1128
1129    fn send_reject_params(
1130        &self,
1131        label: TxLabel,
1132        signal: SignalIdentifier,
1133        params: &[u8],
1134    ) -> Result<()> {
1135        let header = SignalingHeader::new(label, signal, SignalingMessageType::ResponseReject);
1136        let mut packet = vec![0 as u8; header.encoded_len() + params.len()];
1137        header.encode(packet.as_mut_slice())?;
1138        packet[header.encoded_len()..].clone_from_slice(params);
1139        self.send_signal(&packet)
1140    }
1141
1142    fn send_signal(&self, data: &[u8]) -> Result<()> {
1143        let _ = self.signaling.write(data).map_err(|x| Error::PeerWrite(x))?;
1144        Ok(())
1145    }
1146}