vsock_service_lib/
service.rs

1// Copyright 2018 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// This module contains the bulk of the logic for connecting user applications to a
6// vsock driver.
7//
8// Handling user requests is complicated as there are multiple communication channels
9// involved. For example a request to 'connect' will result in sending a message
10// to the driver over the single DeviceProxy. If this returns with success then
11// eventually a message will come over the single Callbacks stream indicating
12// whether the remote accepted or rejected.
13//
14// Fundamentally then there needs to be mutual exclusion in accessing DeviceProxy,
15// and de-multiplexing of incoming messages on the Callbacks stream. There are
16// a two high level options for doing this.
17//  1. Force a single task event driver model. This would mean that additional
18//     asynchronous executions are never spawned, and any use of await! or otherwise
19//     blocking with additional futures requires collection futures in future sets
20//     or having custom polling logic etc. Whilst this is probably the most resource
21//     efficient it restricts the service to be single task forever by its design,
22//     is harder to reason about as cannot be written very idiomatically with futures
23//     and is even more complicated to avoid blocking other requests whilst waiting
24//     on responses from the driver.
25//  2. Allow multiple asynchronous executions and use some form of message passing
26//     and mutual exclusion checking to handle DeviceProxy access and sharing access
27//     to the Callbacks stream. Potentially more resource intensive with unnecessary
28//     refcells etc, but allows for the potential to have actual concurrent execution
29//     and is much simpler to write the logic.
30// The chosen option is (2) and the access to DeviceProxy is handled with an Rc<Refcell<State>>,
31// and de-multiplexing of the Callbacks is done by registering an event whilst holding
32// the refcell, and having a single asynchronous task that is dedicated to converting
33// incoming Callbacks to signaling registered events.
34
35use crate::{addr, port};
36use anyhow::{format_err, Context as _};
37use fidl::endpoints;
38use fidl::endpoints::{ControlHandle, RequestStream};
39use fidl_fuchsia_hardware_vsock::{
40    CallbacksMarker, CallbacksRequest, CallbacksRequestStream, DeviceProxy, VMADDR_CID_HOST,
41    VMADDR_CID_LOCAL,
42};
43use fidl_fuchsia_vsock::{
44    AcceptorProxy, ConnectionRequest, ConnectionRequestStream, ConnectionTransport,
45    ConnectorRequest, ConnectorRequestStream, ListenerControlHandle, ListenerRequest,
46    ListenerRequestStream, SIGNAL_STREAM_INCOMING,
47};
48use fuchsia_async as fasync;
49use futures::channel::{mpsc, oneshot};
50use futures::{future, select, Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
51use std::cell::{Ref, RefCell, RefMut};
52use std::collections::{HashMap, VecDeque};
53use std::convert::Infallible;
54use std::ops::Deref;
55use std::pin::Pin;
56use std::rc::Rc;
57use std::task::{Context, Poll};
58use thiserror::Error;
59
60const ZXIO_SIGNAL_INCOMING: zx::Signals = zx::Signals::from_bits(SIGNAL_STREAM_INCOMING).unwrap();
61
62type Cid = u32;
63type Port = u32;
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65struct Addr(Cid, Port);
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68enum EventType {
69    Shutdown,
70    Response,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Hash)]
74struct Event {
75    action: EventType,
76    addr: addr::Vsock,
77}
78
79#[derive(Debug, Clone, Eq, PartialEq, Hash)]
80enum Deregister {
81    Event(Event),
82    Listen(Addr),
83    Port(Addr),
84}
85
86#[derive(Error, Debug)]
87enum Error {
88    #[error("Driver returned failure status {}", _0)]
89    Driver(#[source] zx::Status),
90    #[error("All ephemeral ports are allocated")]
91    OutOfPorts,
92    #[error("Addr has already been bound")]
93    AlreadyBound,
94    #[error("Connection refused by remote")]
95    ConnectionRefused,
96    #[error("Error whilst communication with client")]
97    ClientCommunication(#[source] anyhow::Error),
98    #[error("Error whilst communication with client")]
99    DriverCommunication(#[source] anyhow::Error),
100    #[error("Driver reset the connection")]
101    ConnectionReset,
102    #[error("There are no more connections in the accept queue")]
103    NoConnectionsInQueue,
104}
105
106impl From<oneshot::Canceled> for Error {
107    fn from(_: oneshot::Canceled) -> Error {
108        Error::ConnectionReset
109    }
110}
111
112impl Error {
113    pub fn into_status(&self) -> zx::Status {
114        match self {
115            Error::Driver(status) => *status,
116            Error::OutOfPorts => zx::Status::NO_RESOURCES,
117            Error::AlreadyBound => zx::Status::ALREADY_BOUND,
118            Error::ConnectionRefused => zx::Status::UNAVAILABLE,
119            Error::ClientCommunication(err) | Error::DriverCommunication(err) => {
120                *err.downcast_ref::<zx::Status>().unwrap_or(&zx::Status::INTERNAL)
121            }
122            Error::ConnectionReset => zx::Status::PEER_CLOSED,
123            Error::NoConnectionsInQueue => zx::Status::SHOULD_WAIT,
124        }
125    }
126    pub fn is_comm_failure(&self) -> bool {
127        match self {
128            Error::ClientCommunication(_) | Error::DriverCommunication(_) => true,
129            _ => false,
130        }
131    }
132}
133
134fn map_driver_result(result: Result<Result<(), i32>, fidl::Error>) -> Result<(), Error> {
135    result
136        .map_err(|x| Error::DriverCommunication(x.into()))?
137        .map_err(|e| Error::Driver(zx::Status::from_raw(e)))
138}
139
140struct SocketContextState {
141    port: Addr,
142    accept_queue: VecDeque<addr::Vsock>,
143    backlog: Option<u32>,
144    control: ListenerControlHandle,
145    signaled: bool,
146}
147
148#[derive(Clone)]
149pub struct SocketContext(Rc<RefCell<SocketContextState>>);
150
151impl SocketContext {
152    fn new(port: Addr, control: ListenerControlHandle) -> SocketContext {
153        SocketContext(Rc::new(RefCell::new(SocketContextState {
154            port,
155            accept_queue: VecDeque::new(),
156            backlog: None,
157            signaled: false,
158            control,
159        })))
160    }
161
162    fn listen(&self, backlog: u32) -> Result<(), Error> {
163        let mut ctx = self.0.borrow_mut();
164        if ctx.backlog.is_some() {
165            return Err(Error::AlreadyBound);
166        }
167        // TODO: Update listener?
168        ctx.backlog = Some(backlog);
169        Ok(())
170    }
171
172    fn push_addr(&self, addr: addr::Vsock) -> bool {
173        let mut ctx = self.0.borrow_mut();
174        if Addr(addr.remote_cid, addr.local_port) != ctx.port {
175            panic!("request address doesn't match local socket address");
176        }
177        let Some(ref mut backlog) = ctx.backlog else {
178            panic!("pushing address when not yet bound");
179        };
180        if *backlog == 0 {
181            return false;
182        }
183        *backlog -= 1;
184        ctx.accept_queue.push_back(addr);
185        if ctx.signaled == false {
186            let _ = ctx.control.signal_peer(zx::Signals::empty(), ZXIO_SIGNAL_INCOMING);
187            ctx.signaled = true
188        }
189        return true;
190    }
191
192    fn pop_addr(&self) -> Option<addr::Vsock> {
193        let mut ctx = self.0.borrow_mut();
194        if let Some(addr) = ctx.accept_queue.pop_front() {
195            let Some(ref mut backlog) = ctx.backlog else {
196                return None;
197            };
198            *backlog += 1;
199            if ctx.accept_queue.len() == 0 {
200                let _ = ctx.control.signal_peer(ZXIO_SIGNAL_INCOMING, zx::Signals::empty());
201                ctx.signaled = false;
202            }
203            Some(addr)
204        } else {
205            None
206        }
207    }
208
209    fn port(&self) -> Addr {
210        self.0.borrow_mut().port
211    }
212}
213
214enum Listener {
215    Bound,
216    Channel(mpsc::UnboundedSender<addr::Vsock>),
217    Queue(SocketContext),
218}
219
220struct State {
221    guest_vsock_device: Option<DeviceProxy>,
222    loopback_vsock_device: Option<DeviceProxy>,
223    local_cid: Cid,
224    events: HashMap<Event, oneshot::Sender<()>>,
225    used_ports: HashMap<Cid, port::Tracker>,
226    listeners: HashMap<Addr, Listener>,
227    tasks: fasync::TaskGroup,
228}
229
230impl State {
231    fn device(&self, cid: u32) -> Option<&DeviceProxy> {
232        match (cid, &self.guest_vsock_device, &self.loopback_vsock_device) {
233            (VMADDR_CID_LOCAL, _, Some(loopback)) => Some(&loopback),
234            (VMADDR_CID_HOST, Some(guest), _) => Some(&guest),
235            (VMADDR_CID_HOST, None, Some(loopback)) => Some(&loopback),
236            (cid, None, Some(loopback)) if cid == self.local_cid => Some(&loopback),
237            _ => None,
238        }
239    }
240}
241
242#[derive(Clone)]
243pub struct Vsock(Rc<RefCell<State>>);
244
245impl Vsock {
246    /// Creates a new vsock service connected to the given `DeviceProxy`
247    ///
248    /// The creation is asynchronous due to need to invoke methods on the given `DeviceProxy`. On
249    /// success a pair of `Self, impl Future<Result<_, Error>>` is returned. The `impl Future` is
250    /// a future that is listening for and processing messages from the `device`. This future needs
251    /// to be evaluated for other methods on the returned `Self` to complete successfully. Unless
252    /// a fatal error occurs the future will never yield a result and will execute infinitely.
253    pub async fn new(
254        guest_vsock_device: Option<DeviceProxy>,
255        loopback_vsock_device: Option<DeviceProxy>,
256    ) -> Result<(Self, impl Future<Output = Result<Vec<Infallible>, anyhow::Error>>), anyhow::Error>
257    {
258        let mut server_streams = Vec::new();
259        let mut start_device = |device: &DeviceProxy| {
260            let (callbacks_client, callbacks_server) =
261                endpoints::create_endpoints::<CallbacksMarker>();
262            server_streams.push(callbacks_server.into_stream());
263
264            device.start(callbacks_client).map(map_driver_result).err_into::<anyhow::Error>()
265        };
266        let mut local_cid = VMADDR_CID_LOCAL;
267        if let Some(ref device) = guest_vsock_device {
268            start_device(device).await.context("Failed to start guest device")?;
269            local_cid = device.get_cid().await?;
270        }
271        if let Some(ref device) = loopback_vsock_device {
272            start_device(device).await.context("Failed to start loopback device")?;
273        }
274        let service = State {
275            guest_vsock_device,
276            loopback_vsock_device,
277            local_cid,
278            events: HashMap::new(),
279            used_ports: HashMap::new(),
280            listeners: HashMap::new(),
281            tasks: fasync::TaskGroup::new(),
282        };
283
284        let service = Vsock(Rc::new(RefCell::new(service)));
285        let callback_loops: Vec<_> = server_streams
286            .into_iter()
287            .map(|stream| service.clone().run_callbacks(stream))
288            .collect();
289
290        Ok((service, future::try_join_all(callback_loops)))
291    }
292    async fn run_callbacks(
293        self,
294        mut callbacks: CallbacksRequestStream,
295    ) -> Result<Infallible, anyhow::Error> {
296        while let Some(Ok(cb)) = callbacks.next().await {
297            self.borrow_mut().do_callback(cb);
298        }
299        // The only way to get here is if our callbacks stream ended, since our notifications
300        // cannot disconnect as we are holding a reference to them in |service|.
301        Err(format_err!("Driver disconnected"))
302    }
303
304    // Spawns a new asynchronous task for listening for incoming connections on a port.
305    fn start_listener(
306        &self,
307        acceptor: fidl::endpoints::ClientEnd<fidl_fuchsia_vsock::AcceptorMarker>,
308        local_port: u32,
309    ) -> Result<(), Error> {
310        let acceptor = acceptor.into_proxy();
311        let stream = self.listen_port(local_port)?;
312        self.borrow_mut().tasks.local(
313            self.clone()
314                .run_connection_listener(stream, acceptor)
315                .unwrap_or_else(|err| log::warn!("Error {} running connection listener", err)),
316        );
317        Ok(())
318    }
319
320    // Spawns a new asynchronous task for listening for incoming connections on a port.
321    fn start_listener2(
322        &self,
323        listener: fidl::endpoints::ServerEnd<fidl_fuchsia_vsock::ListenerMarker>,
324        port: Addr,
325    ) -> Result<(), Error> {
326        let stream = listener.into_stream();
327        self.bind_port(port.clone())?;
328        self.borrow_mut().tasks.local(
329            self.clone()
330                .run_connection_listener2(stream, port)
331                .unwrap_or_else(|err| log::warn!("Error {} running connection listener", err)),
332        );
333        Ok(())
334    }
335
336    // Handles a single incoming client request.
337    async fn handle_request(&self, request: ConnectorRequest) -> Result<(), Error> {
338        match request {
339            ConnectorRequest::Connect { remote_cid, remote_port, con, responder } => responder
340                .send(
341                    self.make_connection(remote_cid, remote_port, con)
342                        .await
343                        .map_err(|e| e.into_status().into_raw()),
344                ),
345            ConnectorRequest::Listen { local_port, acceptor, responder } => responder.send(
346                self.start_listener(acceptor, local_port).map_err(|e| e.into_status().into_raw()),
347            ),
348            ConnectorRequest::Bind { remote_cid, local_port, listener, responder } => responder
349                .send(
350                    self.start_listener2(listener, Addr(remote_cid, local_port))
351                        .map_err(|e| e.into_status().into_raw()),
352                ),
353            ConnectorRequest::GetCid { responder } => responder.send(self.borrow().local_cid),
354        }
355        .map_err(|e| Error::ClientCommunication(e.into()))
356    }
357
358    /// Evaluates messages on a `ConnectorRequestStream` until completion or error
359    ///
360    /// Takes ownership of a `RequestStream` that is most likely created from a `ServicesServer`
361    /// and processes any incoming requests on it.
362    pub async fn run_client_connection(self, request: ConnectorRequestStream) {
363        let self_ref = &self;
364        let fut = request
365            .map_err(|err| Error::ClientCommunication(err.into()))
366            // TODO: The parallel limit of 4 is currently invented with no basis and should
367            // made something more sensible.
368            .try_for_each_concurrent(4, |request| {
369                self_ref
370                    .handle_request(request)
371                    .or_else(|e| future::ready(if e.is_comm_failure() { Err(e) } else { Ok(()) }))
372            });
373        if let Err(e) = fut.await {
374            log::info!("Failed to handle request {}", e);
375        }
376    }
377    fn alloc_ephemeral_port(self, cid: Cid) -> Option<AllocatedPort> {
378        let p = self.borrow_mut().used_ports.entry(cid).or_default().allocate();
379        p.map(|p| AllocatedPort { port: Addr(cid, p), service: self })
380    }
381    // Creates a `ListenStream` that will retrieve raw incoming connection requests.
382    // These requests come from the device via the run_callbacks future.
383    fn listen_port(&self, port: u32) -> Result<ListenStream, Error> {
384        if port::is_ephemeral(port) {
385            log::info!("Rejecting request to listen on ephemeral port {}", port);
386            return Err(Error::ConnectionRefused);
387        }
388        match self.borrow_mut().listeners.entry(Addr(VMADDR_CID_HOST, port)) {
389            std::collections::hash_map::Entry::Vacant(entry) => {
390                let (sender, receiver) = mpsc::unbounded();
391                let listen =
392                    ListenStream { local_port: port, service: self.clone(), stream: receiver };
393                entry.insert(Listener::Channel(sender));
394                Ok(listen)
395            }
396            _ => {
397                log::info!("Attempt to listen on already bound port {}", port);
398                Err(Error::AlreadyBound)
399            }
400        }
401    }
402
403    fn bind_port(&self, port: Addr) -> Result<(), Error> {
404        if port::is_ephemeral(port.1) {
405            log::info!("Rejecting request to listen on ephemeral port {}", port.1);
406            return Err(Error::ConnectionRefused);
407        }
408        match self.borrow_mut().listeners.entry(port) {
409            std::collections::hash_map::Entry::Vacant(entry) => {
410                entry.insert(Listener::Bound);
411                Ok(())
412            }
413            _ => {
414                log::info!("Attempt to listen on already bound port {:?}", port);
415                Err(Error::AlreadyBound)
416            }
417        }
418    }
419
420    // Helper for inserting an event into the events hashmap
421    fn register_event(&self, event: Event) -> Result<OneshotEvent, Error> {
422        match self.borrow_mut().events.entry(event) {
423            std::collections::hash_map::Entry::Vacant(entry) => {
424                let (sender, receiver) = oneshot::channel();
425                let event = OneshotEvent {
426                    event: Some(entry.key().clone()),
427                    service: self.clone(),
428                    oneshot: receiver,
429                };
430                entry.insert(sender);
431                Ok(event)
432            }
433            _ => Err(Error::AlreadyBound),
434        }
435    }
436
437    // These helpers are wrappers around sending a message to the device, and creating events that
438    // will be signaled by the run_callbacks future when it receives a message from the device.
439    fn send_request(
440        &self,
441        addr: &addr::Vsock,
442        data: zx::Socket,
443    ) -> Result<impl Future<Output = Result<(OneshotEvent, OneshotEvent), Error>> + 'static, Error>
444    {
445        let shutdown_callback =
446            self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
447        let response_callback =
448            self.register_event(Event { action: EventType::Response, addr: addr.clone() })?;
449
450        let send_request_fut = self.borrow_mut().send_request(&addr, data);
451
452        Ok(async move {
453            send_request_fut.await?;
454            Ok((shutdown_callback, response_callback))
455        })
456    }
457    fn send_response(
458        &self,
459        addr: &addr::Vsock,
460        data: zx::Socket,
461    ) -> Result<impl Future<Output = Result<OneshotEvent, Error>> + 'static, Error> {
462        let shutdown_callback =
463            self.register_event(Event { action: EventType::Shutdown, addr: addr.clone() })?;
464
465        let send_request_fut = self.borrow_mut().send_response(&addr, data);
466
467        Ok(async move {
468            send_request_fut.await?;
469            Ok(shutdown_callback)
470        })
471    }
472
473    // Runs a connected socket until completion. Processes any VMO sends and shutdown events.
474    async fn run_connection<ShutdownFut>(
475        self,
476        addr: addr::Vsock,
477        shutdown_event: ShutdownFut,
478        mut requests: ConnectionRequestStream,
479        _port: Option<AllocatedPort>,
480    ) -> Result<(), Error>
481    where
482        ShutdownFut:
483            Future<Output = Result<(), futures::channel::oneshot::Canceled>> + std::marker::Unpin,
484    {
485        let mut shutdown_event = shutdown_event.fuse();
486        select! {
487            shutdown_event = shutdown_event => {
488                let fut = future::ready(shutdown_event)
489                    .err_into()
490                    .and_then(|()| self.borrow_mut().send_rst(&addr));
491                return fut.await;
492            },
493            request = requests.next() => {
494                match request {
495                    Some(Ok(ConnectionRequest::Shutdown{control_handle: _control_handle})) => {
496                        let fut =
497                            self.borrow_mut().send_shutdown(&addr)
498                                // Wait to either receive the RST for the client or to be
499                                // shut down for some other reason
500                                .and_then(|()| shutdown_event.err_into());
501                        return fut.await;
502                    },
503                    // Generate a RST for a non graceful client disconnect.
504                    Some(Err(e)) => {
505                        let fut = self.borrow_mut().send_rst(&addr);
506                        fut.await?;
507                        return Err(Error::ClientCommunication(e.into()));
508                    },
509                    None => {
510                        let fut = self.borrow_mut().send_rst(&addr);
511                        return fut.await;
512                    },
513                }
514            }
515        }
516    }
517
518    fn listen(&self, socket: &SocketContext, backlog: u32) -> Result<(), Error> {
519        socket.listen(backlog)?;
520        // Replace "bound" listener with a socket accept queue.
521        match self.borrow_mut().listeners.entry(socket.port()) {
522            std::collections::hash_map::Entry::Vacant(_) => {
523                // We should be in bound state. Something went wrong if we end up here.
524                log::warn!("Expected listener to be in bound state, but listener not found!");
525                return Err(Error::AlreadyBound);
526            }
527            std::collections::hash_map::Entry::Occupied(mut entry) => {
528                if !matches!(entry.get(), Listener::Bound) {
529                    // Listen was probably already called. The call to socket.listen should
530                    // probably already have failed in this case.
531                    log::warn!("Listen called multiple times.");
532                    return Err(Error::AlreadyBound);
533                }
534                entry.insert(Listener::Queue(socket.clone()));
535            }
536        };
537
538        Ok(())
539    }
540
541    async fn accept(
542        &self,
543        socket: &SocketContext,
544        con: ConnectionTransport,
545    ) -> Result<addr::Vsock, Error> {
546        if let Some(addr) = socket.pop_addr() {
547            let data = con.data;
548            let con = con.con.into_stream();
549            let shutdown_event = self.send_response(&addr, data)?.await?;
550            self.borrow_mut().tasks.local(
551                self.clone()
552                    .run_connection(addr.clone(), shutdown_event, con, None)
553                    .map_err(|err| log::warn!("Error {} whilst running connection", err))
554                    .map(|_| ()),
555            );
556            // TODO: check if we want want to return the local port for the connection or the local
557            // port which the request came over.
558            Ok(addr)
559        } else {
560            Err(Error::NoConnectionsInQueue)
561        }
562    }
563
564    // Handles a single incoming client request.
565    async fn handle_listener_request(
566        &self,
567        socket: &SocketContext,
568        request: ListenerRequest,
569    ) -> Result<(), Error> {
570        match request {
571            ListenerRequest::Listen { backlog, responder } => {
572                responder.send(self.listen(socket, backlog).map_err(|e| e.into_status().into_raw()))
573            }
574            ListenerRequest::Accept { con, responder } => match self.accept(socket, con).await {
575                Ok(addr) => responder.send(Ok(&addr)),
576                Err(e) => responder.send(Err(e.into_status().into_raw())),
577            },
578        }
579        .map_err(|e| Error::ClientCommunication(e.into()))
580    }
581
582    async fn run_connection_listener2(
583        self,
584        request: ListenerRequestStream,
585        port: Addr,
586    ) -> Result<(), Error> {
587        let socket = SocketContext::new(port, request.control_handle());
588        let self_ref = &self;
589        let fut = request
590            .map_err(|err| Error::ClientCommunication(err.into()))
591            .try_for_each_concurrent(None, |request| {
592                self_ref
593                    .handle_listener_request(&socket, request)
594                    .or_else(|e| future::ready(if e.is_comm_failure() { Err(e) } else { Ok(()) }))
595            });
596        if let Err(e) = fut.await {
597            log::info!("Failed to handle request {}", e);
598        }
599        self.deregister(Deregister::Listen(socket.port()));
600        Ok(())
601    }
602
603    // Waits for incoming connections on the given `ListenStream`, checks with the
604    // user via the `acceptor` if it should be accepted, and if so spawns a new
605    // asynchronous task to run the connection.
606    async fn run_connection_listener(
607        self,
608        incoming: ListenStream,
609        acceptor: AcceptorProxy,
610    ) -> Result<(), Error> {
611        incoming
612            .then(|addr| acceptor.accept(&*addr.clone()).map_ok(|maybe_con| (maybe_con, addr)))
613            .map_err(|e| Error::ClientCommunication(e.into()))
614            .try_for_each(|(maybe_con, addr)| async {
615                match maybe_con {
616                    Some(con) => {
617                        let data = con.data;
618                        let con = con.con.into_stream();
619                        let shutdown_event = self.send_response(&addr, data)?.await?;
620                        self.borrow_mut().tasks.local(
621                            self.clone()
622                                .run_connection(addr, shutdown_event, con, None)
623                                .map_err(|err| {
624                                    log::warn!("Error {} whilst running connection", err)
625                                })
626                                .map(|_| ()),
627                        );
628                        Ok(())
629                    }
630                    None => {
631                        let fut = self.borrow_mut().send_rst(&addr);
632                        fut.await
633                    }
634                }
635            })
636            .await
637    }
638
639    // Attempts to connect to the given remote cid/port. If successful spawns a new
640    // asynchronous task to run the connection until completion.
641    async fn make_connection(
642        &self,
643        remote_cid: u32,
644        remote_port: u32,
645        con: ConnectionTransport,
646    ) -> Result<u32, Error> {
647        let data = con.data;
648        let con = con.con.into_stream();
649        let port = self.clone().alloc_ephemeral_port(remote_cid).ok_or(Error::OutOfPorts)?;
650        let port_value = port.port.1;
651        let addr = addr::Vsock::new(port_value, remote_port, remote_cid);
652        let (shutdown_event, response_event) = self.send_request(&addr, data)?.await?;
653        let mut shutdown_event = shutdown_event.fuse();
654        select! {
655            _shutdown_event = shutdown_event => {
656                // Getting a RST here just indicates a rejection and
657                // not any underlying issues.
658                return Err(Error::ConnectionRefused);
659            },
660            response_event = response_event.fuse() => response_event?,
661        }
662
663        self.borrow_mut().tasks.local(
664            self.clone()
665                .run_connection(addr, shutdown_event, con, Some(port))
666                .unwrap_or_else(|err| log::warn!("Error {} whilst running connection", err)),
667        );
668        Ok(port_value)
669    }
670
671    /// Mutably borrow the wrapped value.
672    fn borrow_mut(&self) -> RefMut<'_, State> {
673        self.0.borrow_mut()
674    }
675
676    fn borrow(&self) -> Ref<'_, State> {
677        self.0.borrow()
678    }
679
680    // Deregisters the specified event.
681    fn deregister(&self, event: Deregister) {
682        self.borrow_mut().deregister(event);
683    }
684}
685
686impl State {
687    // Remove the `event` from the `events` `HashMap`
688    fn deregister(&mut self, event: Deregister) {
689        match event {
690            Deregister::Event(e) => {
691                self.events.remove(&e);
692            }
693            Deregister::Listen(a) => {
694                self.listeners.remove(&a);
695            }
696            Deregister::Port(p) => {
697                self.used_ports.get_mut(&p.0).unwrap().free(p.1);
698            }
699        }
700    }
701
702    // Wrappers around device functions with nicer type signatures
703    fn send_request(
704        &mut self,
705        addr: &addr::Vsock,
706        data: zx::Socket,
707    ) -> impl Future<Output = Result<(), Error>> {
708        let result = self
709            .device(addr.remote_cid)
710            .ok_or(Error::ConnectionRefused)
711            .and_then(|device| Ok(device.send_request(&addr.clone(), data).map(map_driver_result)));
712        async { result?.await }
713    }
714    fn send_response(
715        &mut self,
716        addr: &addr::Vsock,
717        data: zx::Socket,
718    ) -> impl Future<Output = Result<(), Error>> {
719        let result =
720            self.device(addr.remote_cid).ok_or(Error::ConnectionRefused).and_then(|device| {
721                Ok(device.send_response(&addr.clone(), data).map(map_driver_result))
722            });
723        async { result?.await }
724    }
725    fn send_rst(
726        &mut self,
727        addr: &addr::Vsock,
728    ) -> impl Future<Output = Result<(), Error>> + 'static {
729        let result = self
730            .device(addr.remote_cid)
731            .ok_or(Error::ConnectionRefused)
732            .and_then(|device| Ok(device.send_rst(&addr.clone()).map(map_driver_result)));
733        async { result?.await }
734    }
735    fn send_shutdown(
736        &mut self,
737        addr: &addr::Vsock,
738    ) -> impl Future<Output = Result<(), Error>> + 'static {
739        let result = self
740            .device(addr.remote_cid)
741            .ok_or(Error::ConnectionRefused)
742            .and_then(|device| Ok(device.send_shutdown(&addr.clone()).map(map_driver_result)));
743        async { result?.await }
744    }
745
746    // Processes a single callback from the `device`. This is intended to be used by
747    // `Vsock::run_callbacks`
748    fn do_callback(&mut self, callback: CallbacksRequest) {
749        match callback {
750            CallbacksRequest::Response { addr, control_handle: _control_handle } => {
751                self.events
752                    .remove(&Event { action: EventType::Response, addr: addr::Vsock::from(addr) })
753                    .map(|channel| channel.send(()));
754            }
755            CallbacksRequest::Rst { addr, control_handle: _control_handle } => {
756                self.events
757                    .remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) });
758            }
759            CallbacksRequest::Request { addr, control_handle: _control_handle } => {
760                let addr = addr::Vsock::from(addr);
761                let reset = |state: &mut State| {
762                    let task = state.send_rst(&addr).map(|_| ());
763                    state.tasks.local(task);
764                };
765                match self.listeners.get(&Addr(addr.remote_cid, addr.local_port)) {
766                    Some(Listener::Bound) => {
767                        log::warn!(
768                            "Request on port {} denied due to socket only bound, not yet listening",
769                            addr.local_port
770                        );
771                        reset(self);
772                    }
773                    Some(Listener::Channel(sender)) => {
774                        let _ = sender.unbounded_send(addr.clone());
775                    }
776                    Some(Listener::Queue(socket)) => {
777                        if !socket.push_addr(addr.clone()) {
778                            log::warn!(
779                                "Request on port {} denied due to full backlog",
780                                addr.local_port
781                            );
782                            reset(self);
783                        }
784                    }
785                    None => {
786                        log::warn!("Request on port {} with no listener", addr.local_port);
787                        reset(self);
788                    }
789                }
790            }
791            CallbacksRequest::Shutdown { addr, control_handle: _control_handle } => {
792                self.events
793                    .remove(&Event { action: EventType::Shutdown, addr: addr::Vsock::from(addr) })
794                    .map(|channel| channel.send(()));
795            }
796            CallbacksRequest::TransportReset { new_cid: _new_cid, responder } => {
797                self.events.clear();
798                let _ = responder.send();
799            }
800        }
801    }
802}
803
804struct AllocatedPort {
805    service: Vsock,
806    port: Addr,
807}
808
809impl Deref for AllocatedPort {
810    type Target = Addr;
811
812    fn deref(&self) -> &Addr {
813        &self.port
814    }
815}
816
817impl Drop for AllocatedPort {
818    fn drop(&mut self) {
819        self.service.deregister(Deregister::Port(self.port));
820    }
821}
822
823struct OneshotEvent {
824    event: Option<Event>,
825    service: Vsock,
826    oneshot: oneshot::Receiver<()>,
827}
828
829impl Drop for OneshotEvent {
830    fn drop(&mut self) {
831        self.event.take().map(|e| self.service.deregister(Deregister::Event(e)));
832    }
833}
834
835impl Future for OneshotEvent {
836    type Output = <oneshot::Receiver<()> as Future>::Output;
837
838    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
839        match self.oneshot.poll_unpin(cx) {
840            Poll::Ready(x) => {
841                // Take the event so that we don't try to deregister it later,
842                // as by having sent the message we just received the callbacks
843                // task will already have removed it
844                self.event.take();
845                Poll::Ready(x)
846            }
847            p => p,
848        }
849    }
850}
851
852struct ListenStream {
853    local_port: Port,
854    service: Vsock,
855    stream: mpsc::UnboundedReceiver<addr::Vsock>,
856}
857
858impl Drop for ListenStream {
859    fn drop(&mut self) {
860        self.service.deregister(Deregister::Listen(Addr(VMADDR_CID_HOST, self.local_port)));
861    }
862}
863
864impl Stream for ListenStream {
865    type Item = addr::Vsock;
866
867    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
868        self.stream.poll_next_unpin(cx)
869    }
870}