wlan_sae/
state.rs

1// Copyright 2019 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use super::boringssl::{Bignum, BignumCtx};
6use super::frame::{write_commit, write_confirm};
7use super::internal::{FiniteCyclicGroup, SaeParameters};
8use super::{
9    AntiCloggingTokenMsg, CommitMsg, ConfirmMsg, Key, RejectReason, SaeHandshake, SaeUpdate,
10    SaeUpdateSink, Timeout,
11};
12use anyhow::{bail, format_err, Error};
13use log::{error, warn};
14use wlan_statemachine::*;
15
16// TODO(https://fxbug.dev/42118302): Handle received timeouts.
17// TODO(https://fxbug.dev/42118769): Handle BadGrp/DiffGrp.
18// TODO(https://fxbug.dev/42118770): Handle frame status.
19
20/// We store an FcgConstructor rather than a FiniteCyclicGroup so that our handshake
21/// can impl `Send`. FCGs are not generally `Send`, so we construct them on the fly.
22type FcgConstructor<E> =
23    Box<dyn Fn() -> Result<Box<dyn FiniteCyclicGroup<Element = E>>, Error> + Send + 'static>;
24
25struct SaeConfiguration<E> {
26    fcg: FcgConstructor<E>,
27    params: SaeParameters,
28    pwe: Vec<u8>,
29}
30
31struct Commit<E> {
32    scalar: Bignum,
33    element: E,
34}
35
36struct SerializedCommit {
37    scalar: Vec<u8>,
38    element: Vec<u8>,
39}
40
41#[derive(Debug, PartialEq)]
42struct Kck(Vec<u8>);
43
44impl<E> Commit<E> {
45    /// IEEE 802.11-2016 12.4.7.4
46    /// Returns the serialized scalar and element with appropriate padding as needed.
47    fn serialize(&self, config: &SaeConfiguration<E>) -> Result<SerializedCommit, Error> {
48        let fcg = (config.fcg)()?;
49        let scalar_size = fcg.scalar_size()?;
50        let scalar = self.scalar.to_be_vec(scalar_size);
51        let element = fcg.element_to_octets(&self.element)?;
52        Ok(SerializedCommit { scalar, element })
53    }
54}
55
56impl SerializedCommit {
57    fn deserialize<E>(&self, config: &SaeConfiguration<E>) -> Result<Commit<E>, Error> {
58        let fcg = (config.fcg)()?;
59        let scalar = Bignum::new_from_slice(&self.scalar[..])?;
60        let element = match fcg.element_from_octets(&self.element)? {
61            Some(element) => element,
62            None => bail!("Attempted to deserialize invalid FCG element"),
63        };
64        Ok(Commit { scalar, element })
65    }
66}
67
68struct SaeNew<E> {
69    config: SaeConfiguration<E>,
70}
71
72struct SaeCommitted<E> {
73    config: SaeConfiguration<E>,
74    rand: Vec<u8>,
75    commit: SerializedCommit,
76    sync: u16,
77    anti_clogging_token: Vec<u8>,
78}
79
80struct SaeConfirmed<E> {
81    config: SaeConfiguration<E>,
82    commit: SerializedCommit,
83    peer_commit: SerializedCommit,
84    kck: Kck,
85    key: Key,
86    sc: u16, // send confirm
87    rc: u16, // receive confirm
88    sync: u16,
89}
90
91// Everything is finished in this state. We keep around the old SaeConfirmed struct in case we need
92// to replay our confirm frame.
93struct SaeAccepted<E>(SaeConfirmed<E>);
94struct SaeFailed;
95
96statemachine!(
97    enum SaeHandshakeState<E>,
98    () => SaeNew<E>,
99    SaeNew<E> => [SaeCommitted<E>, SaeConfirmed<E>, SaeFailed],
100    // SaeCommitted does not self-loop because a retry does not update any state.
101    SaeCommitted<E> => [SaeConfirmed<E>, SaeFailed],
102    // SaeConfirmed can self-loop because retries increment send_confirm.
103    SaeConfirmed<E> => [SaeAccepted<E>, SaeFailed, SaeConfirmed<E>],
104    SaeAccepted<E> => SaeFailed,
105);
106
107/// This enum is used in any place where an operation can either succeed, or result in silently
108/// dropping the received frame.
109enum FrameResult<T> {
110    /// The frame was processed and the given output produced as a result.
111    Proceed(T),
112    /// The frame was incorrect and should be dropped silently.
113    Drop,
114}
115
116impl<T> std::fmt::Debug for FrameResult<T> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Proceed(_) => write!(f, "FrameResult::Proceed"),
120            Self::Drop => write!(f, "FrameResult::Drop"),
121        }
122    }
123}
124
125/// IEEE 802.11-2016 12.4.5.4
126/// Returns the calculated pairwise key and peer commit, or None if the given peer element is
127/// invalid.
128fn process_commit<E>(
129    config: &SaeConfiguration<E>,
130    rand: &Bignum,
131    commit: &Commit<E>,
132    peer_scalar: &[u8],
133    peer_element: &[u8],
134) -> Result<FrameResult<(Commit<E>, Kck, Key)>, RejectReason> {
135    let fcg = (config.fcg)()?;
136    // Parse the peer element.
137    let peer_commit = match fcg.element_from_octets(peer_element)? {
138        Some(element) => Commit { scalar: Bignum::new_from_slice(peer_scalar)?, element },
139        None => return Ok(FrameResult::Drop),
140    };
141
142    let pwe =
143        fcg.element_from_octets(&config.pwe)?.ok_or_else(|| format_err!("Could not unwrap PWE"))?;
144    let element_k = fcg.scalar_op(
145        rand,
146        &fcg.elem_op(&fcg.scalar_op(&peer_commit.scalar, &pwe)?, &peer_commit.element)?,
147    )?;
148    let k = match fcg.map_to_secret_value(&element_k)? {
149        Some(k) => k,
150        None => return Ok(FrameResult::Drop), // This is an auth failure.
151    };
152    let ctx = BignumCtx::new()?;
153    let keyseed = config.params.hmac.hkdf_extract(&[0u8; 32][..], &k);
154    let sha_ctx = peer_commit
155        .scalar
156        .mod_add(&commit.scalar, &fcg.order()?, &ctx)?
157        .to_be_vec(fcg.scalar_size()?);
158    let q = config.params.hmac.bits();
159    let kck_and_pmk =
160        config.params.hmac.kdf_hash_length(&keyseed[..], "SAE KCK and PMK", &sha_ctx[..], q + 256);
161    let kck = kck_and_pmk[0..q / 8].to_vec();
162    let pmk = kck_and_pmk[q / 8..(q + 256) / 8].to_vec();
163    let pmkid = sha_ctx[0..16].to_vec();
164    Ok(FrameResult::Proceed((peer_commit, Kck(kck), Key { pmk, pmkid })))
165}
166
167/// IEEE 802.11-2016 12.4.5.{5,6}
168/// Computes the confirm value for sending or validating a confirm message. This can only fail from
169/// an internal error.
170fn compute_confirm<E>(
171    config: &SaeConfiguration<E>,
172    kck: &Kck,
173    send_confirm: u16,
174    commit1: &SerializedCommit,
175    commit2: &SerializedCommit,
176) -> Result<Vec<u8>, RejectReason> {
177    Ok(config.params.hmac.confirm(
178        &kck.0[..],
179        send_confirm,
180        &[&commit1.scalar[..], &commit1.element[..], &commit2.scalar[..], &commit2.element[..]],
181    ))
182}
183
184/// Helper function to reject the authentication after too many retries.
185fn check_sync(sync: &u16) -> Result<(), RejectReason> {
186    // IEEE says we should only fail if sync exceeds our limit, but failing on equality as well gives
187    // MAX_RETRIES_PER_EXCHANGE slightly more obvious behavior.
188    if *sync >= super::MAX_RETRIES_PER_EXCHANGE {
189        Err(RejectReason::TooManyRetries)
190    } else {
191        Ok(())
192    }
193}
194
195/// IEEE 802.11-2016 12.4.5.3
196impl<E> SaeNew<E> {
197    fn commit(&self) -> Result<(Vec<u8>, SerializedCommit), Error> {
198        let fcg = (self.config.fcg)()?;
199        let order = fcg.order()?;
200        let ctx = BignumCtx::new()?;
201        let (rand, mask, scalar) = loop {
202            // 1 < rand < order
203            let rand = Bignum::rand(&order.sub(Bignum::new_from_u64(2)?)?)?
204                .add(Bignum::new_from_u64(2)?)?;
205            // 1 < mask < order
206            let mask = Bignum::rand(&order.sub(Bignum::new_from_u64(2)?)?)?
207                .add(Bignum::new_from_u64(2)?)?;
208            let commit_scalar = rand.mod_add(&mask, &order, &ctx)?;
209            if !commit_scalar.is_zero() && !commit_scalar.is_one() {
210                break (rand, mask, commit_scalar);
211            }
212        };
213        let pwe = fcg
214            .element_from_octets(&self.config.pwe)?
215            .ok_or_else(|| format_err!("Could not unwrap PWE"))?;
216        let element = fcg.inverse_op(fcg.scalar_op(&mask, &pwe)?)?;
217        Ok((
218            rand.to_be_vec(fcg.scalar_size()?),
219            Commit { scalar, element }.serialize(&self.config)?,
220        ))
221    }
222
223    fn send_first_commit(
224        &self,
225        sink: &mut SaeUpdateSink,
226    ) -> Result<(Vec<u8>, SerializedCommit), RejectReason> {
227        let (rand, commit) = self.commit()?;
228        let group_id = (self.config.fcg)()?.group_id();
229        sink.push(SaeUpdate::SendFrame(write_commit(
230            group_id,
231            &commit.scalar[..],
232            &commit.element[..],
233            &[],
234        )));
235        sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
236        Ok((rand, commit))
237    }
238
239    fn handle_commit(
240        &self,
241        sink: &mut SaeUpdateSink,
242        commit_msg: &CommitMsg<'_>,
243    ) -> Result<(Vec<u8>, SerializedCommit, SerializedCommit, Kck, Key), RejectReason> {
244        let (serialized_rand, serialized_commit) = self.commit()?;
245        let commit = serialized_commit.deserialize(&self.config)?;
246        let rand = Bignum::new_from_slice(&serialized_rand[..])?;
247        let (peer_commit, kck, key) = match process_commit(
248            &self.config,
249            &rand,
250            &commit,
251            &commit_msg.scalar[..],
252            &commit_msg.element[..],
253        )? {
254            FrameResult::Proceed(res) => res,
255            // If we drop the first frame, reject the authentication immediately.
256            FrameResult::Drop => return Err(RejectReason::AuthFailed),
257        };
258        let peer_commit = peer_commit.serialize(&self.config)?;
259        let confirm = compute_confirm(&self.config, &kck, 1, &serialized_commit, &peer_commit)?;
260        // We do not send our own commit message unless we process the peer's successfully.
261        let group_id = (self.config.fcg)()?.group_id();
262        sink.push(SaeUpdate::SendFrame(write_commit(
263            group_id,
264            &serialized_commit.scalar[..],
265            &serialized_commit.element[..],
266            &[],
267        )));
268        sink.push(SaeUpdate::SendFrame(write_confirm(1, &confirm[..])));
269        sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
270        Ok((serialized_rand, serialized_commit, peer_commit, kck, key))
271    }
272}
273
274/// IEEE 802.11-2016 12.4.8.6.4
275impl<E> SaeCommitted<E> {
276    fn handle_commit(
277        &self,
278        sink: &mut SaeUpdateSink,
279        commit_msg: &CommitMsg<'_>,
280    ) -> Result<FrameResult<(SerializedCommit, Kck, Key)>, RejectReason> {
281        if &commit_msg.scalar[..] == &self.commit.scalar[..]
282            && &commit_msg.element[..] == &self.commit.element[..]
283        {
284            // This is a reflection attack.
285            sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
286            return Ok(FrameResult::Drop);
287        }
288        let (peer_commit, kck, key) = match process_commit(
289            &self.config,
290            &Bignum::new_from_slice(&self.rand[..])?,
291            &self.commit.deserialize(&self.config)?,
292            &commit_msg.scalar[..],
293            &commit_msg.element[..],
294        )? {
295            FrameResult::Proceed(res) => res,
296            // IEEE doesn't specify that we do anything in this case. It might make sense to reset
297            // the retransmission timer, but we stick with the spec.
298            FrameResult::Drop => return Ok(FrameResult::Drop),
299        };
300        let peer_commit = peer_commit.serialize(&self.config)?;
301        let confirm = compute_confirm(&self.config, &kck, 1, &self.commit, &peer_commit)?;
302        sink.push(SaeUpdate::SendFrame(write_confirm(1, &confirm[..])));
303        sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
304        Ok(FrameResult::Proceed((peer_commit, kck, key)))
305    }
306
307    fn resend_last_frame(&mut self, sink: &mut SaeUpdateSink) -> Result<(), RejectReason> {
308        check_sync(&self.sync)?;
309        self.sync += 1;
310        // We resend our last commit.
311        let group_id = (self.config.fcg)()?.group_id();
312        sink.push(SaeUpdate::SendFrame(write_commit(
313            group_id,
314            &self.commit.scalar[..],
315            &self.commit.element[..],
316            &self.anti_clogging_token[..],
317        )));
318        sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
319        Ok(())
320    }
321
322    fn handle_confirm(
323        &mut self,
324        sink: &mut SaeUpdateSink,
325        _confirm_msg: &ConfirmMsg<'_>,
326    ) -> Result<(), RejectReason> {
327        self.resend_last_frame(sink)
328    }
329
330    fn handle_anti_clogging_token(
331        &mut self,
332        sink: &mut SaeUpdateSink,
333        act_msg: &AntiCloggingTokenMsg<'_>,
334    ) -> Result<(), RejectReason> {
335        self.anti_clogging_token = act_msg.anti_clogging_token.to_vec();
336        self.resend_last_frame(sink)
337    }
338
339    fn handle_timeout(
340        &mut self,
341        sink: &mut SaeUpdateSink,
342        timeout: Timeout,
343    ) -> Result<(), RejectReason> {
344        match timeout {
345            Timeout::Retransmission => self.resend_last_frame(sink),
346            Timeout::KeyExpiration => {
347                Err(format_err!("Unexpected key expiration timout before PMKSA established.")
348                    .into())
349            }
350        }
351    }
352}
353
354/// IEEE 802.11-2016 12.4.8.6.5
355impl<E> SaeConfirmed<E> {
356    fn handle_commit(
357        &mut self,
358        sink: &mut SaeUpdateSink,
359        _commit_msg: &CommitMsg<'_>,
360    ) -> Result<(), RejectReason> {
361        // The peer did not receive our previous commit or confirm.
362        check_sync(&self.sync)?;
363        // IEEE Std 802.11 does *not* specify that we verify the peer sent the same commit. We just
364        // assume this to be the case.
365        self.sync += 1;
366        self.sc += 1;
367        let confirm =
368            compute_confirm(&self.config, &self.kck, self.sc, &self.commit, &self.peer_commit)?;
369        let group_id = (self.config.fcg)()?.group_id();
370        sink.push(SaeUpdate::SendFrame(write_commit(
371            group_id,
372            &self.commit.scalar[..],
373            &self.commit.element[..],
374            &[],
375        )));
376        sink.push(SaeUpdate::SendFrame(write_confirm(self.sc, &confirm[..])));
377        sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
378        Ok(())
379    }
380
381    fn handle_confirm(
382        &mut self,
383        sink: &mut SaeUpdateSink,
384        confirm_msg: &ConfirmMsg<'_>,
385    ) -> Result<FrameResult<()>, RejectReason> {
386        let verifier = compute_confirm(
387            &self.config,
388            &self.kck,
389            confirm_msg.send_confirm,
390            &self.peer_commit,
391            &self.commit,
392        )?;
393        if confirm_msg.confirm == &verifier[..] {
394            sink.push(SaeUpdate::CancelTimeout(Timeout::Retransmission));
395            sink.push(SaeUpdate::ResetTimeout(Timeout::KeyExpiration));
396            self.rc = confirm_msg.send_confirm;
397            // We use u16::max_value() where IEEE specifies 2^16 - 1.
398            self.sc = u16::max_value();
399            sink.push(SaeUpdate::Success(self.key.clone()));
400            Ok(FrameResult::Proceed(()))
401        } else {
402            Ok(FrameResult::Drop)
403        }
404    }
405
406    fn handle_timeout(
407        &mut self,
408        sink: &mut SaeUpdateSink,
409        timeout: Timeout,
410    ) -> Result<(), RejectReason> {
411        match timeout {
412            Timeout::Retransmission => {
413                // Resend our confirm message.
414                check_sync(&self.sync)?;
415                self.sync += 1;
416                self.sc += 1;
417                let confirm = compute_confirm(
418                    &self.config,
419                    &self.kck,
420                    self.sc,
421                    &self.commit,
422                    &self.peer_commit,
423                )?;
424                sink.push(SaeUpdate::SendFrame(write_confirm(self.sc, &confirm[..])));
425                sink.push(SaeUpdate::ResetTimeout(Timeout::Retransmission));
426                Ok(())
427            }
428            Timeout::KeyExpiration => {
429                Err(format_err!("Unexpected key expiration timout before PMKSA established.")
430                    .into())
431            }
432        }
433    }
434}
435
436/// IEEE 802.11-2016 12.4.8.6.6
437impl<E> SaeAccepted<E> {
438    // This function does not return a FrameResult because there is no state transition in the
439    // successful case.
440    fn handle_confirm(
441        &mut self,
442        sink: &mut SaeUpdateSink,
443        confirm_msg: &ConfirmMsg<'_>,
444    ) -> Result<(), RejectReason> {
445        check_sync(&self.0.sync)?;
446        // We use u16::max_value() where IEEE specifies 2^16 - 1.
447        if confirm_msg.send_confirm <= self.0.rc || confirm_msg.send_confirm == u16::max_value() {
448            return Ok(());
449        }
450        // If we fail to verify, the message is dropped silently.
451        if let Ok(verifier) = compute_confirm(
452            &self.0.config,
453            &self.0.kck,
454            confirm_msg.send_confirm,
455            &self.0.peer_commit,
456            &self.0.commit,
457        ) {
458            if verifier == &confirm_msg.confirm[..] {
459                self.0.rc = confirm_msg.send_confirm;
460                self.0.sync += 1;
461                let confirm = compute_confirm(
462                    &self.0.config,
463                    &self.0.kck,
464                    self.0.sc,
465                    &self.0.commit,
466                    &self.0.peer_commit,
467                )?;
468                sink.push(SaeUpdate::SendFrame(write_confirm(self.0.sc, &confirm[..])));
469            }
470        }
471        Ok(())
472    }
473
474    fn handle_timeout(&mut self, timeout: Timeout) -> Result<(), RejectReason> {
475        match timeout {
476            Timeout::Retransmission => {
477                // This is weird, but probably shouldn't kill our PMKSA.
478                error!("Unexpected retransmission timeout after completed SAE handshake.");
479                Ok(())
480            }
481            Timeout::KeyExpiration => Err(RejectReason::KeyExpiration),
482        }
483    }
484}
485
486impl<E> SaeHandshakeState<E> {
487    fn initiate_sae(self, sink: &mut SaeUpdateSink) -> Self {
488        match self {
489            SaeHandshakeState::SaeNew(state) => match state.send_first_commit(sink) {
490                Ok((rand, commit)) => {
491                    let (transition, state) = state.release_data();
492                    transition
493                        .to(SaeCommitted {
494                            config: state.config,
495                            rand,
496                            commit,
497                            sync: 0,
498                            anti_clogging_token: vec![],
499                        })
500                        .into()
501                }
502                Err(reject) => {
503                    sink.push(SaeUpdate::Reject(reject));
504                    state.transition_to(SaeFailed).into()
505                }
506            },
507            _ => {
508                error!("Unexpected call to initiate_sae");
509                self
510            }
511        }
512    }
513
514    fn handle_commit(self, sink: &mut SaeUpdateSink, commit_msg: &CommitMsg<'_>) -> Self {
515        match self {
516            SaeHandshakeState::SaeNew(state) => {
517                match state.handle_commit(sink, commit_msg) {
518                    Ok((_rand, commit, peer_commit, kck, key)) => {
519                        let (transition, state) = state.release_data();
520                        transition
521                            .to(SaeConfirmed {
522                                config: state.config,
523                                commit,
524                                peer_commit,
525                                kck,
526                                key,
527                                sc: 1,
528                                rc: 0,
529                                sync: 0,
530                            })
531                            .into()
532                    }
533                    // We always reject the authentication if the first commit is invalid.
534                    Err(reject) => {
535                        sink.push(SaeUpdate::Reject(reject));
536                        state.transition_to(SaeFailed).into()
537                    }
538                }
539            }
540            SaeHandshakeState::SaeCommitted(state) => match state.handle_commit(sink, commit_msg) {
541                Ok(FrameResult::Proceed((peer_commit, kck, key))) => {
542                    let (transition, committed) = state.release_data();
543                    let confirmed = SaeConfirmed {
544                        config: committed.config,
545                        commit: committed.commit,
546                        peer_commit,
547                        kck,
548                        key,
549                        sc: 1,
550                        rc: 0,
551                        sync: committed.sync,
552                    };
553                    transition.to(confirmed).into()
554                }
555                Ok(FrameResult::Drop) => state.into(),
556                Err(reject) => {
557                    sink.push(SaeUpdate::Reject(reject));
558                    state.transition_to(SaeFailed).into()
559                }
560            },
561            SaeHandshakeState::SaeConfirmed(mut state) => {
562                match state.handle_commit(sink, commit_msg) {
563                    Ok(()) => state.into(),
564                    Err(reject) => {
565                        sink.push(SaeUpdate::Reject(reject));
566                        state.transition_to(SaeFailed).into()
567                    }
568                }
569            }
570            _ => {
571                warn!("Unexpected SAE commit received");
572                self
573            }
574        }
575    }
576
577    fn handle_confirm(self, sink: &mut SaeUpdateSink, confirm_msg: &ConfirmMsg<'_>) -> Self {
578        match self {
579            SaeHandshakeState::SaeCommitted(mut state) => {
580                match state.handle_confirm(sink, confirm_msg) {
581                    Ok(()) => state.into(),
582                    Err(reject) => {
583                        sink.push(SaeUpdate::Reject(reject));
584                        state.transition_to(SaeFailed).into()
585                    }
586                }
587            }
588            SaeHandshakeState::SaeConfirmed(mut state) => {
589                match state.handle_confirm(sink, confirm_msg) {
590                    Ok(FrameResult::Proceed(())) => {
591                        let (transition, state) = state.release_data();
592                        transition.to(SaeAccepted(state)).into()
593                    }
594                    Ok(FrameResult::Drop) => state.into(),
595                    Err(e) => {
596                        sink.push(SaeUpdate::Reject(e.into()));
597                        state.transition_to(SaeFailed).into()
598                    }
599                }
600            }
601            SaeHandshakeState::SaeAccepted(mut state) => {
602                match state.handle_confirm(sink, confirm_msg) {
603                    Ok(()) => state.into(),
604                    Err(reject) => {
605                        sink.push(SaeUpdate::Reject(reject));
606                        state.transition_to(SaeFailed).into()
607                    }
608                }
609            }
610            _ => {
611                warn!("Unexpected SAE confirm received");
612                self
613            }
614        }
615    }
616
617    fn handle_anti_clogging_token(
618        self,
619        sink: &mut SaeUpdateSink,
620        act_msg: &AntiCloggingTokenMsg<'_>,
621    ) -> Self {
622        match self {
623            SaeHandshakeState::SaeCommitted(mut state) => {
624                match state.handle_anti_clogging_token(sink, act_msg) {
625                    Ok(()) => state.into(),
626                    Err(reject) => {
627                        sink.push(SaeUpdate::Reject(reject));
628                        state.transition_to(SaeFailed).into()
629                    }
630                }
631            }
632            _ => {
633                error!("Unexpected anti clogging token received");
634                self
635            }
636        }
637    }
638
639    fn handle_timeout(self, sink: &mut SaeUpdateSink, timeout: Timeout) -> Self {
640        match self {
641            SaeHandshakeState::SaeCommitted(mut state) => {
642                match state.handle_timeout(sink, timeout) {
643                    Ok(()) => state.into(),
644                    Err(reject) => {
645                        sink.push(SaeUpdate::Reject(reject));
646                        state.transition_to(SaeFailed).into()
647                    }
648                }
649            }
650            SaeHandshakeState::SaeConfirmed(mut state) => {
651                match state.handle_timeout(sink, timeout) {
652                    Ok(()) => state.into(),
653                    Err(reject) => {
654                        sink.push(SaeUpdate::Reject(reject));
655                        state.transition_to(SaeFailed).into()
656                    }
657                }
658            }
659            SaeHandshakeState::SaeAccepted(mut state) => match state.handle_timeout(timeout) {
660                Ok(()) => state.into(),
661                Err(reject) => {
662                    sink.push(SaeUpdate::Reject(reject));
663                    state.transition_to(SaeFailed).into()
664                }
665            },
666            _ => {
667                error!("Unexpected SAE timeout triggered");
668                self
669            }
670        }
671    }
672}
673
674pub struct SaeHandshakeImpl<E>(StateMachine<SaeHandshakeState<E>>);
675
676impl<E> SaeHandshakeImpl<E> {
677    pub fn new(fcg_constructor: FcgConstructor<E>, params: SaeParameters) -> Result<Self, Error> {
678        let fcg = fcg_constructor()?;
679        let pwe = fcg.element_to_octets(&fcg.generate_pwe(&params)?)?;
680        Ok(Self(StateMachine::new(SaeHandshakeState::from(State::new(SaeNew {
681            config: SaeConfiguration { fcg: fcg_constructor, params, pwe },
682        })))))
683    }
684}
685
686impl<E> SaeHandshake for SaeHandshakeImpl<E> {
687    fn initiate_sae(&mut self, sink: &mut SaeUpdateSink) {
688        self.0.replace_state(|state| state.initiate_sae(sink));
689    }
690
691    fn handle_commit(&mut self, sink: &mut SaeUpdateSink, commit_msg: &CommitMsg<'_>) {
692        self.0.replace_state(|state| state.handle_commit(sink, commit_msg));
693    }
694
695    fn handle_confirm(&mut self, sink: &mut SaeUpdateSink, confirm_msg: &ConfirmMsg<'_>) {
696        self.0.replace_state(|state| state.handle_confirm(sink, confirm_msg));
697    }
698
699    fn handle_anti_clogging_token(
700        &mut self,
701        sink: &mut SaeUpdateSink,
702        act_msg: &AntiCloggingTokenMsg<'_>,
703    ) {
704        self.0.replace_state(|state| state.handle_anti_clogging_token(sink, act_msg));
705    }
706
707    fn handle_timeout(&mut self, sink: &mut SaeUpdateSink, timeout: Timeout) {
708        self.0.replace_state(|state| state.handle_timeout(sink, timeout));
709    }
710}
711
712// Most testing is done in sae/mod.rs, so we only test internal functions here.
713#[cfg(test)]
714mod test {
715    use super::*;
716    use crate::boringssl::{Bignum, EcGroupId};
717    use crate::hmac_utils::HmacUtilsImpl;
718    use crate::{ecc, PweMethod};
719    use hex::FromHex;
720    use ieee80211::{MacAddr, Ssid};
721    use lazy_static::lazy_static;
722    use mundane::hash::Sha256;
723    use std::convert::TryFrom;
724    use wlan_common::assert_variant;
725
726    // IEEE Std 802.11-18/1104r0: "New Test Vectors for SAE" provides all of these values.
727    // TEST_PWD is slightly modified by concatenating the password identifier field; IEEE Std
728    // 802.11-2020 specifies that a password identifier may not be used with PWE generation by
729    // looping.
730    lazy_static! {
731        static ref TEST_STA_A: MacAddr = MacAddr::from([0x82, 0x7b, 0x91, 0x9d, 0xd4, 0xb9]);
732        static ref TEST_STA_B: MacAddr = MacAddr::from([0x1e, 0xec, 0x49, 0xea, 0x64, 0x88]);
733    }
734
735    const TEST_GROUP: EcGroupId = EcGroupId::P256;
736    const TEST_SSID: &'static str = "SSID in from 802.11-18/r1104r0";
737    const TEST_PWD: &'static str = "mekmitasdigoatpsk4internet";
738    const TEST_RAND_A: &'static str =
739        "a906f61e4d3a5d4eb2965ff34cf917dd044445c878c17ca5d5b93786da9f83cf";
740    const TEST_SCALAR_A: &'static str =
741        "eb3bab1964e4a0ab05925ddf3339519138bc65d6cdc0f813dd6fd4344eb4bfe4";
742    const TEST_ELEMENT_A: &'static str = "4b5c21597658f4e3eddfb4b99f25b4d6540f32ff1fd5c530c60a794448610bc6de3d92bdbbd47d935980ca6cf8988ab6630be6764c885ceb9793970f695217ee";
743    const TEST_CONFIRM_A: &'static str =
744        "12d9d5c78c500526d36c41dbc56aedf2914cedddd7cad4a58c48f83dbde9fc77";
745    const TEST_RAND_B: &'static str =
746        "a47d07bbd3d1b618b325dfde02413a450a90fd1ee1ac35f4d3856cc9cb77128c";
747    const TEST_SCALAR_B: &'static str =
748        "5564f045b2ea1e566cf1dd741f70d9be35d2df5b9a5502946ee03cf8dae27e1e";
749    const TEST_ELEMENT_B: &'static str = "05b8430eb7a99e24877ce69baf3dc580e309633d6b385f83ee1c3ec3591f1a5393c06e805ddceb2fde50930dd7cfebb987c6ff9666af164eb5184d8e6662ed6a";
750    const TEST_CONFIRM_B: &'static str =
751        "02871cf906898b8060ec184143be77b8c08a8019b13eb6d0aef0d8383dfac2fd";
752    const KEY_PMK: &'static str =
753        "7aead86fba4c3221fc437f5f14d70d854ea5d5aac1690116793081eda4d557c5";
754    const KEY_KCK: &'static str =
755        "599d6f1e27548be8499dceed2feccf94818ce1c79f1b4eb3d6a53228a09bf3ed";
756    const KEY_PMKID: &'static str = "40a09b6017cebf0072843b5352aa2b4f";
757
758    fn make_ecc_config() -> SaeConfiguration<<ecc::Group as FiniteCyclicGroup>::Element> {
759        let params = SaeParameters {
760            hmac: Box::new(HmacUtilsImpl::<Sha256>::new()),
761            pwe_method: PweMethod::Loop,
762            ssid: Ssid::try_from(TEST_SSID).unwrap(),
763            password: Vec::from(TEST_PWD),
764            password_id: None, // Cannot be used with PweMethod::Loop
765            sta_a_mac: *TEST_STA_A,
766            sta_b_mac: *TEST_STA_B,
767        };
768        let fcg_constructor = Box::new(|| {
769            ecc::Group::new(TEST_GROUP).map(|group| {
770                Box::new(group)
771                    as Box<
772                        dyn FiniteCyclicGroup<Element = <ecc::Group as FiniteCyclicGroup>::Element>,
773                    >
774            })
775        });
776        let fcg = (fcg_constructor)().unwrap();
777        let pwe = fcg.element_to_octets(&fcg.generate_pwe(&params).unwrap()).unwrap();
778        SaeConfiguration { fcg: fcg_constructor, params, pwe }
779    }
780
781    fn make_commit<E>(config: &SaeConfiguration<E>, scalar: &str, element: &str) -> Commit<E> {
782        let scalar = Bignum::new_from_slice(&Vec::from_hex(scalar).unwrap()[..]).unwrap();
783        let element = (config.fcg)()
784            .unwrap()
785            .element_from_octets(&Vec::from_hex(element).unwrap()[..])
786            .unwrap()
787            .unwrap();
788        Commit { scalar, element }
789    }
790
791    fn expected_kck() -> Kck {
792        Kck(Vec::from_hex(KEY_KCK).unwrap())
793    }
794
795    fn expected_key() -> Key {
796        Key { pmk: Vec::from_hex(KEY_PMK).unwrap(), pmkid: Vec::from_hex(KEY_PMKID).unwrap() }
797    }
798
799    #[test]
800    fn process_commit_success_sta_a() {
801        let config = make_ecc_config();
802        let commit_a = make_commit(&config, TEST_SCALAR_A, TEST_ELEMENT_A);
803
804        let rand_a = Bignum::new_from_slice(&Vec::from_hex(TEST_RAND_A).unwrap()[..]).unwrap();
805        let scalar_b = Vec::from_hex(TEST_SCALAR_B).unwrap();
806        let element_b = Vec::from_hex(TEST_ELEMENT_B).unwrap();
807
808        let result =
809            process_commit(&config, &rand_a, &commit_a, &scalar_b[..], &element_b[..]).unwrap();
810        let (_peer_commit_a, kck, key) = assert_variant!(result, FrameResult::Proceed(res) => res);
811
812        assert_eq!(kck, expected_kck());
813        assert_eq!(key, expected_key());
814    }
815
816    #[test]
817    fn process_commit_success_sta_b() {
818        let config = make_ecc_config();
819        let commit_b = make_commit(&config, TEST_SCALAR_B, TEST_ELEMENT_B);
820
821        let rand_b = Bignum::new_from_slice(&Vec::from_hex(TEST_RAND_B).unwrap()[..]).unwrap();
822        let scalar_a = Vec::from_hex(TEST_SCALAR_A).unwrap();
823        let element_a = Vec::from_hex(TEST_ELEMENT_A).unwrap();
824
825        let result =
826            process_commit(&config, &rand_b, &commit_b, &scalar_a[..], &element_a[..]).unwrap();
827        let (_peer_commit_b, kck, key) = assert_variant!(result, FrameResult::Proceed(res) => res);
828
829        assert_eq!(kck, expected_kck());
830        assert_eq!(key, expected_key());
831    }
832
833    #[test]
834    fn process_commit_fails_bad_peer_element() {
835        let config = make_ecc_config();
836        let commit_a = make_commit(&config, TEST_SCALAR_A, TEST_ELEMENT_A);
837
838        let rand_a = Bignum::new_from_slice(&Vec::from_hex(TEST_RAND_A).unwrap()[..]).unwrap();
839        let scalar_b = Vec::from_hex(TEST_SCALAR_B).unwrap();
840        let mut element_b = Vec::from_hex(TEST_ELEMENT_B).unwrap();
841        element_b[0] += 1;
842
843        let result =
844            process_commit(&config, &rand_a, &commit_a, &scalar_b[..], &element_b[..]).unwrap();
845        assert_variant!(result, FrameResult::Drop);
846    }
847
848    #[test]
849    fn test_compute_confirm() {
850        let config = make_ecc_config();
851        let commit_a =
852            make_commit(&config, TEST_SCALAR_A, TEST_ELEMENT_A).serialize(&config).unwrap();
853        let commit_b =
854            make_commit(&config, TEST_SCALAR_B, TEST_ELEMENT_B).serialize(&config).unwrap();
855        let kck = expected_kck();
856
857        let confirm_a = compute_confirm(&config, &kck, 1, &commit_a, &commit_b).unwrap();
858        let expected_confirm_a = Vec::from_hex(TEST_CONFIRM_A).unwrap();
859        assert_eq!(confirm_a, expected_confirm_a);
860
861        let confirm_b = compute_confirm(&config, &kck, 1, &commit_b, &commit_a).unwrap();
862        let expected_confirm_b = Vec::from_hex(TEST_CONFIRM_B).unwrap();
863        assert_eq!(confirm_b, expected_confirm_b);
864    }
865}