1use core::ffi::c_void;
6use core::ptr::NonNull;
7use std::num::NonZero;
8use std::ops::ControlFlow;
9use std::sync::OnceLock;
10
11use log::{debug, warn};
12use zx::Status;
13
14use fdf::{Channel, Dispatcher, DispatcherBuilder, DispatcherRef};
15use fidl_fuchsia_driver_framework::DriverRequest;
16
17use fdf::{DriverHandle, Message, fdf_handle_t};
18
19use crate::{Driver, DriverContext};
20use fdf_sys::fdf_dispatcher_get_current_dispatcher;
21use fidl_fuchsia_driver_framework::DriverStartArgs;
22use fuchsia_async::LocalExecutorBuilder;
23
24pub struct DriverServer<T> {
28 server_handle: OnceLock<Channel<[u8]>>,
29 root_dispatcher: DispatcherRef<'static>,
30 driver: OnceLock<T>,
31}
32
33impl<T: Driver> DriverServer<T> {
34 pub unsafe extern "C" fn initialize(server_handle: fdf_handle_t) -> *mut c_void {
41 let root_dispatcher = NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
44 .expect("Non-null current dispatcher");
45 let server_handle = OnceLock::from(unsafe {
49 Channel::from_driver_handle(DriverHandle::new_unchecked(
50 NonZero::new(server_handle).expect("valid driver handle"),
51 ))
52 });
53
54 let root_dispatcher = unsafe { DispatcherRef::from_raw(root_dispatcher) };
56 let server_ptr = Box::into_raw(Box::new(Self {
59 server_handle,
60 root_dispatcher: root_dispatcher.clone(),
61 driver: OnceLock::default(),
62 }));
63
64 let server = unsafe { &mut *server_ptr };
69
70 let rust_async_dispatcher = DispatcherBuilder::new()
73 .name("fuchsia-async")
74 .allow_thread_blocking()
75 .create_released()
76 .expect("failure creating blocking dispatcher for rust async");
77 rust_async_dispatcher
80 .post_task_sync(move |status| {
81 let Status::OK = status else { return };
83 Dispatcher::override_current(root_dispatcher.clone(), || {
84 let port = zx::Port::create_with_opts(zx::PortOptions::BIND_TO_INTERRUPT);
88 let mut executor = LocalExecutorBuilder::new().port(port).build();
89 executor.run_singlethreaded(async move {
90 server.message_loop(root_dispatcher).await;
91 server.server_handle.take()
96 });
97 });
98 })
99 .expect("failure spawning main event loop for rust async dispatch");
100
101 server_ptr.cast()
105 }
106
107 pub unsafe extern "C" fn destroy(obj: *mut c_void) {
115 let obj: *mut Self = obj.cast();
116 unsafe { drop(Box::from_raw(obj)) }
119 }
120
121 async fn message_loop(&mut self, dispatcher: DispatcherRef<'_>) {
124 loop {
125 let server_handle_lock = self.server_handle.get_mut();
126 let Some(server_handle) = server_handle_lock else {
127 panic!("driver already shut down while message loop was running")
128 };
129 match server_handle.read_bytes(dispatcher.clone()).await {
130 Ok(Some(message)) => {
131 if let ControlFlow::Break(_) = self.handle_message(message).await {
132 return;
134 }
135 }
136 Ok(None) => panic!("unexpected empty message on server channel"),
137 Err(status @ Status::PEER_CLOSED) | Err(status @ Status::UNAVAILABLE) => {
138 warn!(
139 "Driver server channel closed before a stop message with status {status}, exiting main loop early but stop() will not be called."
140 );
141 return;
142 }
143 Err(e) => panic!("unexpected error on server channel {e}"),
144 }
145 }
146 }
147
148 async fn handle_start(&self, start_args: DriverStartArgs) -> Result<(), Status> {
155 let context = DriverContext::new(self.root_dispatcher.clone(), start_args)?;
156 context.start_logging(T::NAME)?;
157
158 log::debug!("driver starting");
159
160 let driver = T::start(context).await?;
161 self.driver.set(driver).map_err(|_| ()).expect("Driver received start message twice");
162 Ok(())
163 }
164
165 async fn handle_stop(&mut self) {
166 log::debug!("driver stopping");
167 self.driver
168 .take()
169 .expect("received stop message more than once or without successfully starting")
170 .stop()
171 .await;
172 }
173
174 async fn handle_message(&mut self, message: Message<[u8]>) -> ControlFlow<()> {
181 let (_, request) = DriverRequest::read_from_message(message).unwrap();
182 match request {
183 DriverRequest::Start { start_args, responder } => {
184 let res = self.handle_start(start_args).await.map_err(Status::into_raw);
185 let Some(server_handle) = self.server_handle.get() else {
186 panic!("driver shutting down before it was finished starting")
187 };
188 responder.send_response(server_handle, res).unwrap();
189 if res.is_ok() {
190 ControlFlow::Continue(())
191 } else {
192 debug!("driver failed to start, exiting main loop");
193 ControlFlow::Break(())
194 }
195 }
196 DriverRequest::Stop {} => {
197 self.handle_stop().await;
198 ControlFlow::Break(())
199 }
200 _ => panic!("Unknown message on server channel"),
201 }
202 }
203
204 pub(crate) fn testing_get_driver(&self) -> Option<&'_ T> {
207 self.driver.get()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 use fdf::{CurrentDispatcher, OnDispatcher};
216 use fdf_env::test::spawn_in_driver;
217 use fidl_next_fuchsia_driver_framework::DriverClientHandler;
218 use zx::Status;
219
220 use fdf::Channel;
221 use fidl_next::{ClientDispatcher, ClientEnd};
222
223 #[derive(Default)]
224 struct TestDriver {
225 _not_empty: bool,
226 }
227
228 impl Driver for TestDriver {
229 const NAME: &str = "test_driver";
230
231 async fn start(context: DriverContext) -> Result<Self, Status> {
232 let DriverContext { root_dispatcher, start_args, .. } = context;
233 println!("created new test driver on dispatcher: {root_dispatcher:?}");
234 println!("driver start message: {start_args:?}");
235 Ok(Self::default())
236 }
237 async fn stop(&self) {
238 println!("driver stop message");
239 }
240 }
241
242 crate::driver_register!(TestDriver);
243
244 #[derive(Debug)]
245 struct DriverClient;
246 impl DriverClientHandler for DriverClient {}
247
248 #[test]
249 fn register_driver() {
250 assert_eq!(__fuchsia_driver_registration__.version, 1);
251 let initialize_func = __fuchsia_driver_registration__.v1.initialize.expect("initializer");
252 let destroy_func = __fuchsia_driver_registration__.v1.destroy.expect("destroy function");
253
254 let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
255
256 let (client_exit_tx, client_exit_rx) = futures::channel::oneshot::channel();
257 spawn_in_driver("driver registration", async move {
258 let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
259 ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
260 let dispatcher = ClientDispatcher::new(client_end);
261 let client = dispatcher.client();
262
263 CurrentDispatcher
264 .spawn_task(async move {
265 dispatcher.run(DriverClient).await.unwrap_err();
266 client_exit_tx.send(()).unwrap();
267 })
268 .unwrap();
269
270 let channel_handle = server_chan.into_driver_handle().into_raw().get();
271 let driver_server = unsafe { initialize_func(channel_handle) } as usize;
272 assert_ne!(driver_server, 0);
273
274 client
275 .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
276 .await
277 .unwrap()
278 .unwrap();
279
280 client.stop().await.unwrap();
281 client_exit_rx.await.unwrap();
282
283 unsafe {
284 destroy_func(driver_server as *mut c_void);
285 }
286 })
287 }
288}