overnet_core/proxy/run/
main.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5//! Main loops (and associated spawn functions) for proxying... handles moving data from one point
6//! to another, and calling into crate::proxy::xfer once a handle transfer is required.
7
8use super::super::handle::ReadValue;
9use super::super::stream::{
10    Frame, StreamReader, StreamReaderBinder, StreamWriter, StreamWriterBinder,
11};
12use super::super::{
13    Proxy, ProxyTransferInitiationReceiver, Proxyable, ProxyableRW, RemoveFromProxyTable,
14    StreamRefSender,
15};
16use crate::labels::{NodeId, TransferKey};
17use crate::peer::{FramedStreamReader, FramedStreamWriter};
18use anyhow::{bail, format_err, Context as _, Error};
19use futures::future::Either;
20use futures::prelude::*;
21use std::sync::{Arc, Mutex};
22use zx_status;
23
24// We run two tasks to proxy a handle - one to handle handle->stream, the other to handle
25// stream->handle. When we want to perform a transfer operation we end up wanting to think about
26// just one task, so we provide a join operation here.
27#[derive(Debug)]
28enum FinishProxyLoopAction<Hdl: Proxyable> {
29    InitiateTransfer {
30        paired_handle: fidl::Handle,
31        drain_stream: FramedStreamWriter,
32        stream_ref_sender: StreamRefSender,
33        stream_reader: StreamReader<Hdl::Message>,
34    },
35    FollowTransfer {
36        initiate_transfer: ProxyTransferInitiationReceiver,
37        new_destination_node: NodeId,
38        transfer_key: TransferKey,
39        stream_reader: StreamReader<Hdl::Message>,
40    },
41    Shutdown {
42        result: Result<(), zx_status::Status>,
43        stream_reader: StreamReader<Hdl::Message>,
44    },
45}
46
47struct FinishProxyLoopSender<Hdl: Proxyable> {
48    chan: futures::channel::oneshot::Sender<FinishProxyLoopAction<Hdl>>,
49}
50type FinishProxyLoopReceiver<Hdl> = futures::channel::oneshot::Receiver<FinishProxyLoopAction<Hdl>>;
51
52impl<Hdl: 'static + Proxyable> FinishProxyLoopSender<Hdl> {
53    fn and_then(self, then: FinishProxyLoopAction<Hdl>) -> Result<(), Error> {
54        Ok(self.chan.send(then).map_err(|_| format_err!("Join channel broken"))?)
55    }
56
57    // This join is to initiate a new transfer.
58    fn and_then_initiate(
59        self,
60        paired_handle: fidl::Handle,
61        drain_stream: FramedStreamWriter,
62        stream_ref_sender: StreamRefSender,
63        stream_reader: StreamReader<Hdl::Message>,
64    ) -> Result<(), Error> {
65        self.and_then(FinishProxyLoopAction::InitiateTransfer {
66            paired_handle,
67            drain_stream,
68            stream_ref_sender,
69            stream_reader,
70        })
71    }
72
73    // This join is to follow a transfer initiated by the remote end.
74    fn and_then_follow(
75        self,
76        initiate_transfer: ProxyTransferInitiationReceiver,
77        new_destination_node: NodeId,
78        transfer_key: TransferKey,
79        stream_reader: StreamReader<Hdl::Message>,
80    ) -> Result<(), Error> {
81        self.and_then(FinishProxyLoopAction::FollowTransfer {
82            initiate_transfer,
83            new_destination_node,
84            transfer_key,
85            stream_reader,
86        })
87    }
88
89    fn and_then_shutdown(
90        self,
91        result: Result<(), zx_status::Status>,
92        stream_reader: StreamReader<Hdl::Message>,
93    ) -> Result<(), Error> {
94        self.and_then(FinishProxyLoopAction::Shutdown { result, stream_reader })
95    }
96}
97
98fn new_task_joiner<Hdl: Proxyable>() -> (FinishProxyLoopSender<Hdl>, FinishProxyLoopReceiver<Hdl>) {
99    let (tx, rx) = futures::channel::oneshot::channel();
100    (FinishProxyLoopSender { chan: tx }, rx)
101}
102
103/// Store behind [`set_proxy_drop_event_handler`]
104static PROXY_DROP_EVENT: Mutex<Option<Box<dyn Fn(&Result<(), Error>) + 'static + Send>>> =
105    Mutex::new(None);
106
107/// Sets a global callback to call every time a proxy is dropped. It's given a
108/// reference to the error and can be used to send metrics events.
109pub fn set_proxy_drop_event_handler(handler: impl Fn(&Result<(), Error>) + 'static + Send) {
110    *PROXY_DROP_EVENT.lock().unwrap() = Some(Box::new(handler));
111}
112
113// Spawn a proxy (two tasks, one for each direction of proxying).
114pub(crate) async fn run_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
115    proxy: Arc<Proxy<Hdl>>,
116    initiate_transfer: ProxyTransferInitiationReceiver,
117    stream_writer: FramedStreamWriter,
118    initial_stream_reader: Option<FramedStreamReader>,
119    stream_reader: FramedStreamReader,
120) -> Result<(), Error> {
121    assert!(Arc::strong_count(&proxy) == 1);
122    let (tx_join, rx_join) = new_task_joiner();
123    let hdl = proxy.hdl();
124    let mut stream_writer = stream_writer.bind(hdl);
125    let initial_stream_reader = initial_stream_reader.map(|s| s.bind(hdl));
126    let mut stream_reader = stream_reader.bind(hdl);
127    let res = futures::future::try_join(
128        async {
129            if !stream_reader.is_initiator() {
130                stream_reader.expect_hello().await?;
131            } else {
132                stream_writer.send_hello().await?;
133            }
134            Ok::<(), Error>(())
135        },
136        async {
137            if let Some(initial_stream_reader) = initial_stream_reader {
138                drain(proxy.clone(), initial_stream_reader).await?;
139            }
140            Ok(())
141        },
142    )
143    .await;
144
145    if let Err(e) = res {
146        Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
147        return Err(e);
148    }
149
150    let mut my_proxy = Some(Arc::clone(&proxy));
151
152    let take_proxy = || {
153        my_proxy = None;
154    };
155
156    let res = futures::future::try_join(
157        stream_to_handle(proxy.clone(), initiate_transfer, stream_reader, tx_join)
158            .map_err(|e| e.context("stream_to_handle")),
159        handle_to_stream(proxy, stream_writer, rx_join, take_proxy)
160            .map_err(|e| e.context("handle_to_stream")),
161    )
162    .map_ok(drop)
163    .await;
164
165    if let Some(cb) = &*PROXY_DROP_EVENT.lock().unwrap() {
166        cb(&res)
167    }
168    if let Err(e) = res {
169        if let Some(proxy) = my_proxy {
170            Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
171        }
172        Err(e)
173    } else {
174        Ok(())
175    }
176}
177
178async fn handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
179    proxy: Arc<Proxy<Hdl>>,
180    mut stream: StreamWriter<Hdl::Message>,
181    mut finish_proxy_loop: FinishProxyLoopReceiver<Hdl>,
182    take_proxy: impl FnOnce(),
183) -> Result<(), Error> {
184    let mut message = Default::default();
185    let finish_proxy_loop_action = loop {
186        let sr =
187            futures::future::select(&mut finish_proxy_loop, proxy.read_from_handle(&mut message))
188                .await;
189        match sr {
190            Either::Left((finish_proxy_loop_action, _)) => {
191                // Note: Proxy guarantees that read_from_handle can be dropped safely without losing data.
192                break finish_proxy_loop_action;
193            }
194            Either::Right((Err(zx_status::Status::PEER_CLOSED), _)) => {
195                if let Some(finish_proxy_loop_action) = finish_proxy_loop.now_or_never() {
196                    break finish_proxy_loop_action;
197                }
198                stream.send_shutdown(Ok(())).await.context("send_shutdown")?;
199                return Ok(());
200            }
201            Either::Right((Err(x), _)) => {
202                stream.send_shutdown(Err(x)).await.context("send_shutdown")?;
203                return Err(x).context("read_from_handle");
204            }
205            Either::Right((Ok(ReadValue::Message), _)) => {
206                drop(sr);
207                stream.send_data(&mut message).await.context("send_data")?;
208            }
209            Either::Right((Ok(ReadValue::SignalUpdate(signal_update)), _)) => {
210                stream.send_signal(signal_update).await.context("send_signal")?;
211            }
212        };
213    };
214    take_proxy();
215    let proxy = Arc::try_unwrap(proxy).map_err(|_| format_err!("Proxy should be isolated"))?;
216    match finish_proxy_loop_action {
217        Ok(FinishProxyLoopAction::InitiateTransfer {
218            paired_handle,
219            drain_stream,
220            stream_ref_sender,
221            stream_reader,
222        }) => {
223            super::xfer::initiate(
224                proxy,
225                paired_handle,
226                stream,
227                stream_reader,
228                drain_stream,
229                stream_ref_sender,
230            )
231            .await
232        }
233        Ok(FinishProxyLoopAction::FollowTransfer {
234            initiate_transfer,
235            new_destination_node,
236            transfer_key,
237            stream_reader,
238        }) => {
239            super::xfer::follow(
240                proxy,
241                initiate_transfer,
242                stream,
243                new_destination_node,
244                transfer_key,
245                stream_reader,
246            )
247            .await
248        }
249        Ok(FinishProxyLoopAction::Shutdown { result, stream_reader }) => {
250            join_shutdown(proxy, stream, stream_reader, result).await
251        }
252        Err(futures::channel::oneshot::Canceled) => unreachable!(),
253    }
254}
255
256async fn join_shutdown<Hdl: 'static + Proxyable>(
257    proxy: Proxy<Hdl>,
258    stream_writer: StreamWriter<Hdl::Message>,
259    stream_reader: StreamReader<Hdl::Message>,
260    result: Result<(), zx_status::Status>,
261) -> Result<(), Error> {
262    stream_writer.send_shutdown(result).await?;
263    let _ = stream_reader.expect_shutdown(Ok(())).await;
264    proxy.close_with_reason(format!("Proxy shut down (result: {result:?})"));
265    Ok(())
266}
267
268async fn drain<Hdl: 'static + for<'a> ProxyableRW<'a>>(
269    proxy: Arc<Proxy<Hdl>>,
270    mut drain_stream: StreamReader<Hdl::Message>,
271) -> Result<(), Error> {
272    loop {
273        let frame = drain_stream.next().await?;
274        match frame {
275            Frame::Data(message) => proxy.write_to_handle(message).await?,
276            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
277            Frame::EndTransfer => break,
278            Frame::BeginTransfer(_, _) => bail!("BeginTransfer on drain stream"),
279            Frame::AckTransfer => bail!("AckTransfer on drain stream"),
280            Frame::Hello => bail!("Hello frame disallowed on drain streams"),
281            Frame::Shutdown(r) => bail!("Stream shutdown during drain: {:?}", r),
282        }
283    }
284    Ok(())
285}
286
287async fn stream_to_handle<Hdl: 'static + for<'a> ProxyableRW<'a>>(
288    proxy: Arc<Proxy<Hdl>>,
289    mut initiate_transfer: ProxyTransferInitiationReceiver,
290    mut stream: StreamReader<Hdl::Message>,
291    finish_proxy_loop: FinishProxyLoopSender<Hdl>,
292) -> Result<(), Error> {
293    let removed_from_proxy_table = loop {
294        let frame = match futures::future::select(&mut initiate_transfer, stream.next()).await {
295            Either::Left((removed_from_proxy_table, _)) => {
296                // Note: StreamReader guarantees it's safe to drop a partial read without
297                // losing data.
298                break removed_from_proxy_table;
299            }
300            Either::Right((frame, _)) => frame.context("stream.next()")?,
301        };
302        match frame {
303            Frame::Data(message) => {
304                if let Err(e) = proxy.write_to_handle(message).await {
305                    let _ = finish_proxy_loop.and_then_shutdown(Err(e), stream);
306                    match e {
307                        zx_status::Status::PEER_CLOSED => {
308                            return Ok(());
309                        }
310                        _ => return Err(e).context("write_to_handle"),
311                    }
312                }
313            }
314            Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
315            Frame::BeginTransfer(new_destination_node, transfer_key) => {
316                return finish_proxy_loop
317                    .and_then_follow(initiate_transfer, new_destination_node, transfer_key, stream)
318                    .context("finish_proxy_loop")
319            }
320            Frame::EndTransfer => bail!("Received EndTransfer on a regular stream"),
321            Frame::AckTransfer => bail!("Received AckTransfer before sending a BeginTransfer"),
322            Frame::Hello => bail!("Hello frame received after stream established"),
323            Frame::Shutdown(result) => {
324                let _ = finish_proxy_loop.and_then_shutdown(result, stream);
325                return result.context("Remote shutdown");
326            }
327        }
328    };
329    match removed_from_proxy_table {
330        Err(e) => Err(e.into()),
331        Ok(RemoveFromProxyTable::Dropped) => unreachable!(),
332        Ok(RemoveFromProxyTable::InitiateTransfer {
333            paired_handle,
334            drain_stream,
335            stream_ref_sender,
336        }) => Ok(finish_proxy_loop.and_then_initiate(
337            paired_handle,
338            drain_stream,
339            stream_ref_sender,
340            stream,
341        )?),
342    }
343}