overnet_core/proxy/handle/
channel.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::{
6    IntoProxied, Message, Proxyable, ProxyableRW, ReadValue, RouterHolder, Serializer, IO,
7};
8use crate::coding::{decode_fidl, encode_fidl};
9use crate::peer::PeerConnRef;
10use anyhow::{Context as _, Error};
11use fidl::{AsHandleRef, AsyncChannel, HandleBased, Peered, Signals};
12use fidl_fuchsia_overnet_protocol::{ZirconChannelMessage, ZirconHandle};
13use futures::prelude::*;
14use futures::ready;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use zx_status;
18
19pub(crate) struct Channel {
20    chan: AsyncChannel,
21}
22
23impl std::fmt::Debug for Channel {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        self.chan.fmt(f)
26    }
27}
28
29impl Proxyable for Channel {
30    type Message = ChannelMessage;
31
32    fn from_fidl_handle(hdl: fidl::Handle) -> Result<Self, Error> {
33        Ok(fidl::Channel::from_handle(hdl).into_proxied()?)
34    }
35
36    fn into_fidl_handle(self) -> Result<fidl::Handle, Error> {
37        Ok(self.chan.into_zx_channel().into_handle())
38    }
39
40    fn signal_peer(&self, clear: Signals, set: Signals) -> Result<(), Error> {
41        let chan: &fidl::Channel = self.chan.as_ref();
42        chan.signal_peer(clear, set)?;
43        Ok(())
44    }
45
46    #[cfg(not(target_os = "fuchsia"))]
47    fn close_with_reason(self, msg: String) {
48        self.chan.close_with_reason(msg);
49    }
50}
51
52impl<'a> ProxyableRW<'a> for Channel {
53    type Reader = ChannelReader<'a>;
54    type Writer = ChannelWriter;
55}
56
57impl IntoProxied for fidl::Channel {
58    type Proxied = Channel;
59    fn into_proxied(self) -> Result<Channel, Error> {
60        Ok(Channel { chan: AsyncChannel::from_channel(self) })
61    }
62}
63
64pub(crate) struct ChannelReader<'a> {
65    collector: super::signals::Collector<'a>,
66}
67
68impl<'a> IO<'a> for ChannelReader<'a> {
69    type Proxyable = Channel;
70    type Output = ReadValue;
71    fn new() -> ChannelReader<'a> {
72        ChannelReader { collector: Default::default() }
73    }
74    fn poll_io(
75        &mut self,
76        msg: &mut ChannelMessage,
77        channel: &'a Channel,
78        fut_ctx: &mut Context<'_>,
79    ) -> Poll<Result<ReadValue, zx_status::Status>> {
80        let read_result = channel.chan.read(fut_ctx, &mut msg.bytes, &mut msg.handles);
81        self.collector.after_read(fut_ctx, channel.chan.as_handle_ref(), read_result, false)
82    }
83}
84
85pub(crate) struct ChannelWriter;
86
87impl IO<'_> for ChannelWriter {
88    type Proxyable = Channel;
89    type Output = ();
90    fn new() -> ChannelWriter {
91        ChannelWriter
92    }
93    fn poll_io(
94        &mut self,
95        msg: &mut ChannelMessage,
96        channel: &Channel,
97        _: &mut Context<'_>,
98    ) -> Poll<Result<(), zx_status::Status>> {
99        Poll::Ready(Ok(channel.chan.write(&msg.bytes, &mut msg.handles)?))
100    }
101}
102
103#[derive(Default, Debug)]
104pub(crate) struct ChannelMessage {
105    bytes: Vec<u8>,
106    handles: Vec<fidl::Handle>,
107}
108
109impl Message for ChannelMessage {
110    type Parser = ChannelMessageParser;
111    type Serializer = ChannelMessageSerializer;
112}
113
114impl PartialEq for ChannelMessage {
115    fn eq(&self, rhs: &Self) -> bool {
116        if !self.handles.is_empty() {
117            return false;
118        }
119        if !rhs.handles.is_empty() {
120            return false;
121        }
122        return self.bytes == rhs.bytes;
123    }
124}
125
126pub(crate) enum ChannelMessageParser {
127    New,
128    Pending {
129        bytes: Vec<u8>,
130        handles: Pin<Box<dyn 'static + Send + Future<Output = Result<Vec<fidl::Handle>, Error>>>>,
131    },
132    Done,
133}
134
135impl std::fmt::Debug for ChannelMessageParser {
136    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137        match self {
138            ChannelMessageParser::New => "New",
139            ChannelMessageParser::Pending { .. } => "Pending",
140            ChannelMessageParser::Done => "Done",
141        }
142        .fmt(f)
143    }
144}
145
146impl Serializer for ChannelMessageParser {
147    type Message = ChannelMessage;
148    fn new() -> Self {
149        Self::New
150    }
151    fn poll_ser(
152        &mut self,
153        msg: &mut Self::Message,
154        serialized: &mut Vec<u8>,
155        conn: PeerConnRef<'_>,
156        router: &mut RouterHolder<'_>,
157        fut_ctx: &mut Context<'_>,
158    ) -> Poll<Result<(), Error>> {
159        log::trace!(msg:?, serialized:?, self:?; "ChannelMessageParser::poll_ser",);
160        match self {
161            ChannelMessageParser::New => {
162                let ZirconChannelMessage { mut bytes, handles: unbound_handles } =
163                    decode_fidl(serialized)?;
164                // Special case no handles case to avoid allocation dance
165                if unbound_handles.is_empty() {
166                    msg.handles.clear();
167                    std::mem::swap(&mut msg.bytes, &mut bytes);
168                    *self = ChannelMessageParser::Done;
169                    return Poll::Ready(Ok(()));
170                }
171                let closure_conn = conn.into_peer_conn();
172                let closure_router = router.get()?.clone();
173                *self = ChannelMessageParser::Pending {
174                    bytes,
175                    handles: async move {
176                        let mut handles = Vec::new();
177                        for hdl in unbound_handles.into_iter() {
178                            handles.push(
179                                closure_router
180                                    .clone()
181                                    .recv_proxied(hdl, closure_conn.as_ref())
182                                    .await?,
183                            );
184                        }
185                        Ok(handles)
186                    }
187                    .boxed(),
188                };
189                self.poll_ser(msg, serialized, conn, router, fut_ctx)
190            }
191            ChannelMessageParser::Pending { ref mut bytes, handles } => {
192                let mut handles = ready!(handles.as_mut().poll(fut_ctx))?;
193                std::mem::swap(&mut msg.handles, &mut handles);
194                std::mem::swap(&mut msg.bytes, bytes);
195                *self = ChannelMessageParser::Done;
196                Poll::Ready(Ok(()))
197            }
198            ChannelMessageParser::Done => unreachable!(),
199        }
200    }
201}
202
203pub(crate) enum ChannelMessageSerializer {
204    New,
205    Pending(Pin<Box<dyn 'static + Send + Future<Output = Result<Vec<ZirconHandle>, Error>>>>),
206    Done,
207}
208
209impl Serializer for ChannelMessageSerializer {
210    type Message = ChannelMessage;
211    fn new() -> Self {
212        Self::New
213    }
214    fn poll_ser(
215        &mut self,
216        msg: &mut Self::Message,
217        serialized: &mut Vec<u8>,
218        conn: PeerConnRef<'_>,
219        router: &mut RouterHolder<'_>,
220        fut_ctx: &mut Context<'_>,
221    ) -> Poll<Result<(), Error>> {
222        let self_val = match self {
223            ChannelMessageSerializer::New => "New",
224            ChannelMessageSerializer::Pending { .. } => "Pending",
225            ChannelMessageSerializer::Done => "Done",
226        };
227        log::trace!(msg:?, serialized:?, self = self_val; "ChannelMessageSerializer::poll_ser");
228        match self {
229            ChannelMessageSerializer::New => {
230                let handles = std::mem::replace(&mut msg.handles, Vec::new());
231                // Special case no handles case to avoid allocation dance
232                if handles.is_empty() {
233                    *serialized = encode_fidl(&mut ZirconChannelMessage {
234                        bytes: std::mem::replace(&mut msg.bytes, Vec::new()),
235                        handles: Vec::new(),
236                    })?;
237                    *self = ChannelMessageSerializer::Done;
238                    return Poll::Ready(Ok(()));
239                }
240                let closure_conn = conn.into_peer_conn();
241                let closure_router = router.get()?.clone();
242                *self = ChannelMessageSerializer::Pending(
243                    async move {
244                        let mut send_handles = Vec::new();
245                        for handle in handles {
246                            // save for debugging
247                            let raw_handle = handle.raw_handle();
248                            send_handles.push(
249                                closure_router
250                                    .send_proxied(handle, closure_conn.as_ref())
251                                    .await
252                                    .with_context(|| format!("Sending handle {:?}", raw_handle))?,
253                            );
254                        }
255                        Ok(send_handles)
256                    }
257                    .boxed(),
258                );
259                self.poll_ser(msg, serialized, conn, router, fut_ctx)
260            }
261            ChannelMessageSerializer::Pending(handles) => {
262                let handles = ready!(handles.as_mut().poll(fut_ctx))?;
263                *serialized = encode_fidl(&mut ZirconChannelMessage {
264                    bytes: std::mem::replace(&mut msg.bytes, Vec::new()),
265                    handles,
266                })?;
267                *self = ChannelMessageSerializer::Done;
268                Poll::Ready(Ok(()))
269            }
270            ChannelMessageSerializer::Done => unreachable!(),
271        }
272    }
273}