fdf_component/
server.rs

1// Copyright 2024 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use core::ffi::c_void;
6use core::ptr::NonNull;
7use std::num::NonZero;
8use std::ops::ControlFlow;
9use std::sync::OnceLock;
10
11use log::{debug, warn};
12use zx::Status;
13
14use fdf::{Channel, Dispatcher, DispatcherBuilder, DispatcherRef};
15use fidl_fuchsia_driver_framework::DriverRequest;
16
17use fdf::{DriverHandle, Message, fdf_handle_t};
18
19use crate::{Driver, DriverContext};
20use fdf_sys::fdf_dispatcher_get_current_dispatcher;
21use fidl_fuchsia_driver_framework::DriverStartArgs;
22use fuchsia_async::LocalExecutorBuilder;
23
24/// Implements the lifecycle management of a rust driver, including starting and stopping it
25/// and setting up the rust async dispatcher and logging for the driver to use, and running a
26/// message loop for the driver start and stop messages.
27pub struct DriverServer<T> {
28    server_handle: OnceLock<Channel<[u8]>>,
29    root_dispatcher: DispatcherRef<'static>,
30    driver: OnceLock<T>,
31}
32
33impl<T: Driver> DriverServer<T> {
34    /// Called by the driver host to start the driver.
35    ///
36    /// # Safety
37    ///
38    /// The caller must provide a valid non-zero driver transport channel handle for
39    /// `server_handle`.
40    pub unsafe extern "C" fn initialize(server_handle: fdf_handle_t) -> *mut c_void {
41        // SAFETY: We verify that the pointer returned is non-null, ensuring that this was
42        // called from within a driver context.
43        let root_dispatcher = NonNull::new(unsafe { fdf_dispatcher_get_current_dispatcher() })
44            .expect("Non-null current dispatcher");
45        // SAFETY: We use NonZero::new to verify that we've been given a non-zero
46        // driver handle, and expect that the caller (which is the driver runtime) has given us
47        // a valid driver transport fidl channel.
48        let server_handle = OnceLock::from(unsafe {
49            Channel::from_driver_handle(DriverHandle::new_unchecked(
50                NonZero::new(server_handle).expect("valid driver handle"),
51            ))
52        });
53
54        // SAFETY: the root dispatcher is expected to live as long as this driver is loaded.
55        let root_dispatcher = unsafe { DispatcherRef::from_raw(root_dispatcher) };
56        // We leak the box holding the server so that the driver runtime can take control over the
57        // lifetime of the server object.
58        let server_ptr = Box::into_raw(Box::new(Self {
59            server_handle,
60            root_dispatcher: root_dispatcher.clone(),
61            driver: OnceLock::default(),
62        }));
63
64        // Reconstitute the pointer to the `DriverServer` as a mut reference to use it in the main
65        // loop.
66        // SAFETY: We are the exclusive owner of the object until we drop the server handle,
67        // triggering the driver host to call `destroy`.
68        let server = unsafe { &mut *server_ptr };
69
70        // Build a new dispatcher that we can have spin on a fuchsia-async executor main loop
71        // to act as a reactor for non-driver events.
72        let rust_async_dispatcher = DispatcherBuilder::new()
73            .name("fuchsia-async")
74            .allow_thread_blocking()
75            .create_released()
76            .expect("failure creating blocking dispatcher for rust async");
77        // Post the task to the dispatcher that will run the fuchsia-async loop, and have it run
78        // the server's message loop waiting for start and stop messages from the driver host.
79        rust_async_dispatcher
80            .post_task_sync(move |status| {
81                // bail immediately if we were somehow cancelled before we started
82                let Status::OK = status else { return };
83                Dispatcher::override_current(root_dispatcher.clone(), || {
84                    // create and run a fuchsia-async executor, giving it the "root" dispatcher to
85                    // actually execute driver tasks on, as this thread will be effectively blocked
86                    // by the reactor loop.
87                    let port = zx::Port::create_with_opts(zx::PortOptions::BIND_TO_INTERRUPT);
88                    let mut executor = LocalExecutorBuilder::new().port(port).build();
89                    executor.run_singlethreaded(async move {
90                        server.message_loop(root_dispatcher).await;
91                        // take the server handle so it can drop after the async block is done,
92                        // which will signal to the driver host that the driver has finished
93                        // shutdown, so that we are can guarantee that when `destroy` is called, we
94                        // are not still using `server`.
95                        server.server_handle.take()
96                    });
97                });
98            })
99            .expect("failure spawning main event loop for rust async dispatch");
100
101        // Take the pointer of the server object to use as the identifier for the server to the
102        // driver runtime. It uses this as an opaque identifier and expects no particular layout of
103        // the object pointed to, and we use it to free the box at unload in `Self::destroy`.
104        server_ptr.cast()
105    }
106
107    /// Called by the driver host after shutting down a driver and once the handle passed to
108    /// [`Self::initialize`] is dropped.
109    ///
110    /// # Safety
111    ///
112    /// This must only be called after the handle provided to [`Self::initialize`] has been
113    /// dropped, which indicates that the main event loop of the driver lifecycle has ended.
114    pub unsafe extern "C" fn destroy(obj: *mut c_void) {
115        let obj: *mut Self = obj.cast();
116        // SAFETY: We built this object in `initialize` and gave ownership of its
117        // lifetime to the driver framework, which is now giving it to us to free.
118        unsafe { drop(Box::from_raw(obj)) }
119    }
120
121    /// Implements the main message loop for handling start and stop messages from rust
122    /// driver host and passing them on to the implementation of [`Driver`] we contain.
123    async fn message_loop(&mut self, dispatcher: DispatcherRef<'_>) {
124        loop {
125            let server_handle_lock = self.server_handle.get_mut();
126            let Some(server_handle) = server_handle_lock else {
127                panic!("driver already shut down while message loop was running")
128            };
129            match server_handle.read_bytes(dispatcher.clone()).await {
130                Ok(Some(message)) => {
131                    if let ControlFlow::Break(_) = self.handle_message(message).await {
132                        // driver shut down or failed to start, exit message loop
133                        return;
134                    }
135                }
136                Ok(None) => panic!("unexpected empty message on server channel"),
137                Err(status @ Status::PEER_CLOSED) | Err(status @ Status::UNAVAILABLE) => {
138                    warn!(
139                        "Driver server channel closed before a stop message with status {status}, exiting main loop early but stop() will not be called."
140                    );
141                    return;
142                }
143                Err(e) => panic!("unexpected error on server channel {e}"),
144            }
145        }
146    }
147
148    /// Handles the start message by initializing logging and calling the [`Driver::start`] with
149    /// a constructed [`DriverContext`].
150    ///
151    /// # Panics
152    ///
153    /// This method panics if the start message has already been received.
154    async fn handle_start(&self, start_args: DriverStartArgs) -> Result<(), Status> {
155        let context = DriverContext::new(self.root_dispatcher.clone(), start_args)?;
156        context.start_logging(T::NAME)?;
157
158        log::debug!("driver starting");
159
160        let driver = T::start(context).await?;
161        self.driver.set(driver).map_err(|_| ()).expect("Driver received start message twice");
162        Ok(())
163    }
164
165    async fn handle_stop(&mut self) {
166        log::debug!("driver stopping");
167        self.driver
168            .take()
169            .expect("received stop message more than once or without successfully starting")
170            .stop()
171            .await;
172    }
173
174    /// Dispatches messages from the driver host to the appropriate implementation.
175    ///
176    /// # Panics
177    ///
178    /// This method panics if the messages are received out of order somehow (two start messages,
179    /// stop before start, etc).
180    async fn handle_message(&mut self, message: Message<[u8]>) -> ControlFlow<()> {
181        let (_, request) = DriverRequest::read_from_message(message).unwrap();
182        match request {
183            DriverRequest::Start { start_args, responder } => {
184                let res = self.handle_start(start_args).await.map_err(Status::into_raw);
185                let Some(server_handle) = self.server_handle.get() else {
186                    panic!("driver shutting down before it was finished starting")
187                };
188                responder.send_response(server_handle, res).unwrap();
189                if res.is_ok() {
190                    ControlFlow::Continue(())
191                } else {
192                    debug!("driver failed to start, exiting main loop");
193                    ControlFlow::Break(())
194                }
195            }
196            DriverRequest::Stop {} => {
197                self.handle_stop().await;
198                ControlFlow::Break(())
199            }
200            _ => panic!("Unknown message on server channel"),
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    use fdf::{CurrentDispatcher, OnDispatcher};
210    use fdf_env::test::spawn_in_driver;
211    use fidl_next_fuchsia_driver_framework::DriverClientHandler;
212    use zx::Status;
213
214    use fdf::Channel;
215    use fidl_next::{ClientDispatcher, ClientEnd};
216
217    #[derive(Default)]
218    struct TestDriver {
219        _not_empty: bool,
220    }
221
222    impl Driver for TestDriver {
223        const NAME: &str = "test_driver";
224
225        async fn start(context: DriverContext) -> Result<Self, Status> {
226            let DriverContext { root_dispatcher, start_args, .. } = context;
227            println!("created new test driver on dispatcher: {root_dispatcher:?}");
228            println!("driver start message: {start_args:?}");
229            Ok(Self::default())
230        }
231        async fn stop(&self) {
232            println!("driver stop message");
233        }
234    }
235
236    crate::driver_register!(TestDriver);
237
238    #[derive(Debug)]
239    struct DriverClient;
240    impl DriverClientHandler<fdf_fidl::DriverChannel> for DriverClient {}
241
242    #[test]
243    fn register_driver() {
244        assert_eq!(__fuchsia_driver_registration__.version, 1);
245        let initialize_func = __fuchsia_driver_registration__.v1.initialize.expect("initializer");
246        let destroy_func = __fuchsia_driver_registration__.v1.destroy.expect("destroy function");
247
248        let (server_chan, client_chan) = Channel::<[fidl_next::Chunk]>::create();
249
250        let (client_exit_tx, client_exit_rx) = futures::channel::oneshot::channel();
251        spawn_in_driver("driver registration", async move {
252            let client_end: ClientEnd<fidl_next_fuchsia_driver_framework::Driver, _> =
253                ClientEnd::from_untyped(fdf_fidl::DriverChannel::new(client_chan));
254            let dispatcher = ClientDispatcher::new(client_end);
255            let client = dispatcher.client();
256
257            CurrentDispatcher
258                .spawn_task(async move {
259                    dispatcher.run(DriverClient).await.unwrap_err();
260                    client_exit_tx.send(()).unwrap();
261                })
262                .unwrap();
263
264            let channel_handle = server_chan.into_driver_handle().into_raw().get();
265            let driver_server = unsafe { initialize_func(channel_handle) } as usize;
266            assert_ne!(driver_server, 0);
267
268            client
269                .start(fidl_next_fuchsia_driver_framework::DriverStartArgs::default())
270                .await
271                .unwrap()
272                .unwrap();
273
274            client.stop().await.unwrap();
275            client_exit_rx.await.unwrap();
276
277            unsafe {
278                destroy_func(driver_server as *mut c_void);
279            }
280        })
281    }
282}