_usb_vsock_service_driver_rustc/
vsock_service.rs1use fidl_fuchsia_hardware_vsock::{self as vsock, Addr, CallbacksProxy};
6use fuchsia_async::{Scope, Socket};
7use futures::channel::mpsc;
8use futures::{StreamExt, TryStreamExt};
9use log::{debug, error, info};
10use std::io::Error;
11use std::sync::atomic::{AtomicU32, Ordering};
12use std::sync::{self, Arc, Weak};
13use usb_vsock::{Address, Connection, PacketBuffer};
14use zx::Status;
15
16use crate::ConnectionRequest;
17
18pub struct VsockService<B> {
20 connection: sync::Mutex<Option<Weak<Connection<B>>>>,
21 callback: CallbacksProxy,
22 current_cid: AtomicU32,
23 scope: Scope,
24}
25
26impl<B: PacketBuffer> VsockService<B> {
27 pub async fn wait_for_start(
29 incoming_connections: mpsc::Receiver<ConnectionRequest>,
30 requests: &mut vsock::DeviceRequestStream,
31 ) -> Result<Self, Error> {
32 use vsock::DeviceRequest::*;
33
34 let scope = Scope::new_with_name("vsock-service");
35 let Some(req) = requests.try_next().await.map_err(Error::other)? else {
36 return Err(Error::other(
37 "vsock client connected and disconnected without sending start message",
38 ));
39 };
40
41 match req {
42 Start { cb, responder } => {
43 info!("Client callback set for vsock client");
44 let connection = Default::default();
45 let callback = cb.into_proxy();
46 let current_cid = AtomicU32::new(3);
50 scope.spawn(Self::run_incoming_loop(incoming_connections, callback.clone()));
51 responder.send(Ok(())).map_err(Error::other)?;
52 Ok(Self { connection, callback, scope, current_cid })
53 }
54 other => {
55 Err(Error::other(format!("unexpected message before start message: {other:?}")))
56 }
57 }
58 }
59
60 pub fn current_cid(&self) -> u32 {
61 self.current_cid.load(Ordering::Relaxed)
62 }
63
64 pub async fn set_connection(&self, conn: Arc<Connection<B>>, cid: u32) {
70 self.current_cid.store(cid, Ordering::Relaxed);
71 self.callback.transport_reset(cid).await.unwrap_or_else(log_callback_error);
72 let mut current = self.connection.lock().unwrap();
73 if current.as_ref().and_then(Weak::upgrade).is_some() {
74 panic!("Can only have one active connection set at a time");
75 }
76 current.replace(Arc::downgrade(&conn));
77 }
78
79 fn get_connection(&self) -> Option<Arc<Connection<B>>> {
81 self.connection.lock().unwrap().as_ref().and_then(Weak::upgrade)
82 }
83
84 async fn send_request(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
85 log::trace!("sending request for connection to address {addr:?}");
86 let cb = self.callback.clone();
87 let connection = self.get_connection();
88 let device_cid = self.current_cid();
89 self.scope.spawn(async move {
90 let Some(connection) = connection else {
91 debug!("connection to {addr:?} rejected due to lack of a host connection");
94 cb.rst(&addr).unwrap_or_else(log_callback_error);
95 return;
96 };
97 let status = match connection
98 .connect(from_fidl_addr(device_cid, addr), Socket::from_socket(data))
99 .await
100 {
101 Ok(status) => status,
102 Err(err) => {
103 debug!("Connection request failed to connect with err {err:?}");
105 cb.rst(&addr).unwrap_or_else(log_callback_error);
106 return;
107 }
108 };
109 cb.response(&addr).unwrap_or_else(log_callback_error);
110 status.wait_for_close().await.ok();
111 cb.rst(&addr).unwrap_or_else(log_callback_error);
112 });
113 Ok(())
114 }
115
116 async fn send_shutdown(&self, addr: Addr) -> Result<(), Status> {
117 if let Some(connection) = self.get_connection() {
118 connection.close(&from_fidl_addr(self.current_cid(), addr)).await;
119 } else {
120 self.callback.rst(&addr).unwrap_or_else(log_callback_error);
122 }
123 Ok(())
124 }
125
126 async fn send_rst(&self, addr: Addr) -> Result<(), Status> {
127 if let Some(connection) = self.get_connection() {
128 connection.reset(&from_fidl_addr(self.current_cid(), addr)).await.ok();
129 }
130 Ok(())
131 }
132
133 async fn send_response(&self, addr: Addr, data: zx::Socket) -> Result<(), Status> {
134 let address = from_fidl_addr(self.current_cid(), addr);
139 let request = ConnectionRequest::new(address.clone());
140 let Some(connection) = self.get_connection() else {
141 error!("Tried to accept connection for {address:?} on usb connection that is not open");
142 return Err(Status::BAD_STATE);
143 };
144 connection.accept(request, Socket::from_socket(data)).await.map_err(|err| {
145 error!("Failed to accept connection for {address:?}: {err:?}");
146 Err(Status::ADDRESS_UNREACHABLE)
147 })?;
148
149 Ok(())
150 }
151
152 async fn run_incoming_loop(
153 mut incoming_connections: mpsc::Receiver<ConnectionRequest>,
154 proxy: CallbacksProxy,
155 ) {
156 loop {
157 let Some(next) = incoming_connections.next().await else {
158 return;
159 };
160 if let Err(err) = proxy.request(&from_vsock_addr(*next.address())) {
161 error!("Error calling callback for incoming connection request: {err:?}");
162 return;
163 }
164 }
165 }
166
167 pub async fn run(&self, mut requests: vsock::DeviceRequestStream) -> Result<(), Error> {
170 use vsock::DeviceRequest::*;
171
172 while let Some(req) = requests.try_next().await.map_err(Error::other)? {
173 match req {
174 start @ Start { .. } => {
175 return Err(Error::other(format!(
176 "unexpected start message after one was already sent {start:?}"
177 )))
178 }
179 SendRequest { addr, data, responder } => responder
180 .send(self.send_request(addr, data).await.map_err(Status::into_raw))
181 .map_err(Error::other)?,
182 SendShutdown { addr, responder } => responder
183 .send(self.send_shutdown(addr).await.map_err(Status::into_raw))
184 .map_err(Error::other)?,
185 SendRst { addr, responder } => responder
186 .send(self.send_rst(addr).await.map_err(Status::into_raw))
187 .map_err(Error::other)?,
188 SendResponse { addr, data, responder } => responder
189 .send(self.send_response(addr, data).await.map_err(Status::into_raw))
190 .map_err(Error::other)?,
191 GetCid { responder } => responder.send(self.current_cid()).map_err(Error::other)?,
192 }
193 }
194 Ok(())
195 }
196}
197
198fn log_callback_error<E: std::error::Error>(err: E) {
199 error!("Error sending callback to vsock client: {err:?}")
200}
201
202fn from_fidl_addr(device_cid: u32, value: Addr) -> Address {
203 Address {
204 device_cid,
205 host_cid: value.remote_cid,
206 device_port: value.local_port,
207 host_port: value.remote_port,
208 }
209}
210
211fn from_vsock_addr(value: Address) -> Addr {
213 Addr { local_port: value.device_port, remote_cid: value.host_cid, remote_port: value.host_port }
214}