1use super::super::handle::ReadValue;
9use super::super::stream::{
10 Frame, StreamReader, StreamReaderBinder, StreamWriter, StreamWriterBinder,
11};
12use super::super::{
13 Proxy, ProxyTransferInitiationReceiver, Proxyable, ProxyableRW, RemoveFromProxyTable,
14 StreamRefSender,
15};
16use crate::labels::{NodeId, TransferKey};
17use crate::peer::{FramedStreamReader, FramedStreamWriter};
18use anyhow::{bail, format_err, Context as _, Error};
19use futures::future::Either;
20use futures::prelude::*;
21use std::sync::{Arc, Mutex};
22use zx_status;
23
24#[derive(Debug)]
28enum FinishProxyLoopAction<Hdl: Proxyable> {
29 InitiateTransfer {
30 paired_handle: fidl::Handle,
31 drain_stream: FramedStreamWriter,
32 stream_ref_sender: StreamRefSender,
33 stream_reader: StreamReader<Hdl::Message>,
34 },
35 FollowTransfer {
36 initiate_transfer: ProxyTransferInitiationReceiver,
37 new_destination_node: NodeId,
38 transfer_key: TransferKey,
39 stream_reader: StreamReader<Hdl::Message>,
40 },
41 Shutdown {
42 result: Result<(), zx_status::Status>,
43 stream_reader: StreamReader<Hdl::Message>,
44 },
45}
46
47struct FinishProxyLoopSender<Hdl: Proxyable> {
48 chan: futures::channel::oneshot::Sender<FinishProxyLoopAction<Hdl>>,
49}
50type FinishProxyLoopReceiver<Hdl> = futures::channel::oneshot::Receiver<FinishProxyLoopAction<Hdl>>;
51
52impl<Hdl: 'static + Proxyable> FinishProxyLoopSender<Hdl> {
53 fn and_then(self, then: FinishProxyLoopAction<Hdl>) -> Result<(), Error> {
54 Ok(self.chan.send(then).map_err(|_| format_err!("Join channel broken"))?)
55 }
56
57 fn and_then_initiate(
59 self,
60 paired_handle: fidl::Handle,
61 drain_stream: FramedStreamWriter,
62 stream_ref_sender: StreamRefSender,
63 stream_reader: StreamReader<Hdl::Message>,
64 ) -> Result<(), Error> {
65 self.and_then(FinishProxyLoopAction::InitiateTransfer {
66 paired_handle,
67 drain_stream,
68 stream_ref_sender,
69 stream_reader,
70 })
71 }
72
73 fn and_then_follow(
75 self,
76 initiate_transfer: ProxyTransferInitiationReceiver,
77 new_destination_node: NodeId,
78 transfer_key: TransferKey,
79 stream_reader: StreamReader<Hdl::Message>,
80 ) -> Result<(), Error> {
81 self.and_then(FinishProxyLoopAction::FollowTransfer {
82 initiate_transfer,
83 new_destination_node,
84 transfer_key,
85 stream_reader,
86 })
87 }
88
89 fn and_then_shutdown(
90 self,
91 result: Result<(), zx_status::Status>,
92 stream_reader: StreamReader<Hdl::Message>,
93 ) -> Result<(), Error> {
94 self.and_then(FinishProxyLoopAction::Shutdown { result, stream_reader })
95 }
96}
97
98fn new_task_joiner<Hdl: Proxyable>() -> (FinishProxyLoopSender<Hdl>, FinishProxyLoopReceiver<Hdl>) {
99 let (tx, rx) = futures::channel::oneshot::channel();
100 (FinishProxyLoopSender { chan: tx }, rx)
101}
102
103static PROXY_DROP_EVENT: Mutex<Option<Box<dyn Fn(&Result<(), Error>) + 'static + Send>>> =
105 Mutex::new(None);
106
107pub fn set_proxy_drop_event_handler(handler: impl Fn(&Result<(), Error>) + 'static + Send) {
110 *PROXY_DROP_EVENT.lock().unwrap() = Some(Box::new(handler));
111}
112
113pub(crate) async fn run_main_loop<Hdl: 'static + for<'a> ProxyableRW<'a>>(
115 proxy: Arc<Proxy<Hdl>>,
116 initiate_transfer: ProxyTransferInitiationReceiver,
117 stream_writer: FramedStreamWriter,
118 initial_stream_reader: Option<FramedStreamReader>,
119 stream_reader: FramedStreamReader,
120) -> Result<(), Error> {
121 assert!(Arc::strong_count(&proxy) == 1);
122 let (tx_join, rx_join) = new_task_joiner();
123 let hdl = proxy.hdl();
124 let mut stream_writer = stream_writer.bind(hdl);
125 let initial_stream_reader = initial_stream_reader.map(|s| s.bind(hdl));
126 let mut stream_reader = stream_reader.bind(hdl);
127 let res = futures::future::try_join(
128 async {
129 if !stream_reader.is_initiator() {
130 stream_reader.expect_hello().await?;
131 } else {
132 stream_writer.send_hello().await?;
133 }
134 Ok::<(), Error>(())
135 },
136 async {
137 if let Some(initial_stream_reader) = initial_stream_reader {
138 drain(proxy.clone(), initial_stream_reader).await?;
139 }
140 Ok(())
141 },
142 )
143 .await;
144
145 if let Err(e) = res {
146 Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
147 return Err(e);
148 }
149
150 let mut my_proxy = Some(Arc::clone(&proxy));
151
152 let take_proxy = || {
153 my_proxy = None;
154 };
155
156 let res = futures::future::try_join(
157 stream_to_handle(proxy.clone(), initiate_transfer, stream_reader, tx_join)
158 .map_err(|e| e.context("stream_to_handle")),
159 handle_to_stream(proxy, stream_writer, rx_join, take_proxy)
160 .map_err(|e| e.context("handle_to_stream")),
161 )
162 .map_ok(drop)
163 .await;
164
165 if let Some(cb) = &*PROXY_DROP_EVENT.lock().unwrap() {
166 cb(&res)
167 }
168 if let Err(e) = res {
169 if let Some(proxy) = my_proxy {
170 Arc::try_unwrap(proxy).unwrap().close_with_reason(format!("{e:?}"));
171 }
172 Err(e)
173 } else {
174 Ok(())
175 }
176}
177
178async fn handle_to_stream<Hdl: 'static + for<'a> ProxyableRW<'a>>(
179 proxy: Arc<Proxy<Hdl>>,
180 mut stream: StreamWriter<Hdl::Message>,
181 mut finish_proxy_loop: FinishProxyLoopReceiver<Hdl>,
182 take_proxy: impl FnOnce(),
183) -> Result<(), Error> {
184 let mut message = Default::default();
185 let finish_proxy_loop_action = loop {
186 let sr =
187 futures::future::select(&mut finish_proxy_loop, proxy.read_from_handle(&mut message))
188 .await;
189 match sr {
190 Either::Left((finish_proxy_loop_action, _)) => {
191 break finish_proxy_loop_action;
193 }
194 Either::Right((Err(zx_status::Status::PEER_CLOSED), _)) => {
195 if let Some(finish_proxy_loop_action) = finish_proxy_loop.now_or_never() {
196 break finish_proxy_loop_action;
197 }
198 stream.send_shutdown(Ok(())).await.context("send_shutdown")?;
199 return Ok(());
200 }
201 Either::Right((Err(x), _)) => {
202 stream.send_shutdown(Err(x)).await.context("send_shutdown")?;
203 return Err(x).context("read_from_handle");
204 }
205 Either::Right((Ok(ReadValue::Message), _)) => {
206 drop(sr);
207 stream.send_data(&mut message).await.context("send_data")?;
208 }
209 Either::Right((Ok(ReadValue::SignalUpdate(signal_update)), _)) => {
210 stream.send_signal(signal_update).await.context("send_signal")?;
211 }
212 };
213 };
214 take_proxy();
215 let proxy = Arc::try_unwrap(proxy).map_err(|_| format_err!("Proxy should be isolated"))?;
216 match finish_proxy_loop_action {
217 Ok(FinishProxyLoopAction::InitiateTransfer {
218 paired_handle,
219 drain_stream,
220 stream_ref_sender,
221 stream_reader,
222 }) => {
223 super::xfer::initiate(
224 proxy,
225 paired_handle,
226 stream,
227 stream_reader,
228 drain_stream,
229 stream_ref_sender,
230 )
231 .await
232 }
233 Ok(FinishProxyLoopAction::FollowTransfer {
234 initiate_transfer,
235 new_destination_node,
236 transfer_key,
237 stream_reader,
238 }) => {
239 super::xfer::follow(
240 proxy,
241 initiate_transfer,
242 stream,
243 new_destination_node,
244 transfer_key,
245 stream_reader,
246 )
247 .await
248 }
249 Ok(FinishProxyLoopAction::Shutdown { result, stream_reader }) => {
250 join_shutdown(proxy, stream, stream_reader, result).await
251 }
252 Err(futures::channel::oneshot::Canceled) => unreachable!(),
253 }
254}
255
256async fn join_shutdown<Hdl: 'static + Proxyable>(
257 proxy: Proxy<Hdl>,
258 stream_writer: StreamWriter<Hdl::Message>,
259 stream_reader: StreamReader<Hdl::Message>,
260 result: Result<(), zx_status::Status>,
261) -> Result<(), Error> {
262 stream_writer.send_shutdown(result).await?;
263 let _ = stream_reader.expect_shutdown(Ok(())).await;
264 proxy.close_with_reason(format!("Proxy shut down (result: {result:?})"));
265 Ok(())
266}
267
268async fn drain<Hdl: 'static + for<'a> ProxyableRW<'a>>(
269 proxy: Arc<Proxy<Hdl>>,
270 mut drain_stream: StreamReader<Hdl::Message>,
271) -> Result<(), Error> {
272 loop {
273 let frame = drain_stream.next().await?;
274 match frame {
275 Frame::Data(message) => proxy.write_to_handle(message).await?,
276 Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
277 Frame::EndTransfer => break,
278 Frame::BeginTransfer(_, _) => bail!("BeginTransfer on drain stream"),
279 Frame::AckTransfer => bail!("AckTransfer on drain stream"),
280 Frame::Hello => bail!("Hello frame disallowed on drain streams"),
281 Frame::Shutdown(r) => bail!("Stream shutdown during drain: {:?}", r),
282 }
283 }
284 Ok(())
285}
286
287async fn stream_to_handle<Hdl: 'static + for<'a> ProxyableRW<'a>>(
288 proxy: Arc<Proxy<Hdl>>,
289 mut initiate_transfer: ProxyTransferInitiationReceiver,
290 mut stream: StreamReader<Hdl::Message>,
291 finish_proxy_loop: FinishProxyLoopSender<Hdl>,
292) -> Result<(), Error> {
293 let removed_from_proxy_table = loop {
294 let frame = match futures::future::select(&mut initiate_transfer, stream.next()).await {
295 Either::Left((removed_from_proxy_table, _)) => {
296 break removed_from_proxy_table;
299 }
300 Either::Right((frame, _)) => frame.context("stream.next()")?,
301 };
302 match frame {
303 Frame::Data(message) => {
304 if let Err(e) = proxy.write_to_handle(message).await {
305 let _ = finish_proxy_loop.and_then_shutdown(Err(e), stream);
306 match e {
307 zx_status::Status::PEER_CLOSED => {
308 return Ok(());
309 }
310 _ => return Err(e).context("write_to_handle"),
311 }
312 }
313 }
314 Frame::SignalUpdate(signal_update) => proxy.apply_signal_update(signal_update)?,
315 Frame::BeginTransfer(new_destination_node, transfer_key) => {
316 return finish_proxy_loop
317 .and_then_follow(initiate_transfer, new_destination_node, transfer_key, stream)
318 .context("finish_proxy_loop")
319 }
320 Frame::EndTransfer => bail!("Received EndTransfer on a regular stream"),
321 Frame::AckTransfer => bail!("Received AckTransfer before sending a BeginTransfer"),
322 Frame::Hello => bail!("Hello frame received after stream established"),
323 Frame::Shutdown(result) => {
324 let _ = finish_proxy_loop.and_then_shutdown(result, stream);
325 return result.context("Remote shutdown");
326 }
327 }
328 };
329 match removed_from_proxy_table {
330 Err(e) => Err(e.into()),
331 Ok(RemoveFromProxyTable::Dropped) => unreachable!(),
332 Ok(RemoveFromProxyTable::InitiateTransfer {
333 paired_handle,
334 drain_stream,
335 stream_ref_sender,
336 }) => Ok(finish_proxy_loop.and_then_initiate(
337 paired_handle,
338 drain_stream,
339 stream_ref_sender,
340 stream,
341 )?),
342 }
343}