1use crate::channel::{Cbw, Channel};
12use crate::ie::intersect::*;
13use crate::ie::{
14 self, HtCapabilities, SupportedRate, VhtCapabilities, parse_ht_capabilities,
15 parse_vht_capabilities,
16};
17use crate::mac::CapabilityInfo;
18use anyhow::{Context as _, Error, format_err};
19use {fidl_fuchsia_wlan_ieee80211 as fidl_ieee80211, fidl_fuchsia_wlan_mlme as fidl_mlme};
20
21const OVERRIDE_CAP_INFO_ESS: bool = true;
25const OVERRIDE_CAP_INFO_IBSS: bool = false;
26
27const OVERRIDE_CAP_INFO_CF_POLLABLE: bool = false;
30const OVERRIDE_CAP_INFO_CF_POLL_REQUEST: bool = false;
31
32const OVERRIDE_CAP_INFO_PRIVACY: bool = false;
35
36const OVERRIDE_CAP_INFO_SPECTRUM_MGMT: bool = false;
38
39const OVERRIDE_HT_CAP_INFO_TX_STBC: bool = false;
41
42const OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET: u32 = 0;
47
48fn override_capability_info(capability_info: CapabilityInfo) -> CapabilityInfo {
51 capability_info
52 .with_ess(OVERRIDE_CAP_INFO_ESS)
53 .with_ibss(OVERRIDE_CAP_INFO_IBSS)
54 .with_cf_pollable(OVERRIDE_CAP_INFO_CF_POLLABLE)
55 .with_cf_poll_req(OVERRIDE_CAP_INFO_CF_POLL_REQUEST)
56 .with_privacy(OVERRIDE_CAP_INFO_PRIVACY)
57 .with_spectrum_mgmt(OVERRIDE_CAP_INFO_SPECTRUM_MGMT)
58}
59
60pub fn derive_join_capabilities(
65 bss_channel: Channel,
66 bss_rates: &[SupportedRate],
67 device_info: &fidl_mlme::DeviceInfo,
68) -> Result<ClientCapabilities, Error> {
69 let band_cap = get_band_cap_for_channel(&device_info.bands[..], bss_channel)
71 .context(format!("iface does not support BSS channel {}", bss_channel.primary))?;
72
73 let capability_info =
77 override_capability_info(CapabilityInfo(device_info.softmac_hardware_capability as u16));
78
79 let client_rates = band_cap.basic_rates.iter().map(|&r| SupportedRate(r)).collect::<Vec<_>>();
82 let rates = intersect_rates(ApRates(bss_rates), ClientRates(&client_rates))
83 .map_err(|error| format_err!("could not intersect rates: {:?}", error))
84 .context(format!("deriving rates: {:?} + {:?}", band_cap.basic_rates, bss_rates))?;
85
86 let (ht_cap, vht_cap) =
89 override_ht_vht(band_cap.ht_cap.as_ref(), band_cap.vht_cap.as_ref(), bss_channel.cbw)?;
90
91 Ok(ClientCapabilities(StaCapabilities { capability_info, rates, ht_cap, vht_cap }))
92}
93
94fn override_ht_vht(
97 fidl_ht_cap: Option<&Box<fidl_ieee80211::HtCapabilities>>,
98 fidl_vht_cap: Option<&Box<fidl_ieee80211::VhtCapabilities>>,
99 cbw: Cbw,
100) -> Result<(Option<HtCapabilities>, Option<VhtCapabilities>), Error> {
101 if fidl_ht_cap.is_none() && fidl_vht_cap.is_some() {
102 return Err(format_err!("VHT Cap without HT Cap is invalid."));
103 }
104
105 let ht_cap = match fidl_ht_cap {
106 Some(h) => {
107 let ht_cap = *parse_ht_capabilities(&h.bytes[..]).context("verifying HT Cap")?;
108 Some(override_ht_capabilities(ht_cap, cbw))
109 }
110 None => None,
111 };
112
113 let vht_cap = match fidl_vht_cap {
114 Some(v) => {
115 let vht_cap = *parse_vht_capabilities(&v.bytes[..]).context("verifying VHT Cap")?;
116 Some(override_vht_capabilities(vht_cap, cbw))
117 }
118 None => None,
119 };
120 Ok((ht_cap, vht_cap))
121}
122
123fn override_ht_capabilities(mut ht_cap: HtCapabilities, cbw: Cbw) -> HtCapabilities {
126 let mut ht_cap_info = ht_cap.ht_cap_info.with_tx_stbc(OVERRIDE_HT_CAP_INFO_TX_STBC);
127 match cbw {
128 Cbw::Cbw20 => ht_cap_info.set_chan_width_set(ie::ChanWidthSet::TWENTY_ONLY),
129 _ => (),
130 }
131 ht_cap.ht_cap_info = ht_cap_info;
132 ht_cap
133}
134
135fn override_vht_capabilities(mut vht_cap: VhtCapabilities, cbw: Cbw) -> VhtCapabilities {
138 let mut vht_cap_info = vht_cap.vht_cap_info;
139 if vht_cap_info.supported_cbw_set() != OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET {
140 match cbw {
145 Cbw::Cbw160 | Cbw::Cbw80P80 { secondary80: _ } => (),
146 _ => vht_cap_info.set_supported_cbw_set(OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET),
147 }
148 }
149 vht_cap.vht_cap_info = vht_cap_info;
150 vht_cap
151}
152
153pub fn get_band_cap_for_channel(
154 bands: &[fidl_mlme::BandCapability],
155 channel: Channel,
156) -> Result<&fidl_mlme::BandCapability, anyhow::Error> {
157 let target = channel.get_band().context("Failed to retrieve band capabilities")?;
158 bands
159 .iter()
160 .find(|b| b.band == target && b.operating_channels.contains(&channel.primary))
161 .ok_or_else(|| format_err!("No band capability for channel {channel:?}: {bands:?}"))
162}
163
164#[derive(Debug, PartialEq)]
170pub struct StaCapabilities {
171 pub capability_info: CapabilityInfo,
172 pub rates: Vec<SupportedRate>,
173 pub ht_cap: Option<HtCapabilities>,
174 pub vht_cap: Option<VhtCapabilities>,
175}
176
177#[derive(Debug, PartialEq)]
178pub struct ClientCapabilities(pub StaCapabilities);
179#[derive(Debug, PartialEq)]
180pub struct ApCapabilities(pub StaCapabilities);
181
182pub fn intersect_with_ap_as_client(
184 client: &ClientCapabilities,
185 ap: &ApCapabilities,
186) -> Result<StaCapabilities, Error> {
187 let rates = intersect_rates(ApRates(&ap.0.rates[..]), ClientRates(&client.0.rates[..]))
188 .map_err(|e| format_err!("could not intersect rates: {:?}", e))?;
189 let (capability_info, ht_cap, vht_cap) = intersect(&client.0, &ap.0);
190 Ok(StaCapabilities { rates, capability_info, ht_cap, vht_cap })
191}
192
193pub fn intersect_with_remote_client_as_ap(
195 ap: &ApCapabilities,
196 remote_client: &ClientCapabilities,
197) -> StaCapabilities {
198 let rates = intersect_rates(ApRates(&ap.0.rates[..]), ClientRates(&remote_client.0.rates[..]))
200 .unwrap_or(vec![]);
201 let (capability_info, ht_cap, vht_cap) = intersect(&ap.0, &remote_client.0);
202 StaCapabilities { rates, capability_info, ht_cap, vht_cap }
203}
204
205fn intersect(
206 ours: &StaCapabilities,
207 theirs: &StaCapabilities,
208) -> (CapabilityInfo, Option<HtCapabilities>, Option<VhtCapabilities>) {
209 let capability_info = CapabilityInfo(ours.capability_info.raw() & theirs.capability_info.raw());
211 let ht_cap = match (ours.ht_cap, theirs.ht_cap) {
212 (Some(ours), Some(theirs)) => Some(ours.intersect(&theirs)),
214 _ => None,
215 };
216 let vht_cap = match (ours.vht_cap, theirs.vht_cap) {
217 (Some(ours), Some(theirs)) => Some(ours.intersect(&theirs)),
219 _ => None,
220 };
221 (capability_info, ht_cap, vht_cap)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::mac;
228 use crate::test_utils::fake_capabilities::fake_5ghz_band_capability_ht;
229 use assert_matches::assert_matches;
230 use fidl_fuchsia_wlan_common as fidl_common;
231
232 #[test]
233 fn test_build_cap_info() {
234 let capability_info = CapabilityInfo(0)
235 .with_ess(!OVERRIDE_CAP_INFO_ESS)
236 .with_ibss(!OVERRIDE_CAP_INFO_IBSS)
237 .with_cf_pollable(!OVERRIDE_CAP_INFO_CF_POLLABLE)
238 .with_cf_poll_req(!OVERRIDE_CAP_INFO_CF_POLL_REQUEST)
239 .with_privacy(!OVERRIDE_CAP_INFO_PRIVACY)
240 .with_spectrum_mgmt(!OVERRIDE_CAP_INFO_SPECTRUM_MGMT);
241 let capability_info = override_capability_info(capability_info);
242 assert_eq!(capability_info.ess(), OVERRIDE_CAP_INFO_ESS);
243 assert_eq!(capability_info.ibss(), OVERRIDE_CAP_INFO_IBSS);
244 assert_eq!(capability_info.cf_pollable(), OVERRIDE_CAP_INFO_CF_POLLABLE);
245 assert_eq!(capability_info.cf_poll_req(), OVERRIDE_CAP_INFO_CF_POLL_REQUEST);
246 assert_eq!(capability_info.privacy(), OVERRIDE_CAP_INFO_PRIVACY);
247 assert_eq!(capability_info.spectrum_mgmt(), OVERRIDE_CAP_INFO_SPECTRUM_MGMT);
248 }
249
250 #[test]
251 fn test_override_ht_cap() {
252 let mut ht_cap = ie::fake_ht_capabilities();
253 let ht_cap_info = ht_cap
254 .ht_cap_info
255 .with_tx_stbc(!OVERRIDE_HT_CAP_INFO_TX_STBC)
256 .with_chan_width_set(ie::ChanWidthSet::TWENTY_FORTY);
257 ht_cap.ht_cap_info = ht_cap_info;
258 let mut channel = Channel { primary: 153, cbw: Cbw::Cbw20 };
259
260 let ht_cap_info = override_ht_capabilities(ht_cap, channel.cbw).ht_cap_info;
261 assert_eq!(ht_cap_info.tx_stbc(), OVERRIDE_HT_CAP_INFO_TX_STBC);
262 assert_eq!(ht_cap_info.chan_width_set(), ie::ChanWidthSet::TWENTY_ONLY);
263
264 channel.cbw = Cbw::Cbw40;
265 let ht_cap_info = override_ht_capabilities(ht_cap, channel.cbw).ht_cap_info;
266 assert_eq!(ht_cap_info.chan_width_set(), ie::ChanWidthSet::TWENTY_FORTY);
267 }
268
269 #[test]
270 fn test_override_vht_cap() {
271 let mut vht_cap = ie::fake_vht_capabilities();
272 let vht_cap_info = vht_cap.vht_cap_info.with_supported_cbw_set(2);
273 vht_cap.vht_cap_info = vht_cap_info;
274 let mut channel = Channel { primary: 153, cbw: Cbw::Cbw20 };
275
276 let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
279 assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
280
281 channel.cbw = Cbw::Cbw40;
282 let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
283 assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
284
285 channel.cbw = Cbw::Cbw80;
286 let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
287 assert_eq!(vht_cap_info.supported_cbw_set(), OVERRIDE_VHT_CAP_INFO_SUPPORTED_CBW_SET);
288
289 channel.cbw = Cbw::Cbw160;
292 let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
293 assert_eq!(vht_cap_info.supported_cbw_set(), 2);
294
295 channel.cbw = Cbw::Cbw80P80 { secondary80: 42 };
296 let vht_cap_info = override_vht_capabilities(vht_cap, channel.cbw).vht_cap_info;
297 assert_eq!(vht_cap_info.supported_cbw_set(), 2);
298 }
299
300 #[test]
301 fn test_get_device_band_cap() {
302 let device_info = fidl_mlme::DeviceInfo {
303 sta_addr: [0; 6],
304 role: fidl_common::WlanMacRole::Client,
305 bands: vec![fake_5ghz_band_capability_ht(ie::ChanWidthSet::TWENTY_FORTY)],
306 softmac_hardware_capability: 0,
307 qos_capable: true,
308 };
309 assert_eq!(
310 fidl_ieee80211::WlanBand::FiveGhz,
311 get_band_cap_for_channel(&device_info.bands[..], Channel::new(36, Cbw::Cbw20))
312 .unwrap()
313 .band
314 );
315 }
316
317 fn fake_client_join_cap() -> ClientCapabilities {
318 ClientCapabilities(StaCapabilities {
319 capability_info: mac::CapabilityInfo(0x1234),
320 rates: [101, 102, 103, 104].iter().cloned().map(SupportedRate).collect(),
321 ht_cap: Some(HtCapabilities {
322 ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(2).with_tx_stbc(false),
323 ..ie::fake_ht_capabilities()
324 }),
325 vht_cap: Some(ie::fake_vht_capabilities()),
326 })
327 }
328
329 fn fake_ap_join_cap() -> ApCapabilities {
330 ApCapabilities(StaCapabilities {
331 capability_info: mac::CapabilityInfo(0x4321),
332 rates: [101 + 128, 102, 9].iter().cloned().map(SupportedRate).collect(),
334 ht_cap: Some(HtCapabilities {
335 ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(1).with_tx_stbc(true),
336 ..ie::fake_ht_capabilities()
337 }),
338 vht_cap: Some(ie::fake_vht_capabilities()),
339 })
340 }
341
342 #[test]
343 fn client_intersect_with_ap() {
344 let caps = assert_matches!(
345 intersect_with_ap_as_client(&fake_client_join_cap(), &fake_ap_join_cap()),
346 Ok(caps) => caps
347 );
348 assert_eq!(
349 caps,
350 StaCapabilities {
351 capability_info: mac::CapabilityInfo(0x0220),
352 rates: [229, 102].iter().cloned().map(SupportedRate).collect(),
353 ht_cap: Some(HtCapabilities {
354 ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(2).with_tx_stbc(false),
355 ..ie::fake_ht_capabilities()
356 }),
357 ..fake_client_join_cap().0
358 }
359 )
360 }
361
362 #[test]
363 fn ap_intersect_with_remote_client() {
364 assert_eq!(
365 intersect_with_remote_client_as_ap(&fake_ap_join_cap(), &fake_client_join_cap()),
366 StaCapabilities {
367 capability_info: mac::CapabilityInfo(0x0220),
368 rates: [229, 102].iter().cloned().map(SupportedRate).collect(),
369 ht_cap: Some(HtCapabilities {
370 ht_cap_info: ie::HtCapabilityInfo(0).with_rx_stbc(0).with_tx_stbc(true),
371 ..ie::fake_ht_capabilities()
372 }),
373 ..fake_ap_join_cap().0
374 }
375 );
376 }
377}