fuchsia_async/net/fuchsia/
udp.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#![deny(missing_docs)]
6
7use crate::net::EventedFd;
8use futures::future::Future;
9use futures::ready;
10use futures::task::{Context, Poll};
11use std::io;
12use std::net::{self, SocketAddr};
13use std::ops::Deref;
14use std::pin::Pin;
15
16fn new_socket_address_conversion_error() -> std::io::Error {
17    io::Error::other("socket address is not IPv4 or IPv6")
18}
19
20/// An I/O object representing a UDP socket.
21///
22/// Like [`std::net::UdpSocket`], a `UdpSocket` represents a socket that is
23/// bound to a local address, and optionally is connected to a remote address.
24#[derive(Debug)]
25pub struct UdpSocket(DatagramSocket);
26
27impl Deref for UdpSocket {
28    type Target = DatagramSocket;
29
30    fn deref(&self) -> &Self::Target {
31        &self.0
32    }
33}
34
35impl UdpSocket {
36    /// Creates an async UDP socket from the given address.
37    ///
38    /// See [`std::net::UdpSocket::bind()`].
39    pub fn bind(addr: &SocketAddr) -> io::Result<UdpSocket> {
40        let socket = net::UdpSocket::bind(addr)?;
41        UdpSocket::from_socket(socket)
42    }
43
44    /// Creates an async UDP socket from a [`std::net::UdpSocket`].
45    pub fn from_socket(socket: net::UdpSocket) -> io::Result<UdpSocket> {
46        let socket: socket2::Socket = socket.into();
47        socket.set_nonblocking(true)?;
48        let evented_fd = unsafe { EventedFd::new(socket)? };
49        Ok(UdpSocket(DatagramSocket(evented_fd)))
50    }
51
52    /// Create a new UDP socket from an existing bound socket.
53    pub fn from_datagram(socket: DatagramSocket) -> io::Result<Self> {
54        let sock: &socket2::Socket = socket.as_ref();
55        if sock.r#type()? != socket2::Type::DGRAM {
56            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket type is not datagram"));
57        }
58        if sock.protocol()? != Some(socket2::Protocol::UDP) {
59            return Err(io::Error::new(io::ErrorKind::InvalidInput, "socket protocol is not UDP"));
60        }
61        // Maintain the invariant that the socket is bound (or connected).
62        let _: socket2::SockAddr = socket.local_addr()?;
63        Ok(Self(socket))
64    }
65
66    /// Returns the socket address that this socket was created from.
67    pub fn local_addr(&self) -> io::Result<SocketAddr> {
68        self.0
69            .local_addr()
70            .and_then(|sa| sa.as_socket().ok_or_else(new_socket_address_conversion_error))
71    }
72
73    /// Receive a UDP datagram from the socket.
74    ///
75    /// Asynchronous version of [`std::net::UdpSocket::recv_from()`].
76    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> UdpRecvFrom<'a> {
77        UdpRecvFrom { socket: self, buf }
78    }
79
80    /// Send a UDP datagram via the socket.
81    ///
82    /// Asynchronous version of [`std::net::UdpSocket::send_to()`].
83    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: SocketAddr) -> SendTo<'a> {
84        SendTo { socket: self, buf, addr: addr.into() }
85    }
86
87    /// Asynchronously send a datagram (possibly split over multiple buffers) via the socket.
88    pub fn send_to_vectored<'a>(
89        &'a self,
90        bufs: &'a [io::IoSlice<'a>],
91        addr: SocketAddr,
92    ) -> SendToVectored<'a> {
93        SendToVectored { socket: self, bufs, addr: addr.into() }
94    }
95}
96
97/// An I/O object representing a datagram socket.
98#[derive(Debug)]
99pub struct DatagramSocket(EventedFd<socket2::Socket>);
100
101impl Deref for DatagramSocket {
102    type Target = EventedFd<socket2::Socket>;
103
104    fn deref(&self) -> &Self::Target {
105        &self.0
106    }
107}
108
109impl DatagramSocket {
110    /// Create a new async datagram socket.
111    pub fn new(domain: socket2::Domain, protocol: Option<socket2::Protocol>) -> io::Result<Self> {
112        let socket = socket2::Socket::new(domain, socket2::Type::DGRAM.nonblocking(), protocol)?;
113        let evented_fd = unsafe { EventedFd::new(socket)? };
114        Ok(Self(evented_fd))
115    }
116
117    /// Create a new async datagram socket from an existing socket.
118    pub fn new_from_socket(socket: socket2::Socket) -> io::Result<Self> {
119        match socket.r#type()? {
120            socket2::Type::DGRAM
121            // SOCK_RAW sockets operate on raw datagrams (e.g. datagrams that
122            // include the frame/packet header). For the purposes of
123            // `DatagramSocket`, their semantics are identical.
124            | socket2::Type::RAW => {
125                socket.set_nonblocking(true)?;
126                let evented_fd = unsafe { EventedFd::new(socket)? };
127                Ok(Self(evented_fd))
128            }
129            _ => Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid socket type.")),
130        }
131    }
132
133    /// Returns the socket address that this socket was created from.
134    pub fn local_addr(&self) -> io::Result<socket2::SockAddr> {
135        self.0.as_ref().local_addr()
136    }
137
138    /// Receive a datagram asynchronously from the socket.
139    ///
140    /// The returned future will resolve with the number of bytes read and the source address of
141    /// the datagram on success.
142    pub fn recv_from<'a>(&'a self, buf: &'a mut [u8]) -> RecvFrom<'a> {
143        RecvFrom { socket: self, buf }
144    }
145
146    /// Attempt to receive a datagram from the socket without blocking.
147    pub fn async_recv_from(
148        &self,
149        buf: &mut [u8],
150        cx: &mut Context<'_>,
151    ) -> Poll<io::Result<(usize, socket2::SockAddr)>> {
152        ready!(EventedFd::poll_readable(&self.0, cx))?;
153        // SAFETY: socket2::Socket::recv_from takes a `&mut [MaybeUninit<u8>]`, so it's necessary to
154        // type-pun `&mut [u8]`. This is safe because the bytes are known to be initialized, and
155        // MaybeUninit's layout is guaranteed to be equivalent to its wrapped type.
156        let buf = unsafe {
157            std::slice::from_raw_parts_mut(
158                buf.as_mut_ptr() as *mut core::mem::MaybeUninit<u8>,
159                buf.len(),
160            )
161        };
162        match self.0.as_ref().recv_from(buf) {
163            Err(e) => {
164                if e.kind() == io::ErrorKind::WouldBlock {
165                    self.0.need_read(cx);
166                    Poll::Pending
167                } else {
168                    Poll::Ready(Err(e))
169                }
170            }
171            Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
172        }
173    }
174
175    /// Send a datagram via the socket to the given address.
176    ///
177    /// The returned future will resolve with the number of bytes sent on success.
178    pub fn send_to<'a>(&'a self, buf: &'a [u8], addr: socket2::SockAddr) -> SendTo<'a> {
179        SendTo { socket: self, buf, addr }
180    }
181
182    /// Attempt to send a datagram via the socket without blocking.
183    pub fn async_send_to(
184        &self,
185        buf: &[u8],
186        addr: &socket2::SockAddr,
187        cx: &mut Context<'_>,
188    ) -> Poll<io::Result<usize>> {
189        ready!(EventedFd::poll_writable(&self.0, cx))?;
190        match self.0.as_ref().send_to(buf, addr) {
191            Err(e) => {
192                if e.kind() == io::ErrorKind::WouldBlock {
193                    self.0.need_write(cx);
194                    Poll::Pending
195                } else {
196                    Poll::Ready(Err(e))
197                }
198            }
199            Ok(size) => Poll::Ready(Ok(size)),
200        }
201    }
202
203    /// Send a datagram (possibly split over multiple buffers) via the socket.
204    pub fn send_to_vectored<'a>(
205        &'a self,
206        bufs: &'a [io::IoSlice<'a>],
207        addr: socket2::SockAddr,
208    ) -> SendToVectored<'a> {
209        SendToVectored { socket: self, bufs, addr }
210    }
211
212    /// Attempt to send a datagram (possibly split over multiple buffers) via the socket without
213    /// blocking.
214    pub fn async_send_to_vectored<'a>(
215        &self,
216        bufs: &'a [io::IoSlice<'a>],
217        addr: &socket2::SockAddr,
218        cx: &mut Context<'_>,
219    ) -> Poll<io::Result<usize>> {
220        ready!(EventedFd::poll_writable(&self.0, cx))?;
221        match self.0.as_ref().send_to_vectored(bufs, addr) {
222            Err(e) => {
223                if e.kind() == io::ErrorKind::WouldBlock {
224                    self.0.need_write(cx);
225                    Poll::Pending
226                } else {
227                    Poll::Ready(Err(e))
228                }
229            }
230            Ok(size) => Poll::Ready(Ok(size)),
231        }
232    }
233
234    /// Sets the value of the `SO_BROADCAST` option for this socket.
235    ///
236    /// When enabled, this socket is allowed to send packets to a broadcast address.
237    pub fn set_broadcast(&self, broadcast: bool) -> io::Result<()> {
238        self.0.as_ref().set_broadcast(broadcast)
239    }
240
241    /// Gets the value of the `SO_BROADCAST` option for this socket.
242    pub fn broadcast(&self) -> io::Result<bool> {
243        self.0.as_ref().broadcast()
244    }
245
246    /// Sets the `SO_BINDTODEVICE` socket option.
247    ///
248    /// If a socket is bound to an interface, only packets received from that particular interface
249    /// are processed by the socket. Note that this only works for some socket types, particularly
250    /// AF_INET sockets.
251    ///
252    /// The binding will be removed if `interface` is `None` or an empty byte slice.
253    pub fn bind_device(&self, interface: Option<&[u8]>) -> io::Result<()> {
254        self.0.as_ref().bind_device(interface)
255    }
256
257    /// Gets the value of the `SO_BINDTODEVICE` socket option.
258    ///
259    /// `Ok(None)` will be returned if the socket option is not set.
260    pub fn device(&self) -> io::Result<Option<Vec<u8>>> {
261        self.0.as_ref().device()
262    }
263}
264
265/// Future returned by [`UdpSocket::recv_from()`].
266#[must_use = "futures do nothing unless you `.await` or poll them"]
267pub struct UdpRecvFrom<'a> {
268    socket: &'a UdpSocket,
269    buf: &'a mut [u8],
270}
271
272impl<'a> Future for UdpRecvFrom<'a> {
273    type Output = io::Result<(usize, SocketAddr)>;
274
275    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
276        let this = &mut *self;
277        let (received, addr) = ready!(this.socket.0.async_recv_from(this.buf, cx))?;
278        Poll::Ready(
279            addr.as_socket()
280                .ok_or_else(new_socket_address_conversion_error)
281                .map(|addr| (received, addr)),
282        )
283    }
284}
285
286/// Future returned by [`DatagramSocket::recv_from()`].
287#[must_use = "futures do nothing unless you `.await` or poll them"]
288pub struct RecvFrom<'a> {
289    socket: &'a DatagramSocket,
290    buf: &'a mut [u8],
291}
292
293impl<'a> Future for RecvFrom<'a> {
294    type Output = io::Result<(usize, socket2::SockAddr)>;
295
296    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297        let this = &mut *self;
298        let (received, addr) = ready!(this.socket.async_recv_from(this.buf, cx))?;
299        Poll::Ready(Ok((received, addr)))
300    }
301}
302
303/// Future returned by [`DatagramSocket::send_to()`].
304#[must_use = "futures do nothing unless you `.await` or poll them"]
305pub struct SendTo<'a> {
306    socket: &'a DatagramSocket,
307    buf: &'a [u8],
308    addr: socket2::SockAddr,
309}
310
311impl<'a> Future for SendTo<'a> {
312    type Output = io::Result<usize>;
313
314    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
315        self.socket.async_send_to(self.buf, &self.addr, cx)
316    }
317}
318
319/// Future returned by [`DatagramSocket::send_to_vectored()`].
320#[must_use = "futures do nothing unless you `.await` or poll them"]
321pub struct SendToVectored<'a> {
322    socket: &'a DatagramSocket,
323    bufs: &'a [io::IoSlice<'a>],
324    addr: socket2::SockAddr,
325}
326
327impl<'a> Future for SendToVectored<'a> {
328    type Output = io::Result<usize>;
329
330    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
331        self.socket.async_send_to_vectored(self.bufs, &self.addr, cx)
332    }
333}
334
335#[cfg(test)]
336mod test {
337    #[test]
338    fn datagram_socket_new_from_socket() {
339        let sock = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)
340            .expect("failed to create stream socket");
341        match super::DatagramSocket::new_from_socket(sock) {
342            Err(e) => {
343                if e.kind() != std::io::ErrorKind::InvalidInput {
344                    panic!("got: {e:?}; want error of kind InvalidInput");
345                }
346            }
347            Ok(_) => panic!("DatagramSocket created from stream socket succeeded unexpectedly"),
348        }
349    }
350}