1use core::ffi::c_void;
6use core::ptr::NonNull;
7use std::num::NonZero;
8use std::ops::ControlFlow;
9use std::sync::OnceLock;
10
11use log::{debug, warn};
12use zx::Status;
13
14use fdf::{Channel, DispatcherBuilder, DispatcherRef};
15use fidl_fuchsia_driver_framework::DriverRequest;
16
17use fdf::{AsyncDispatcher, DriverHandle, Message, fdf_handle_t};
18
19use crate::{Driver, DriverContext, DriverError};
20use fdf_sys::fdf_dispatcher_get_current_dispatcher;
21use fidl_fuchsia_driver_framework::DriverStartArgs;
22use fuchsia_async::LocalExecutorBuilder;
23
24pub struct DriverServer<T> {
28 server_handle: OnceLock<Channel<[u8]>>,
29 root_dispatcher: DispatcherRef<'static>,
30 driver: OnceLock<T>,
31}
32
33impl<T: Driver> DriverServer<T> {
34 pub unsafe extern "C" fn initialize(server_handle: fdf_handle_t) -> *mut c_void {
41 let root_dispatcher = NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
44 .expect("Non-null current dispatcher");
45 let server_handle = OnceLock::from(unsafe {
49 Channel::from_driver_handle(DriverHandle::new_unchecked(
50 NonZero::new(server_handle).expect("valid driver handle"),
51 ))
52 });
53
54 let root_dispatcher = unsafe { DispatcherRef::from_raw(root_dispatcher) };
56 let server_ptr = Box::into_raw(Box::new(Self {
59 server_handle,
60 root_dispatcher: root_dispatcher.clone(),
61 driver: OnceLock::default(),
62 }));
63
64 let server = unsafe { &mut *server_ptr };
69
70 let rust_async_dispatcher = DispatcherBuilder::new()
73 .name("fuchsia-async")
74 .allow_thread_blocking()
75 .create_released()
76 .expect("failure creating blocking dispatcher for rust async");
77 rust_async_dispatcher
80 .post_task_sync(move |status| {
81 let Status::OK = status else { return };
83 fdf_core::override_current_dispatcher(root_dispatcher.clone(), || {
84 let port = zx::Port::create_with_opts(zx::PortOptions::BIND_TO_INTERRUPT);
88 let mut executor = LocalExecutorBuilder::new().port(port).build();
89 executor.run_singlethreaded(async move {
90 server.message_loop(root_dispatcher).await;
91 server.server_handle.take()
96 });
97 });
98 })
99 .expect("failure spawning main event loop for rust async dispatch");
100
101 server_ptr.cast()
105 }
106
107 pub unsafe extern "C" fn destroy(obj: *mut c_void) {
115 let obj: *mut Self = obj.cast();
116 unsafe { drop(Box::from_raw(obj)) }
119 }
120
121 async fn message_loop(&mut self, dispatcher: DispatcherRef<'_>) {
124 loop {
125 let server_handle_lock = self.server_handle.get_mut();
126 let Some(server_handle) = server_handle_lock else {
127 panic!("driver already shut down while message loop was running")
128 };
129 match server_handle.read_bytes(dispatcher.clone()).await {
130 Ok(Some(message)) => {
131 if let ControlFlow::Break(_) = self.handle_message(message).await {
132 return;
134 }
135 }
136 Ok(None) => panic!("unexpected empty message on server channel"),
137 Err(status @ Status::PEER_CLOSED) | Err(status @ Status::UNAVAILABLE) => {
138 warn!(
139 "Driver server channel closed before a stop message with status {status}, exiting main loop early but stop() will not be called."
140 );
141 return;
142 }
143 Err(e) => panic!("unexpected error on server channel {e}"),
144 }
145 }
146 }
147
148 async fn handle_start(&self, start_args: DriverStartArgs) -> Result<(), Status> {
155 let context = DriverContext::new(self.root_dispatcher.clone(), start_args)?;
156 context.start_logging(T::NAME)?;
157
158 log::debug!("driver starting");
159
160 let driver = T::start(context).await.map_err(DriverError::log_to_status)?;
161 self.driver.set(driver).map_err(|_| ()).expect("Driver received start message twice");
162 Ok(())
163 }
164
165 async fn handle_stop(&mut self) {
166 log::debug!("driver stopping");
167 self.driver
168 .take()
169 .expect("received stop message more than once or without successfully starting")
170 .stop()
171 .await;
172 }
173
174 async fn handle_message(&mut self, message: Message<[u8]>) -> ControlFlow<()> {
181 let (_, request) = DriverRequest::read_from_message(message).unwrap();
182 match request {
183 DriverRequest::Start { start_args, responder } => {
184 let res = self.handle_start(start_args).await.map_err(Status::into_raw);
185 let Some(server_handle) = self.server_handle.get() else {
186 panic!("driver shutting down before it was finished starting")
187 };
188 responder.send_response(server_handle, res).unwrap();
189 if res.is_ok() {
190 ControlFlow::Continue(())
191 } else {
192 debug!("driver failed to start, exiting main loop");
193 ControlFlow::Break(())
194 }
195 }
196 DriverRequest::Stop {} => {
197 self.handle_stop().await;
198 ControlFlow::Break(())
199 }
200 _ => panic!("Unknown message on server channel"),
201 }
202 }
203
204 pub(crate) fn testing_get_driver(&self) -> Option<&'_ T> {
207 self.driver.get()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::DriverError;
215
216 use fdf::{CurrentDispatcher, OnDispatcher};
217 use fdf_env::test::spawn_in_driver;
218 use fidl_next_fuchsia_driver_framework::DriverClientHandler;
219 use zx::Status;
220
221 use fdf::Channel;
222 use fidl_next::{ClientDispatcher, ClientEnd};
223
224 #[derive(Default)]
225 struct TestDriver {
226 _not_empty: bool,
227 }
228
229 impl Driver for TestDriver {
230 const NAME: &str = "test_driver";
231
232 async fn start(context: DriverContext) -> Result<Self, DriverError> {
233 let DriverContext { root_dispatcher, start_args, .. } = context;
234 println!("created new test driver on dispatcher: {root_dispatcher:?}");
235 println!("driver start message: {start_args:?}");
236 Ok(Self::default())
237 }
238 async fn stop(&self) {
239 println!("driver stop message");
240 }
241 }
242
243 crate::driver_register!(TestDriver);
244
245 #[derive(Debug)]
246 struct DriverClient;
247 impl DriverClientHandler for DriverClient {}
248
249 #[test]
250 fn register_driver() {
251 assert_eq!(__fuchsia_driver_registration__.version, 1);
252 let initialize_func = __fuchsia_driver_registration__.v1.initialize.expect("initializer");
253 let destroy_func = __fuchsia_driver_registration__.v1.destroy.expect("destroy function");
254
255 let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
256
257 spawn_in_driver("driver registration", async move {
258 let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
259 ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
260 let dispatcher = ClientDispatcher::new(client_end);
261 let client = dispatcher.client();
262
263 let client_task = CurrentDispatcher
264 .spawn(async move {
265 dispatcher.run(DriverClient).await.unwrap_err();
266 })
267 .unwrap();
268
269 let channel_handle = server_chan.into_driver_handle().into_raw().get();
270 let driver_server = unsafe { initialize_func(channel_handle) } as usize;
271 assert_ne!(driver_server, 0);
272
273 client
274 .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
275 .await
276 .unwrap()
277 .unwrap();
278
279 client.stop().await.unwrap();
280 client_task.await.unwrap();
281
282 unsafe {
283 destroy_func(driver_server as *mut c_void);
284 }
285 })
286 }
287
288 struct TestDriverAnyhowSuccess;
289
290 impl Driver for TestDriverAnyhowSuccess {
291 const NAME: &str = "test_driver_anyhow_success";
292
293 async fn start(_context: DriverContext) -> Result<Self, DriverError> {
294 Ok(Self)
295 }
296 async fn stop(&self) {}
297 }
298
299 #[test]
300 fn test_anyhow_success() {
301 let registration = crate::macros::make_driver_registration::<TestDriverAnyhowSuccess>();
302 let initialize_func = registration.v1.initialize.expect("initializer");
303 let destroy_func = registration.v1.destroy.expect("destroy function");
304
305 let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
306
307 spawn_in_driver("driver anyhow success", async move {
308 let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
309 ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
310 let dispatcher = ClientDispatcher::new(client_end);
311 let client = dispatcher.client();
312
313 let client_task = CurrentDispatcher
314 .spawn(async move {
315 dispatcher.run(DriverClient).await.unwrap_err();
316 })
317 .unwrap();
318
319 let channel_handle = server_chan.into_driver_handle().into_raw().get();
320 let driver_server = unsafe { initialize_func(channel_handle) } as usize;
321 assert_ne!(driver_server, 0);
322
323 client
324 .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
325 .await
326 .unwrap()
327 .unwrap();
328
329 client.stop().await.unwrap();
330 client_task.await.unwrap();
331
332 unsafe {
333 destroy_func(driver_server as *mut c_void);
334 }
335 })
336 }
337
338 struct TestDriverAnyhowFailure;
339
340 impl Driver for TestDriverAnyhowFailure {
341 const NAME: &str = "test_driver_anyhow_failure";
342
343 async fn start(_context: DriverContext) -> Result<Self, DriverError> {
344 Err(anyhow::anyhow!(Status::INVALID_ARGS).into())
345 }
346 async fn stop(&self) {}
347 }
348
349 #[test]
350 fn test_anyhow_failure() {
351 let registration = crate::macros::make_driver_registration::<TestDriverAnyhowFailure>();
352 let initialize_func = registration.v1.initialize.expect("initializer");
353 let destroy_func = registration.v1.destroy.expect("destroy function");
354
355 let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
356
357 spawn_in_driver("driver anyhow failure", async move {
358 let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
359 ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
360 let dispatcher = ClientDispatcher::new(client_end);
361 let client = dispatcher.client();
362
363 let client_task = CurrentDispatcher
364 .spawn(async move {
365 dispatcher.run(DriverClient).await.unwrap_err();
366 })
367 .unwrap();
368
369 let channel_handle = server_chan.into_driver_handle().into_raw().get();
370 let driver_server = unsafe { initialize_func(channel_handle) } as usize;
371 assert_ne!(driver_server, 0);
372
373 let res = client
374 .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
375 .await
376 .unwrap();
377
378 assert_eq!(res.unwrap_err(), Status::INVALID_ARGS.into_raw());
379
380 client_task.await.unwrap();
381
382 unsafe {
383 destroy_func(driver_server as *mut c_void);
384 }
385 })
386 }
387
388 struct TestDriverAnyhowFailureDefault;
389
390 impl Driver for TestDriverAnyhowFailureDefault {
391 const NAME: &str = "test_driver_anyhow_failure_default";
392
393 async fn start(_context: DriverContext) -> Result<Self, DriverError> {
394 Err(anyhow::anyhow!("some generic error").into())
395 }
396 async fn stop(&self) {}
397 }
398
399 #[test]
400 fn test_anyhow_failure_default() {
401 let registration =
402 crate::macros::make_driver_registration::<TestDriverAnyhowFailureDefault>();
403 let initialize_func = registration.v1.initialize.expect("initializer");
404 let destroy_func = registration.v1.destroy.expect("destroy function");
405
406 let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
407
408 spawn_in_driver("driver anyhow failure default", async move {
409 let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
410 ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
411 let dispatcher = ClientDispatcher::new(client_end);
412 let client = dispatcher.client();
413
414 let client_task = CurrentDispatcher
415 .spawn(async move {
416 dispatcher.run(DriverClient).await.unwrap_err();
417 })
418 .unwrap();
419
420 let channel_handle = server_chan.into_driver_handle().into_raw().get();
421 let driver_server = unsafe { initialize_func(channel_handle) } as usize;
422 assert_ne!(driver_server, 0);
423
424 let res = client
425 .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
426 .await
427 .unwrap();
428
429 assert_eq!(res.unwrap_err(), Status::INTERNAL.into_raw());
430
431 client_task.await.unwrap();
432
433 unsafe {
434 destroy_func(driver_server as *mut c_void);
435 }
436 })
437 }
438}