1use fuchsia_bluetooth::types::Channel;
6use fuchsia_sync::Mutex;
7
8use futures::ready;
9use futures::stream::{FusedStream, Stream};
10use futures::task::{Context, Poll, Waker};
11use log::{info, trace, warn};
12use packet_encoding::{Decodable, Encodable};
13use slab::Slab;
14use std::collections::VecDeque;
15use std::mem;
16use std::pin::Pin;
17use std::sync::Arc;
18
19#[cfg(test)]
20mod tests;
21
22mod types;
23
24use crate::{Error, Result};
25
26use self::types::AV_REMOTE_PROFILE;
27
28pub use self::types::{Header, MessageType, PacketType, TxLabel};
29
30#[derive(Debug)]
34pub struct Peer {
35 inner: Arc<PeerInner>,
36}
37
38#[derive(Debug)]
39struct PeerInner {
40 channel: Channel,
42
43 response_waiters: Mutex<Slab<ResponseWaiter>>,
49
50 incoming_requests: Mutex<CommandQueue>,
54}
55
56impl Peer {
57 pub fn new(channel: Channel) -> Self {
59 Self { inner: Arc::new(PeerInner::new(channel)) }
60 }
61
62 pub fn take_command_stream(&self) -> CommandStream {
65 {
66 let mut lock = self.inner.incoming_requests.lock();
67 if let CommandListener::None = lock.listener {
68 lock.listener = CommandListener::New;
69 } else {
70 panic!("Command stream has already been taken");
71 }
72 }
73
74 CommandStream { inner: self.inner.clone() }
75 }
76
77 pub fn send_command(&self, payload: &[u8]) -> Result<CommandResponseStream> {
80 let id = self.inner.add_response_waiter()?;
81 let avctp_header = Header::new(id, AV_REMOTE_PROFILE.clone(), MessageType::Command, false);
82 {
83 self.inner.send_packet(&avctp_header, payload)?;
84 }
85
86 Ok(CommandResponseStream::new(avctp_header.label().clone(), self.inner.clone()))
87 }
88}
89
90impl PeerInner {
91 fn new(channel: Channel) -> Self {
92 Self {
93 channel,
94 response_waiters: Mutex::new(Slab::<ResponseWaiter>::new()),
95 incoming_requests: Mutex::<CommandQueue>::default(),
96 }
97 }
98
99 fn add_response_waiter(&self) -> Result<TxLabel> {
102 let key = self.response_waiters.lock().insert(ResponseWaiter::default());
103 let id = TxLabel::try_from(key as u8);
104 if id.is_err() {
105 warn!("Transaction IDs are exhausted");
106 let _ = self.response_waiters.lock().remove(key);
107 }
108 id
109 }
110
111 fn remove_response_interest(&self, id: &TxLabel) {
114 let mut lock = self.response_waiters.lock();
115 let idx = usize::from(id);
116 let _ = lock.remove(idx);
117 }
118
119 fn poll_recv_request(&self, cx: &mut Context<'_>) -> Poll<Result<Packet>> {
124 let is_closed = self.recv_all(cx)?;
125
126 let mut lock = self.incoming_requests.lock();
127
128 match lock.queue.pop_front() {
129 Some(request) => Poll::Ready(Ok(request)),
130 _ => {
131 if is_closed {
132 Poll::Ready(Err(Error::PeerDisconnected))
133 } else {
134 lock.listener = CommandListener::Some(cx.waker().clone());
136 Poll::Pending
137 }
138 }
139 }
140 }
141
142 fn poll_recv_response(&self, label: &TxLabel, cx: &mut Context<'_>) -> Poll<Result<Packet>> {
147 let is_closed = self.recv_all(cx)?;
148
149 let mut waiters = self.response_waiters.lock();
150 let idx = usize::from(label);
151 let waiter = waiters.get_mut(idx).expect("Polled unregistered waiter");
154 if waiter.has_response() {
155 let packet = waiter.pop_received();
157 Poll::Ready(Ok(packet))
158 } else {
159 if is_closed {
160 Poll::Ready(Err(Error::PeerDisconnected))
161 } else {
162 waiter.listener = ResponseListener::Some(cx.waker().clone());
164 Poll::Pending
165 }
166 }
167 }
168
169 fn recv_all(&self, cx: &mut Context<'_>) -> Result<bool> {
173 let mut buf = Vec::<u8>::new();
174 loop {
175 let packet_size = match self.channel.poll_datagram(cx, &mut buf) {
176 Poll::Ready(Err(zx::Status::PEER_CLOSED)) => {
177 trace!("Peer closed");
178 return Ok(true);
179 }
180 Poll::Ready(Err(e)) => return Err(Error::PeerRead(e)),
181 Poll::Pending => return Ok(false),
182 Poll::Ready(Ok(size)) => size,
183 };
184 if packet_size == 0 {
185 continue;
186 }
187 trace!("received packet {:?}", buf);
188 let avctp_header = match Header::decode(buf.as_slice()) {
192 Err(_) => {
193 info!("received unrejectable message");
196 buf = buf.split_off(packet_size);
197 continue;
198 }
199 Ok(x) => x,
200 };
201
202 if avctp_header.profile_id() != AV_REMOTE_PROFILE {
205 info!("received packet not targeted at remote profile service class");
206 let resp_avct = avctp_header.create_invalid_profile_id_response();
207 self.send_packet(&resp_avct, &[])?;
208 buf = buf.split_off(packet_size);
209 continue;
210 }
211
212 if packet_size == avctp_header.encoded_len() {
213 info!("received incomplete packet");
215 buf = buf.split_off(packet_size);
216 continue;
217 }
218
219 let rest = buf.split_off(packet_size);
220 let body = buf.split_off(avctp_header.encoded_len());
221 match avctp_header.message_type() {
223 MessageType::Command => {
224 let mut lock = self.incoming_requests.lock();
225 lock.queue.push_back(Packet { header: avctp_header, body: body.to_vec() });
226 if let CommandListener::Some(ref waker) = lock.listener {
227 waker.wake_by_ref();
228 }
229 buf = rest;
230 }
231 MessageType::Response => {
232 let mut waiters = self.response_waiters.lock();
234 let idx = usize::from(avctp_header.label());
235
236 if let Some(waiter) = waiters.get_mut(idx) {
237 waiter
238 .queue
239 .push_back(Packet { header: avctp_header, body: body.to_vec() });
240 let old_entry = mem::replace(&mut waiter.listener, ResponseListener::New);
241 if let ResponseListener::Some(waker) = old_entry {
242 waker.wake();
243 }
244 } else {
245 trace!("response for {:?} we did not send, dropping", avctp_header.label());
246 };
247 buf = rest;
248 }
250 }
251 }
252 }
253
254 fn wake_any(&self) {
257 {
262 let lock = self.response_waiters.lock();
263 for (_, response_waiter) in lock.iter() {
264 if let ResponseListener::Some(ref waker) = response_waiter.listener {
265 waker.wake_by_ref();
266 return;
267 }
268 }
269 }
270 {
271 let lock = self.incoming_requests.lock();
272 if let CommandListener::Some(ref waker) = lock.listener {
273 waker.wake_by_ref();
274 return;
275 }
276 }
277 }
278
279 pub fn send_packet(&self, resp_header: &Header, body: &[u8]) -> Result<()> {
280 let mut rbuf = vec![0 as u8; resp_header.encoded_len()];
281 resp_header.encode(&mut rbuf)?;
282 if body.len() > 0 {
283 rbuf.extend_from_slice(body);
284 }
285 let _ = self.channel.write(rbuf.as_slice()).map_err(|x| Error::PeerWrite(x))?;
286 Ok(())
287 }
288}
289
290#[derive(Debug)]
292pub struct CommandStream {
293 inner: Arc<PeerInner>,
294}
295
296impl Unpin for CommandStream {}
297
298impl Stream for CommandStream {
299 type Item = Result<Command>;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 Poll::Ready(match ready!(self.inner.poll_recv_request(cx)) {
303 Ok(Packet { header, body, .. }) => {
304 Some(Ok(Command { peer: self.inner.clone(), avctp_header: header, body }))
305 }
306 Err(Error::PeerDisconnected) => None,
307 Err(e) => Some(Err(e)),
308 })
309 }
310}
311
312impl Drop for CommandStream {
313 fn drop(&mut self) {
314 self.inner.incoming_requests.lock().listener = CommandListener::None;
315 self.inner.wake_any();
316 }
317}
318
319#[derive(Debug)]
320pub struct Command {
321 peer: Arc<PeerInner>,
322 avctp_header: Header,
323 body: Vec<u8>,
324}
325
326impl Command {
327 pub fn header(&self) -> &Header {
328 &self.avctp_header
329 }
330
331 pub fn body(&self) -> &[u8] {
332 &self.body[..]
333 }
334
335 pub fn send_response(&self, body: &[u8]) -> Result<()> {
336 let response_header = self.avctp_header.create_response(PacketType::Single);
337 self.peer.send_packet(&response_header, body)
338 }
339}
340
341#[derive(Debug)]
342pub struct Packet {
343 header: Header,
344 body: Vec<u8>,
345}
346
347impl Packet {
348 pub fn header(&self) -> &Header {
349 &self.header
350 }
351
352 pub fn body(&self) -> &[u8] {
353 &self.body[..]
354 }
355}
356
357#[derive(Debug, Default)]
358struct CommandQueue {
359 listener: CommandListener,
360 queue: VecDeque<Packet>,
361}
362
363#[derive(Debug)]
364enum CommandListener {
365 None,
367 New,
369 Some(Waker),
371}
372
373impl Default for CommandListener {
374 fn default() -> Self {
375 CommandListener::None
376 }
377}
378
379#[derive(Debug, Default)]
380struct ResponseWaiter {
381 listener: ResponseListener,
382 queue: VecDeque<Packet>,
383}
384
385#[derive(Debug)]
387enum ResponseListener {
388 New,
390 Some(Waker),
392}
393
394impl Default for ResponseListener {
395 fn default() -> Self {
396 ResponseListener::New
397 }
398}
399
400impl ResponseWaiter {
401 fn has_response(&self) -> bool {
403 !self.queue.is_empty()
404 }
405
406 fn pop_received(&mut self) -> Packet {
407 if !self.has_response() {
408 panic!("expected received buf");
409 }
410 self.queue.pop_front().expect("response listener packet queue is unexpectedly empty")
411 }
412}
413
414#[derive(Debug)]
418pub struct CommandResponseStream {
419 id: Option<TxLabel>,
420 inner: Arc<PeerInner>,
421 done: bool,
422}
423
424impl CommandResponseStream {
425 fn new(id: TxLabel, inner: Arc<PeerInner>) -> CommandResponseStream {
426 CommandResponseStream { id: Some(id), inner, done: false }
427 }
428
429 pub fn complete(&mut self) {
430 if let Some(id) = &self.id {
431 self.inner.remove_response_interest(id);
432 self.id = None;
433 self.done = true;
434 self.inner.wake_any();
435 }
436 }
437}
438
439impl Unpin for CommandResponseStream {}
440
441impl FusedStream for CommandResponseStream {
442 fn is_terminated(&self) -> bool {
443 self.done == true
444 }
445}
446
447impl Stream for CommandResponseStream {
448 type Item = Result<Packet>;
449 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
450 let this = &mut *self;
451 if let Some(id) = &this.id {
452 Poll::Ready(match ready!(this.inner.poll_recv_response(id, cx)) {
453 Ok(packet) => {
454 trace!("received response packet {:?}", packet);
455 if packet.header().is_invalid_profile_id() {
456 Some(Err(Error::InvalidProfileId))
457 } else {
458 Some(Ok(packet))
459 }
460 }
461 Err(Error::PeerDisconnected) => {
462 this.done = true;
463 None
464 }
465 Err(e) => Some(Err(e)),
466 })
467 } else {
468 this.done = true;
469 return Poll::Ready(None);
470 }
471 }
472}
473
474impl Drop for CommandResponseStream {
475 fn drop(&mut self) {
476 self.complete();
477 }
478}