use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use detect_stall::StallableRequestStream;
use fidl::endpoints::ServerEnd;
use futures::channel::oneshot::{self, Canceled};
use futures::future::FusedFuture;
use futures::{FutureExt, Stream, StreamExt};
use pin_project::pin_project;
use vfs::directory::immutable::connection::ImmutableConnection;
use vfs::directory::immutable::Simple;
use vfs::execution_scope::{ActiveGuard, ExecutionScope};
use vfs::ToObjectRequest;
use zx::MonotonicDuration;
use {fidl_fuchsia_io as fio, fuchsia_async as fasync};
use super::{ServiceFs, ServiceObjTrait};
type StalledFut = Pin<Box<dyn FusedFuture<Output = Option<zx::Channel>>>>;
#[pin_project]
pub struct StallableServiceFs<ServiceObjTy: ServiceObjTrait> {
#[pin]
fs: ServiceFs<ServiceObjTy>,
connector: OutgoingConnector,
state: State,
debounce_interval: zx::MonotonicDuration,
is_terminated: bool,
}
pub enum Item<Output> {
Request(Output, ActiveGuard),
Stalled(zx::Channel),
}
enum State {
Running { stalled: StalledFut },
Stalled { channel: Option<fasync::OnSignals<'static, zx::Channel>> },
}
impl<ServiceObjTy: ServiceObjTrait> Stream for StallableServiceFs<ServiceObjTy> {
type Item = Item<ServiceObjTy::Output>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.is_terminated {
return Poll::Ready(None);
}
let poll_fs = this.fs.poll_next_unpin(cx);
if let Poll::Ready(Some(request)) = poll_fs {
return Poll::Ready(Some(Item::Request(request, this.connector.scope.active_guard())));
}
loop {
match &mut this.state {
State::Running { stalled } => {
let channel = std::task::ready!(stalled.as_mut().poll(cx));
let channel = channel
.map(|c| fasync::OnSignals::new(c.into(), zx::Signals::CHANNEL_READABLE));
*this.state = State::Stalled { channel };
}
State::Stalled { channel } => {
if let Poll::Ready(None) = poll_fs {
*this.is_terminated = true;
return Poll::Ready(
channel.take().map(|wait| Item::Stalled(wait.take_handle().into())),
);
}
if channel.is_none() {
return Poll::Pending;
}
let readable = channel.as_mut().unwrap().poll_unpin(cx);
let _ = std::task::ready!(readable);
let wait = channel.take().unwrap();
let stalled =
this.connector.serve(wait.take_handle().into(), *this.debounce_interval);
*this.state = State::Running { stalled };
}
}
}
}
}
struct OutgoingConnector {
flags: fio::OpenFlags,
scope: ExecutionScope,
dir: Arc<Simple>,
}
impl OutgoingConnector {
fn serve(
&mut self,
server_end: ServerEnd<fio::DirectoryMarker>,
debounce_interval: MonotonicDuration,
) -> StalledFut {
let (unbound_sender, unbound_receiver) = oneshot::channel();
let object_request = self.flags.to_object_request(server_end);
let scope = self.scope.clone();
let dir = self.dir.clone();
let flags = self.flags;
object_request.spawn(&scope.clone(), move |object_request_ref| {
async move {
ImmutableConnection::create_transform_stream(
scope,
dir,
flags,
object_request_ref,
move |stream| {
StallableRequestStream::new(
stream,
debounce_interval,
move |maybe_channel: Option<zx::Channel>| {
_ = unbound_sender.send(maybe_channel);
},
)
},
)
}
.boxed()
});
Box::pin(
unbound_receiver
.map(|result| match result {
Ok(maybe_channel) => maybe_channel,
Err(Canceled) => None,
})
.fuse(),
)
}
}
impl<ServiceObjTy: ServiceObjTrait> StallableServiceFs<ServiceObjTy> {
pub(crate) fn new(
mut fs: ServiceFs<ServiceObjTy>,
debounce_interval: zx::MonotonicDuration,
) -> Self {
let channel_queue =
fs.channel_queue.as_mut().expect("Must not poll the original ServiceFs");
assert!(
channel_queue.len() == 1,
"Must have exactly one connection to serve, \
e.g. did you call ServiceFs::take_and_serve_directory_handle?"
);
let server_end = std::mem::replace(channel_queue, vec![]).into_iter().next().unwrap();
let flags = ServiceFs::<ServiceObjTy>::base_connection_flags();
let scope = fs.scope.clone();
let dir = fs.dir.clone();
let mut connector = OutgoingConnector { flags, scope, dir };
let stalled = connector.serve(server_end, debounce_interval);
Self {
fs,
connector,
state: State::Running { stalled },
debounce_interval,
is_terminated: false,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use assert_matches::assert_matches;
use fasync::TestExecutor;
use fidl::endpoints::ClientEnd;
use fidl_fuchsia_component_client_test::{
ProtocolAMarker, ProtocolARequest, ProtocolARequestStream,
};
use futures::future::BoxFuture;
use futures::{pin_mut, select, TryStreamExt};
use test_util::Counter;
use zx::AsHandleRef;
use super::*;
enum Requests {
ServiceA(ProtocolARequestStream),
}
#[derive(Clone)]
struct MockServer {
call_count: Arc<Counter>,
stalled: Arc<AtomicBool>,
server_end: Arc<Mutex<Option<zx::Channel>>>,
}
impl MockServer {
fn new() -> Self {
let call_count = Arc::new(Counter::new(0));
let stalled = Arc::new(AtomicBool::new(false));
let server_end = Arc::new(Mutex::new(None));
Self { call_count, stalled, server_end }
}
fn handle(&self, item: Item<Requests>) -> BoxFuture<'static, ()> {
let stalled = self.stalled.clone();
let call_count = self.call_count.clone();
let server_end = self.server_end.clone();
async move {
match item {
Item::Request(requests, active_guard) => {
let _active_guard = active_guard;
let Requests::ServiceA(mut request_stream) = requests;
while let Ok(Some(request)) = request_stream.try_next().await {
match request {
ProtocolARequest::Foo { responder } => {
call_count.inc();
let _ = responder.send();
}
}
}
}
Item::Stalled(channel) => {
*server_end.lock().unwrap() = Some(channel);
stalled.store(true, Ordering::SeqCst);
}
}
}
.boxed()
}
#[track_caller]
fn assert_fs_gave_back_server_end(self, client_end: ClientEnd<fio::DirectoryMarker>) {
let reclaimed_server_end: zx::Channel = self.server_end.lock().unwrap().take().unwrap();
assert_eq!(
client_end.get_koid().unwrap(),
reclaimed_server_end.basic_info().unwrap().related_koid
)
}
}
async fn setup_test(
server_end: ServerEnd<fio::DirectoryMarker>,
) -> (fasync::MonotonicInstant, MockServer, impl FusedFuture<Output = ()>) {
let initial = fasync::MonotonicInstant::from_nanos(0);
TestExecutor::advance_to(initial).await;
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let mut fs = ServiceFs::new();
fs.serve_connection(server_end).unwrap().dir("svc").add_fidl_service(Requests::ServiceA);
let mock_server = MockServer::new();
let mock_server_clone = mock_server.clone();
let fs = fs
.until_stalled(IDLE_DURATION)
.for_each_concurrent(None, move |item| mock_server_clone.handle(item));
(initial, mock_server, fs)
}
#[fuchsia::test(allow_stalls = false)]
async fn drain_request() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
const NUM_FOO_REQUESTS: usize = 10;
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
let mut proxies = Vec::new();
for _ in 0..NUM_FOO_REQUESTS {
proxies.push(
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end)
.unwrap(),
);
}
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + (IDLE_DURATION * 2)).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
for proxy in proxies.iter() {
select! {
result = proxy.foo().fuse() => assert_matches!(result, Ok(_)),
_ = fs => unreachable!(),
};
}
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
drop(proxies);
fs.await;
assert_eq!(mock_server.call_count.get(), NUM_FOO_REQUESTS);
assert!(mock_server.stalled.load(Ordering::SeqCst));
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn no_request() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + IDLE_DURATION).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_ready());
assert_eq!(mock_server.call_count.get(), 0);
assert!(mock_server.stalled.load(Ordering::SeqCst));
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn outgoing_dir_client_closed() {
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (_initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
drop(client_end);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_ready());
assert_eq!(mock_server.call_count.get(), 0);
assert!(!mock_server.stalled.load(Ordering::SeqCst));
assert!(mock_server.server_end.lock().unwrap().is_none());
}
#[fuchsia::test(allow_stalls = false)]
async fn request_then_stalled() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let proxy =
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end).unwrap();
let foo = proxy.foo().fuse();
pin_mut!(foo);
assert!(TestExecutor::poll_until_stalled(&mut foo).await.is_pending());
let (initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
assert_eq!(mock_server.call_count.get(), 0);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
assert_eq!(mock_server.call_count.get(), 1);
assert_matches!(foo.await, Ok(_));
drop(proxy);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + IDLE_DURATION).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_ready());
assert_eq!(mock_server.call_count.get(), 1);
assert!(mock_server.stalled.load(Ordering::SeqCst));
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn stalled_then_request() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + (IDLE_DURATION / 2)).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
let proxy =
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end).unwrap();
select! {
result = proxy.foo().fuse() => assert_matches!(result, Ok(_)),
_ = fs => unreachable!(),
};
assert_eq!(mock_server.call_count.get(), 1);
drop(proxy);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + (IDLE_DURATION / 2) + IDLE_DURATION).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_ready());
assert!(mock_server.stalled.load(Ordering::SeqCst));
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn periodic_requests() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (mut current_time, mock_server, fs) = setup_test(server_end).await;
let fs = fasync::Task::local(fs);
const NUM_FOO_REQUESTS: usize = 10;
for _ in 0..NUM_FOO_REQUESTS {
let request_interval = IDLE_DURATION / 2;
current_time += request_interval;
TestExecutor::advance_to(current_time).await;
let proxy =
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end)
.unwrap();
assert_matches!(proxy.foo().await, Ok(_));
}
assert_eq!(mock_server.call_count.get(), NUM_FOO_REQUESTS);
for _ in 0..NUM_FOO_REQUESTS {
let request_interval = IDLE_DURATION * 2;
current_time += request_interval;
TestExecutor::advance_to(current_time).await;
let proxy =
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end)
.unwrap();
let foo = proxy.foo();
pin_mut!(foo);
assert_matches!(TestExecutor::poll_until_stalled(&mut foo).await, Poll::Pending);
}
assert_eq!(mock_server.call_count.get(), NUM_FOO_REQUESTS);
fs.await;
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn some_other_outgoing_dir_connection_blocks_stalling() {
const IDLE_DURATION: MonotonicDuration = MonotonicDuration::from_nanos(1_000_000);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let (initial, mock_server, fs) = setup_test(server_end).await;
pin_mut!(fs);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
{
let svc = crate::directory::open_directory_async(&client_end, "svc", fio::R_STAR_DIR)
.unwrap();
TestExecutor::advance_to(initial + IDLE_DURATION).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
assert_matches!(
fuchsia_fs::directory::readdir(&svc).await,
Ok(ref entries)
if entries.len() == 1 && entries[0].name == "fuchsia.component.client.test.ProtocolA"
);
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
TestExecutor::advance_to(initial + (IDLE_DURATION * 3)).await;
assert!(TestExecutor::poll_until_stalled(&mut fs).await.is_pending());
}
fs.await;
assert!(mock_server.stalled.load(Ordering::SeqCst));
mock_server.assert_fs_gave_back_server_end(client_end);
}
#[fuchsia::test(allow_stalls = false)]
async fn end_to_end() {
let initial = fasync::MonotonicInstant::from_nanos(0);
TestExecutor::advance_to(initial).await;
let mock_server = MockServer::new();
let mock_server_clone = mock_server.clone();
const MIN_REQUEST_INTERVAL: i64 = 10_000_000;
let idle_duration = MonotonicDuration::from_nanos(MIN_REQUEST_INTERVAL * 5);
let (client_end, server_end) = fidl::endpoints::create_endpoints::<fio::DirectoryMarker>();
let component_task = async move {
let mut server_end = Some(server_end);
let mut loop_count = 0;
loop {
let mut fs = ServiceFs::new();
fs.serve_connection(server_end.unwrap())
.unwrap()
.dir("svc")
.add_fidl_service(Requests::ServiceA);
let mock_server_clone = mock_server_clone.clone();
fs.until_stalled(idle_duration)
.for_each_concurrent(None, move |item| mock_server_clone.handle(item))
.await;
let stalled_server_end = mock_server.server_end.lock().unwrap().take();
let Some(stalled_server_end) = stalled_server_end else {
return loop_count;
};
fasync::OnSignals::new(
&stalled_server_end,
zx::Signals::CHANNEL_READABLE | zx::Signals::CHANNEL_PEER_CLOSED,
)
.await
.unwrap();
server_end = Some(stalled_server_end.into());
loop_count += 1;
}
};
let component_task = fasync::Task::local(component_task);
let mut deadline = initial;
const NUM_REQUESTS: usize = 30;
for delay_factor in 0..NUM_REQUESTS {
let proxy =
crate::client::connect_to_protocol_at_dir_svc::<ProtocolAMarker>(&client_end)
.unwrap();
proxy.foo().await.unwrap();
drop(proxy);
deadline += MonotonicDuration::from_nanos(MIN_REQUEST_INTERVAL * (delay_factor as i64));
TestExecutor::advance_to(deadline).await;
}
drop(client_end);
let loop_count = component_task.await;
assert_eq!(loop_count, 25);
assert_eq!(mock_server.call_count.get(), NUM_REQUESTS);
assert!(mock_server.stalled.load(Ordering::SeqCst));
}
}