overnet_core/proxy/handle/
socket.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::signals::Collector;
6use super::{
7    IntoProxied, Message, Proxyable, ProxyableRW, ReadValue, RouterHolder, Serializer, IO,
8};
9use crate::peer::PeerConnRef;
10use anyhow::Error;
11use fidl::{AsHandleRef, AsyncSocket, HandleBased, Peered, Signals};
12use futures::io::{AsyncRead, AsyncWrite};
13use futures::ready;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use zx_status;
17
18pub(crate) struct Socket {
19    socket: AsyncSocket,
20}
21
22impl std::fmt::Debug for Socket {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        "Socket".fmt(f)
25    }
26}
27
28impl Proxyable for Socket {
29    type Message = SocketMessage;
30
31    fn from_fidl_handle(hdl: fidl::Handle) -> Result<Self, Error> {
32        Ok(fidl::Socket::from_handle(hdl).into_proxied()?)
33    }
34
35    fn into_fidl_handle(self) -> Result<fidl::Handle, Error> {
36        Ok(self.socket.into_zx_socket().into_handle())
37    }
38
39    fn signal_peer(&self, clear: Signals, set: Signals) -> Result<(), Error> {
40        self.socket.as_ref().signal_peer(clear, set)?;
41        Ok(())
42    }
43}
44
45impl<'a> ProxyableRW<'a> for Socket {
46    type Reader = SocketReader<'a>;
47    type Writer = SocketWriter;
48}
49
50impl IntoProxied for fidl::Socket {
51    type Proxied = Socket;
52    fn into_proxied(self) -> Result<Socket, Error> {
53        // TODO(https://fxbug.dev/418249087): Handle socket disposition changes
54        // better. By using the async implementation in fuchsia_async, we
55        // observe closed on read or write whenever the socket is in a
56        // half-closed state. Which is proxied by overnet as a PEER_CLOSED
57        // instead. We should intercept the disposition signals there and send
58        // them over the proxy.
59        Ok(Socket { socket: AsyncSocket::from_socket(self) })
60    }
61}
62
63pub(crate) struct SocketReader<'a> {
64    collector: Collector<'a>,
65}
66
67impl<'a> IO<'a> for SocketReader<'a> {
68    type Proxyable = Socket;
69    type Output = ReadValue;
70    fn new() -> Self {
71        SocketReader { collector: Default::default() }
72    }
73    fn poll_io(
74        &mut self,
75        msg: &mut SocketMessage,
76        socket: &'a Socket,
77        fut_ctx: &mut Context<'_>,
78    ) -> Poll<Result<ReadValue, zx_status::Status>> {
79        const MIN_READ_LEN: usize = 65536;
80        if msg.0.len() < MIN_READ_LEN {
81            msg.0.resize(MIN_READ_LEN, 0u8);
82        }
83        let read_result = (|| {
84            let n = ready!(Pin::new(&mut &socket.socket).poll_read(fut_ctx, &mut msg.0))?;
85            if n == 0 {
86                return Poll::Ready(Err(zx_status::Status::PEER_CLOSED));
87            }
88            msg.0.truncate(n);
89            Poll::Ready(Ok(()))
90        })();
91        self.collector.after_read(fut_ctx, socket.socket.as_handle_ref(), read_result, false)
92    }
93}
94
95pub(crate) struct SocketWriter;
96
97impl IO<'_> for SocketWriter {
98    type Proxyable = Socket;
99    type Output = ();
100    fn new() -> Self {
101        SocketWriter
102    }
103    fn poll_io(
104        &mut self,
105        msg: &mut SocketMessage,
106        socket: &Socket,
107        fut_ctx: &mut Context<'_>,
108    ) -> Poll<Result<(), zx_status::Status>> {
109        while !msg.0.is_empty() {
110            let n = ready!(Pin::new(&mut &socket.socket).poll_write(fut_ctx, &msg.0))?;
111            msg.0.drain(..n);
112        }
113        Poll::Ready(Ok(()))
114    }
115}
116
117#[derive(Default, PartialEq)]
118pub(crate) struct SocketMessage(Vec<u8>);
119
120impl std::fmt::Debug for SocketMessage {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        self.0.fmt(f)
123    }
124}
125
126impl Message for SocketMessage {
127    type Parser = SocketMessageSerializer;
128    type Serializer = SocketMessageSerializer;
129}
130
131#[derive(Debug)]
132pub(crate) struct SocketMessageSerializer;
133
134impl Serializer for SocketMessageSerializer {
135    type Message = SocketMessage;
136    fn new() -> SocketMessageSerializer {
137        SocketMessageSerializer
138    }
139    fn poll_ser(
140        &mut self,
141        msg: &mut SocketMessage,
142        bytes: &mut Vec<u8>,
143        _: PeerConnRef<'_>,
144        _: &mut RouterHolder<'_>,
145        _: &mut Context<'_>,
146    ) -> Poll<Result<(), Error>> {
147        std::mem::swap(bytes, &mut msg.0);
148        Poll::Ready(Ok(()))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use futures::AsyncReadExt as _;
156
157    #[fuchsia::test]
158    async fn stream_socket_partial_write() {
159        let (tx, rx) = fidl::Socket::create_stream();
160        let socket = tx.into_proxied().expect("create proxied socket");
161
162        const KERNEL_BUF_SIZE: usize = 257024;
163        const EXPECTED_DATA: u8 = 0xff;
164        const EXPECTED_LEN: usize = KERNEL_BUF_SIZE * 2;
165
166        let mut writer = SocketWriter::new();
167        let mut msg = SocketMessage(vec![EXPECTED_DATA; EXPECTED_LEN]);
168        // Write more than the size of the underlying kernel buffer into the
169        // proxied socket to exercise that overnet handles partial writes to the
170        // zircon socket correctly.
171        fuchsia_async::Task::spawn(async {
172            futures::future::poll_fn(move |cx| writer.poll_io(&mut msg, &socket, cx))
173                .await
174                .expect("write to socket")
175        })
176        .detach();
177
178        let mut data = vec![0u8; EXPECTED_LEN];
179        let mut rx = fuchsia_async::Socket::from_socket(rx);
180        rx.read_exact(&mut data).await.expect("read from socket");
181        assert_eq!(data, vec![EXPECTED_DATA; EXPECTED_LEN]);
182    }
183}