_usb_vsock_service_driver_rustc/
vsock_service.rs

1// Copyright 2025 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use fidl_fuchsia_hardware_vsock::{self as vsock, Addr, CallbacksProxy};
6use fuchsia_async::{Scope, Socket};
7use futures::channel::mpsc;
8use futures::{StreamExt, TryStreamExt};
9use log::{debug, error, info};
10use std::io::Error;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::{self, Arc, Weak};
13use usb_vsock::{Address, Connection, PacketBuffer};
14use zx::Status;
15
16use crate::ConnectionRequest;
17
18/// Implements the fuchsia.hardware.vsock service against a [`Connection`].
19pub struct VsockService<B> {
20    connection: sync::Mutex<Option<Weak<Connection<B>>>>,
21    callback: CallbacksProxy,
22    current_cid: AtomicU32,
23    scope: Scope,
24}
25
26impl<B: PacketBuffer> VsockService<B> {
27    /// Waits for the start message from the client and returns a constructed [`VsockService`]
28    pub async fn wait_for_start(
29        incoming_connections: mpsc::Receiver<ConnectionRequest>,
30        requests: &mut vsock::DeviceRequestStream,
31    ) -> Result<Self, Error> {
32        use vsock::DeviceRequest::*;
33
34        let scope = Scope::new_with_name("vsock-service");
35        let Some(req) = requests.try_next().await.map_err(Error::other)? else {
36            return Err(Error::other(
37                "vsock client connected and disconnected without sending start message",
38            ));
39        };
40
41        match req {
42            Start { cb, responder } => {
43                info!("Client callback set for vsock client");
44                let connection = Default::default();
45                let callback = cb.into_proxy();
46                // since we aren't connected to a host yet, we claim to be cid 3 (the first
47                // non-reserved cid). The host can override this when a new connection is
48                // established.
49                let current_cid = AtomicU32::new(3);
50                scope.spawn(Self::run_incoming_loop(incoming_connections, callback.clone()));
51                responder.send(Ok(())).map_err(Error::other)?;
52                Ok(Self { connection, callback, scope, current_cid })
53            }
54            other => {
55                Err(Error::other(format!("unexpected message before start message: {other:?}")))
56            }
57        }
58    }
59
60    pub fn current_cid(&self) -> u32 {
61        self.current_cid.load(Ordering::Relaxed)
62    }
63
64    /// Set the current connection to be used by the vsock service server.
65    ///
66    /// # Panics
67    ///
68    /// Panics if the current socket is already set.
69    pub async fn set_connection(&self, conn: Arc<Connection<B>>, cid: u32) {
70        self.current_cid.store(cid, Ordering::Relaxed);
71        self.callback.transport_reset(cid).await.unwrap_or_else(log_callback_error);
72        let mut current = self.connection.lock().unwrap();
73        if current.as_ref().and_then(Weak::upgrade).is_some() {
74            panic!("Can only have one active connection set at a time");
75        }
76        current.replace(Arc::downgrade(&conn));
77    }
78
79    /// Gets the current connection if one is set.
80    fn get_connection(&self) -> Option<Arc<Connection<B>>> {
81        self.connection.lock().unwrap().as_ref().and_then(Weak::upgrade)
82    }
83
84    async fn send_request(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
85        log::trace!("sending request for connection to address {addr:?}");
86        let cb = self.callback.clone();
87        let connection = self.get_connection();
88        let device_cid = self.current_cid();
89        self.scope.spawn(async move {
90            let Some(connection) = connection else {
91                // immediately reject a connection request if we don't have a usb connection to
92                // put it on
93                debug!("connection to {addr:?} rejected due to lack of a host connection");
94                cb.rst(&addr).unwrap_or_else(log_callback_error);
95                return;
96            };
97            let status = match connection
98                .connect(from_fidl_addr(device_cid, addr), Socket::from_socket(data))
99                .await
100            {
101                Ok(status) => status,
102                Err(err) => {
103                    // connection failed
104                    debug!("Connection request failed to connect with err {err:?}");
105                    cb.rst(&addr).unwrap_or_else(log_callback_error);
106                    return;
107                }
108            };
109            cb.response(&addr).unwrap_or_else(log_callback_error);
110            status.wait_for_close().await.ok();
111            cb.rst(&addr).unwrap_or_else(log_callback_error);
112        });
113        Ok(())
114    }
115
116    async fn send_shutdown(&self, addr: Addr) -> Result<(), Status> {
117        if let Some(connection) = self.get_connection() {
118            connection.close(&from_fidl_addr(self.current_cid(), addr)).await;
119        } else {
120            // this connection can't exist so just tell the caller that it was reset.
121            self.callback.rst(&addr).unwrap_or_else(log_callback_error);
122        }
123        Ok(())
124    }
125
126    async fn send_rst(&self, addr: Addr) -> Result<(), Status> {
127        if let Some(connection) = self.get_connection() {
128            connection.reset(&from_fidl_addr(self.current_cid(), addr)).await.ok();
129        }
130        Ok(())
131    }
132
133    async fn send_response(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
134        // We cheat here and reconstitute the ConnectionRequest ourselves rather than try to thread
135        // it through the state machine. Since the main client of this particular api should be
136        // keeping track on its own, and we will ignore accepts of unknown addresses, this should be
137        // fine.
138        let address = from_fidl_addr(self.current_cid(), addr);
139        let request = ConnectionRequest::new(address.clone());
140        let Some(connection) = self.get_connection() else {
141            error!("Tried to accept connection for {address:?} on usb connection that is not open");
142            return Err(Status::BAD_STATE);
143        };
144        connection.accept(request, Socket::from_socket(data)).await.map_err(|err| {
145            error!("Failed to accept connection for {address:?}: {err:?}");
146            Err(Status::ADDRESS_UNREACHABLE)
147        })?;
148
149        Ok(())
150    }
151
152    async fn run_incoming_loop(
153        mut incoming_connections: mpsc::Receiver<ConnectionRequest>,
154        proxy: CallbacksProxy,
155    ) {
156        loop {
157            let Some(next) = incoming_connections.next().await else {
158                return;
159            };
160            if let Err(err) = proxy.request(&from_vsock_addr(*next.address())) {
161                error!("Error calling callback for incoming connection request: {err:?}");
162                return;
163            }
164        }
165    }
166
167    /// Runs the request loop for [`vsock::DeviceRequest`] against whatever the current [`Connection`]
168    /// is.
169    pub async fn run(&self, mut requests: vsock::DeviceRequestStream) -> Result<(), Error> {
170        use vsock::DeviceRequest::*;
171
172        while let Some(req) = requests.try_next().await.map_err(Error::other)? {
173            match req {
174                start @ Start { .. } => {
175                    return Err(Error::other(format!(
176                        "unexpected start message after one was already sent {start:?}"
177                    )))
178                }
179                SendRequest { addr, data, responder } => responder
180                    .send(self.send_request(addr, data).await.map_err(Status::into_raw))
181                    .map_err(Error::other)?,
182                SendShutdown { addr, responder } => responder
183                    .send(self.send_shutdown(addr).await.map_err(Status::into_raw))
184                    .map_err(Error::other)?,
185                SendRst { addr, responder } => responder
186                    .send(self.send_rst(addr).await.map_err(Status::into_raw))
187                    .map_err(Error::other)?,
188                SendResponse { addr, data, responder } => responder
189                    .send(self.send_response(addr, data).await.map_err(Status::into_raw))
190                    .map_err(Error::other)?,
191                GetCid { responder } => responder.send(self.current_cid()).map_err(Error::other)?,
192            }
193        }
194        Ok(())
195    }
196}
197
198fn log_callback_error<E: std::error::Error>(err: E) {
199    error!("Error sending callback to vsock client: {err:?}")
200}
201
202fn from_fidl_addr(device_cid: u32, value: Addr) -> Address {
203    Address {
204        device_cid,
205        host_cid: value.remote_cid,
206        device_port: value.local_port,
207        host_port: value.remote_port,
208    }
209}
210
211/// Leaves [`Address::device_cid`] blank, to be filled in by the caller
212fn from_vsock_addr(value: Address) -> Addr {
213    Addr { local_port: value.device_port, remote_cid: value.host_cid, remote_port: value.host_port }
214}