directed_graph/
lib.rs

1// Copyright 2020 The Fuchsia Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10/// A directed graph, whose nodes contain an identifier of type `T`.
11pub struct DirectedGraph<T: PartialEq + Hash + Copy + Ord + Debug + Display>(
12    HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedGraph<T> {
16    /// Created a new empty `DirectedGraph`.
17    pub fn new() -> Self {
18        Self(HashMap::new())
19    }
20
21    /// Add an edge to the graph, adding nodes if necessary.
22    pub fn add_edge(&mut self, source: T, target: T) {
23        self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target);
24        self.0.entry(target).or_insert_with(DirectedNode::new);
25    }
26
27    /// Get targets of all edges from this node.
28    pub fn get_targets(&self, id: T) -> Option<&HashSet<T>> {
29        self.0.get(&id).as_ref().map(|node| &node.0)
30    }
31
32    /// Given a dependency graph represented as a set of `edges`, find the set of
33    /// all nodes that the `start` node depends on, directly or indirectly. This
34    /// includes `start` itself.
35    pub fn get_closure(&self, start: T) -> HashSet<T> {
36        let mut res = HashSet::new();
37        res.insert(start);
38        loop {
39            let mut entries_added = false;
40
41            for source in self.0.keys() {
42                match self.get_targets(*source) {
43                    None => continue,
44                    Some(targets) if targets.is_empty() => continue,
45                    Some(targets) => {
46                        for target in targets {
47                            if !res.contains(target) {
48                                continue;
49                            }
50                            if res.insert(source.clone()) {
51                                entries_added = true
52                            }
53                        }
54                    }
55                }
56            }
57
58            if !entries_added {
59                return res;
60            }
61        }
62    }
63
64    /// Returns the nodes of the graph in reverse topological order, or an error if the graph
65    /// contains a cycle.
66    pub fn topological_sort(&self) -> Result<Vec<T>, Error<T>> {
67        TarjanSCC::new(self).run()
68    }
69
70    /// Finds the shortest path between the `from` and `to` nodes in this graph, if such a path
71    /// exists. Both `from` and `to` are included in the returned path.
72    pub fn find_shortest_path(&self, from: T, to: T) -> Option<Vec<T>> {
73        // Keeps track of edges in the shortest path to each node.
74        //
75        // The key in this map is a node whose shortest path to it is known. The value
76        // is the next-to-last node in the shortest path to the key node.
77        //
78        // For example, if the shortest path from `a` to `b` is `{a, b, c}`, this
79        // map will contain:
80        // (c, b)
81        // (b, a)
82        let mut shortest_path_edges: HashMap<T, T> = HashMap::new();
83
84        // Nodes which we have found in the graph but have not yet been visited.
85        let mut discovered_nodes = VecDeque::new();
86        discovered_nodes.push_back(from);
87
88        loop {
89            // Visit the first node in the list.
90            let Some(current_node) = discovered_nodes.pop_front() else {
91                // If there are no more nodes to visit, then a shortest path must not exist.
92                return None;
93            };
94            match self.get_targets(current_node) {
95                None => continue,
96                Some(targets) if targets.is_empty() => continue,
97                Some(targets) => {
98                    for target in targets {
99                        // If we haven't yet visited this node, add it to our set of edges and add
100                        // it to the set of nodes we should visit.
101                        if !shortest_path_edges.contains_key(target) {
102                            shortest_path_edges.insert(*target, current_node);
103                            discovered_nodes.push_back(*target);
104                        }
105                        // If this node is the node we're searching for a path to, then compute the
106                        // path based on the hashmap we've built and return it.
107                        if *target == to {
108                            let mut result = vec![*target];
109                            let mut path_node: T = *target;
110                            loop {
111                                path_node = *shortest_path_edges.get(&path_node).unwrap();
112                                result.push(path_node);
113                                if path_node == from {
114                                    break;
115                                }
116                            }
117                            result.reverse();
118                            return Some(result);
119                        }
120                    }
121                }
122            }
123        }
124    }
125}
126
127impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Default for DirectedGraph<T> {
128    fn default() -> Self {
129        Self(HashMap::new())
130    }
131}
132
133/// A graph node. Contents contain the nodes mapped by edges from this node.
134#[derive(Eq, PartialEq)]
135pub struct DirectedNode<T: PartialEq + Hash + Copy + Ord + Debug + Display>(HashSet<T>);
136
137impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedNode<T> {
138    /// Create an empty node.
139    pub fn new() -> Self {
140        Self(HashSet::new())
141    }
142
143    /// Add edge from this node to `target`.
144    pub fn add_target(&mut self, target: T) {
145        self.0.insert(target);
146    }
147}
148
149/// Errors produced by `DirectedGraph`.
150#[derive(Debug)]
151pub enum Error<T: PartialEq + Hash + Copy + Ord + Debug + Display> {
152    CyclesDetected(HashSet<Vec<T>>),
153}
154
155impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Error<T> {
156    pub fn format_cycle(&self) -> String {
157        match &self {
158            Error::CyclesDetected(cycles) => {
159                // Copy the cycles into a vector and sort them so our output is stable
160                let mut cycles: Vec<_> = cycles.iter().cloned().collect();
161                cycles.sort_unstable();
162
163                let mut output = "{".to_string();
164                for cycle in cycles.iter() {
165                    output.push_str("{");
166                    for item in cycle.iter() {
167                        output.push_str(&format!("{} -> ", item));
168                    }
169                    if !cycle.is_empty() {
170                        output.truncate(output.len() - 4);
171                    }
172                    output.push_str("}, ");
173                }
174                if !cycles.is_empty() {
175                    output.truncate(output.len() - 2);
176                }
177                output.push_str("}");
178                output
179            }
180        }
181    }
182}
183
184/// Runs the tarjan strongly connected components algorithm on a graph to produce either a reverse
185/// topological sort of the nodes in the graph, or a set of the cycles present in the graph.
186///
187/// Description of algorithm:
188/// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
189struct TarjanSCC<'a, T: PartialEq + Hash + Copy + Ord + Debug + Display> {
190    // Each node is assigned an index in the order we find them. This tracks the next index to use.
191    index: u64,
192    // The mappings between nodes and indices
193    indices: HashMap<T, u64>,
194    // The lowest index (numerically) that's accessible from each node
195    low_links: HashMap<T, u64>,
196    // The set of nodes we're currently in the process of considering
197    stack: Vec<T>,
198    // A set containing the nodes in the stack, so we can more efficiently check if an element is
199    // in the stack
200    on_stack: HashSet<T>,
201    // Detected cycles
202    cycles: HashSet<Vec<T>>,
203    // Nodes sorted by reverse topological order
204    node_order: Vec<T>,
205    // The graph this run will be operating on
206    graph: &'a DirectedGraph<T>,
207}
208
209impl<'a, T: Hash + Copy + Ord + Debug + Display> TarjanSCC<'a, T> {
210    fn new(graph: &'a DirectedGraph<T>) -> Self {
211        TarjanSCC {
212            index: 0,
213            indices: HashMap::new(),
214            low_links: HashMap::new(),
215            stack: Vec::new(),
216            on_stack: HashSet::new(),
217            cycles: HashSet::new(),
218            node_order: Vec::new(),
219            graph,
220        }
221    }
222
223    /// Runs the tarjan scc algorithm. Must only be called once, as it will panic on subsequent
224    /// calls.
225    fn run(mut self) -> Result<Vec<T>, Error<T>> {
226        // Sort the nodes we visit, to make the output deterministic instead of being based on
227        // whichever node we find first.
228        let mut nodes: Vec<_> = self.graph.0.keys().cloned().collect();
229        nodes.sort_unstable();
230        for node in &nodes {
231            // Iterate over each node, visiting each one we haven't already visited. We determine
232            // if a node has been visited by if an index has been assigned to it yet.
233            if !self.indices.contains_key(node) {
234                self.visit(*node);
235            }
236        }
237
238        if self.cycles.is_empty() {
239            Ok(self.node_order.drain(..).collect())
240        } else {
241            Err(Error::CyclesDetected(self.cycles.drain().collect()))
242        }
243    }
244
245    fn visit(&mut self, current_node: T) {
246        // assign a new index for this node, and push it on to the stack
247        self.indices.insert(current_node, self.index);
248        self.low_links.insert(current_node, self.index);
249        self.index += 1;
250        self.stack.push(current_node);
251        self.on_stack.insert(current_node);
252
253        let mut targets: Vec<_> = self.graph.0[&current_node].0.iter().cloned().collect();
254        targets.sort_unstable();
255
256        for target in &targets {
257            if !self.indices.contains_key(target) {
258                // Target has not yet been visited; recurse on it
259                self.visit(*target);
260                // Set our lowlink to the min of our lowlink and the target's new lowlink
261                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
262                let target_low_link = *self.low_links.get(&target).unwrap();
263                self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
264            } else if self.on_stack.contains(target) {
265                let current_node_low_link = *self.low_links.get(&current_node).unwrap();
266                let target_index = *self.indices.get(&target).unwrap();
267                self.low_links.insert(current_node, min(current_node_low_link, target_index));
268            }
269        }
270
271        // If current_node is a root node, pop the stack and generate an SCC
272        if self.low_links.get(&current_node) == self.indices.get(&current_node) {
273            let mut strongly_connected_nodes = HashSet::new();
274            let mut stack_node;
275            loop {
276                stack_node = self.stack.pop().unwrap();
277                self.on_stack.remove(&stack_node);
278                strongly_connected_nodes.insert(stack_node);
279                if stack_node == current_node {
280                    break;
281                }
282            }
283            self.insert_cycles_from_scc(
284                &strongly_connected_nodes,
285                stack_node,
286                HashSet::new(),
287                vec![],
288            );
289        }
290        self.node_order.push(current_node);
291    }
292
293    /// Given a set of strongly connected components, computes the cycles present in the set and
294    /// adds those cycles to self.cycles.
295    fn insert_cycles_from_scc(
296        &mut self,
297        scc_nodes: &HashSet<T>,
298        current_node: T,
299        mut visited_nodes: HashSet<T>,
300        mut path: Vec<T>,
301    ) {
302        if visited_nodes.contains(&current_node) {
303            // We've already visited this node, we've got a cycle. Grab all the elements in the
304            // path starting at the first time we visited this node.
305            let (current_node_path_index, _) =
306                path.iter().enumerate().find(|(_, val)| val == &&current_node).unwrap();
307            let mut cycle = path[current_node_path_index..].to_vec();
308
309            // Rotate the cycle such that the lowest value comes first, so that the cycles we
310            // report are consistent.
311            Self::rotate_cycle(&mut cycle);
312            // Push a copy of the first node on to the end, so it's clear that this path ends where
313            // it starts
314            cycle.push(*cycle.first().unwrap());
315            self.cycles.insert(cycle);
316            return;
317        }
318
319        visited_nodes.insert(current_node);
320        path.push(current_node);
321
322        let targets_in_scc: Vec<_> =
323            self.graph.0[&current_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
324        for target in targets_in_scc {
325            self.insert_cycles_from_scc(scc_nodes, *target, visited_nodes.clone(), path.clone());
326        }
327    }
328
329    /// Rotates the cycle such that ordering is maintained and the lowest element comes first. This
330    /// is so that the reported cycles are consistent, as opposed to varying based on which node we
331    /// happened to find first.
332    fn rotate_cycle(cycle: &mut Vec<T>) {
333        let mut lowest_index = 0;
334        let mut lowest_value = cycle.first().unwrap();
335        for (index, node) in cycle.iter().enumerate() {
336            if node < lowest_value {
337                lowest_index = index;
338                lowest_value = node;
339            }
340        }
341        cycle.rotate_left(lowest_index);
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    macro_rules! test_topological_sort {
350        (
351            $(
352                $test_name:ident => {
353                    edges = $edges:expr,
354                    order = $order:expr,
355                },
356            )+
357        ) => {
358            $(
359                #[test]
360                fn $test_name() {
361                    topological_sort_test(&$edges, &$order);
362                }
363            )+
364        }
365    }
366
367    macro_rules! test_cycles {
368        (
369            $(
370                $test_name:ident => {
371                    edges = $edges:expr,
372                    cycles = $cycles:expr,
373                },
374            )+
375        ) => {
376            $(
377                #[test]
378                fn $test_name() {
379                    cycles_test(&$edges, &$cycles);
380                }
381            )+
382        }
383    }
384
385    macro_rules! test_shortest_path {
386        (
387            $(
388                $test_name:ident => {
389                    edges = $edges:expr,
390                    from = $from:expr,
391                    to = $to:expr,
392                    shortest_path = $shortest_path:expr,
393                },
394            )+
395        ) => {
396            $(
397                #[test]
398                fn $test_name() {
399                    shortest_path_test($edges, $from, $to, $shortest_path);
400                }
401            )+
402        }
403    }
404
405    fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
406        let mut graph = DirectedGraph::new();
407        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
408        let actual_order = graph.topological_sort().expect("found a cycle");
409
410        let expected_order: Vec<_> = order.iter().cloned().collect();
411        assert_eq!(expected_order, actual_order);
412    }
413
414    fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
415        let mut graph = DirectedGraph::new();
416        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
417        let Error::CyclesDetected(reported_cycles) = graph
418            .topological_sort()
419            .expect_err("topological sort succeeded on a dataset with a cycle");
420
421        let expected_cycles: HashSet<Vec<_>> =
422            cycles.iter().cloned().map(|c| c.iter().cloned().collect()).collect();
423        assert_eq!(reported_cycles, expected_cycles);
424    }
425
426    fn shortest_path_test(
427        edges: &[(&'static str, &'static str)],
428        from: &'static str,
429        to: &'static str,
430        expected_shortest_path: Option<&[&'static str]>,
431    ) {
432        let mut graph = DirectedGraph::new();
433        edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
434        let actual_shortest_path = graph.find_shortest_path(from, to);
435        let expected_shortest_path =
436            expected_shortest_path.map(|path| path.iter().cloned().collect::<Vec<_>>());
437        assert_eq!(actual_shortest_path, expected_shortest_path);
438    }
439
440    // Tests with no cycles
441
442    test_topological_sort! {
443        test_empty => {
444            edges = [],
445            order = [],
446        },
447        test_fan_out => {
448            edges = [
449                ("a", "b"),
450                ("b", "c"),
451                ("b", "d"),
452                ("d", "e"),
453            ],
454            order = ["c", "e", "d", "b", "a"],
455        },
456        test_fan_in => {
457            edges = [
458                ("a", "b"),
459                ("b", "d"),
460                ("c", "d"),
461                ("d", "e"),
462            ],
463            order = ["e", "d", "b", "a", "c"],
464        },
465        test_forest => {
466            edges = [
467                ("a", "b"),
468                ("b", "c"),
469                ("d", "e"),
470            ],
471            order = ["c", "b", "a", "e", "d"],
472        },
473        test_diamond => {
474            edges = [
475                ("a", "b"),
476                ("a", "c"),
477                ("b", "d"),
478                ("c", "d"),
479            ],
480            order = ["d", "b", "c", "a"],
481        },
482        test_lattice => {
483            edges = [
484                ("a", "b"),
485                ("a", "c"),
486                ("b", "d"),
487                ("b", "e"),
488                ("c", "d"),
489                ("e", "f"),
490                ("d", "f"),
491            ],
492            order = ["f", "d", "e", "b", "c", "a"],
493        },
494        test_deduped_edge => {
495            edges = [
496                ("a", "b"),
497                ("a", "b"),
498                ("b", "c"),
499            ],
500            order = ["c", "b", "a"],
501        },
502    }
503
504    test_cycles! {
505        // Tests where only 1 SCC contains cycles
506
507        test_cycle_self_referential => {
508            edges = [
509                ("a", "a"),
510            ],
511            cycles = [
512                &["a", "a"],
513            ],
514        },
515        test_cycle_two_nodes => {
516            edges = [
517                ("a", "b"),
518                ("b", "a"),
519            ],
520            cycles = [
521                &["a", "b", "a"],
522            ],
523        },
524        test_cycle_two_nodes_with_path_in => {
525            edges = [
526                ("a", "b"),
527                ("b", "c"),
528                ("c", "d"),
529                ("d", "c"),
530            ],
531            cycles = [
532                &["c", "d", "c"],
533            ],
534        },
535        test_cycle_two_nodes_with_path_out => {
536            edges = [
537                ("a", "b"),
538                ("b", "a"),
539                ("b", "c"),
540                ("c", "d"),
541            ],
542            cycles = [
543                &["a", "b", "a"],
544            ],
545        },
546        test_cycle_three_nodes => {
547            edges = [
548                ("a", "b"),
549                ("b", "c"),
550                ("c", "a"),
551            ],
552            cycles = [
553                &["a", "b", "c", "a"],
554            ],
555        },
556        test_cycle_three_nodes_with_inner_cycle => {
557            edges = [
558                ("a", "b"),
559                ("b", "c"),
560                ("c", "b"),
561                ("c", "a"),
562            ],
563            cycles = [
564                &["a", "b", "c", "a"],
565                &["b", "c", "b"],
566            ],
567        },
568        test_cycle_three_nodes_doubly_linked => {
569            edges = [
570                ("a", "b"),
571                ("b", "a"),
572                ("b", "c"),
573                ("c", "b"),
574                ("c", "a"),
575                ("a", "c"),
576            ],
577            cycles = [
578                &["a", "b", "a"],
579                &["b", "c", "b"],
580                &["a", "c", "a"],
581                &["a", "b", "c", "a"],
582                &["a", "c", "b", "a"],
583            ],
584        },
585        test_cycle_with_inner_cycle => {
586            edges = [
587                ("a", "b"),
588                ("b", "c"),
589                ("c", "a"),
590
591                ("b", "d"),
592                ("d", "e"),
593                ("e", "c"),
594            ],
595            cycles = [
596                &["a", "b", "c", "a"],
597                &["a", "b", "d", "e", "c", "a"],
598            ],
599        },
600        test_two_join_cycles => {
601            edges = [
602                ("a", "b"),
603                ("b", "c"),
604                ("c", "a"),
605                ("b", "d"),
606                ("d", "a"),
607            ],
608            cycles = [
609                &["a", "b", "c", "a"],
610                &["a", "b", "d", "a"],
611            ],
612        },
613        test_cycle_four_nodes_doubly_linked => {
614            edges = [
615                ("a", "b"),
616                ("b", "a"),
617                ("b", "c"),
618                ("c", "b"),
619                ("c", "d"),
620                ("d", "c"),
621                ("d", "a"),
622                ("a", "d"),
623            ],
624            cycles = [
625                &["a", "b", "c", "d", "a"],
626                &["a", "b", "a"],
627                &["a", "d", "c", "b", "a"],
628                &["a", "d", "a"],
629                &["b", "c", "b"],
630                &["c", "d", "c"],
631            ],
632        },
633
634        // Tests with multiple SCCs that contain cycles
635
636        test_cycle_self_referential_islands => {
637            edges = [
638                ("a", "a"),
639                ("b", "b"),
640                ("c", "c"),
641                ("d", "e"),
642            ],
643            cycles = [
644                &["a", "a"],
645                &["b", "b"],
646                &["c", "c"],
647            ],
648        },
649        test_cycle_two_sets_of_two_nodes => {
650            edges = [
651                ("a", "b"),
652                ("b", "a"),
653                ("c", "d"),
654                ("d", "c"),
655            ],
656            cycles = [
657                &["a", "b", "a"],
658                &["c", "d", "c"],
659            ],
660        },
661        test_cycle_two_sets_of_two_nodes_connected => {
662            edges = [
663                ("a", "b"),
664                ("b", "a"),
665                ("c", "d"),
666                ("d", "c"),
667                ("a", "c"),
668            ],
669            cycles = [
670                &["a", "b", "a"],
671                &["c", "d", "c"],
672            ],
673        },
674    }
675
676    test_shortest_path! {
677        test_empty_graph => {
678            edges = &[],
679            from = "a",
680            to = "b",
681            shortest_path = None,
682        },
683        test_two_nodes => {
684            edges = &[
685                ("a", "b"),
686            ],
687            from = "a",
688            to = "b",
689            shortest_path = Some(&["a", "b"]),
690        },
691        test_path_to_self => {
692            edges = &[
693                ("a", "a"),
694            ],
695            from = "a",
696            to = "a",
697            shortest_path = Some(&["a", "a"]),
698        },
699        test_path_to_self_no_edge => {
700            edges = &[
701                ("a", "b"),
702            ],
703            from = "a",
704            to = "a",
705            shortest_path = None,
706        },
707        test_path_three_nodes => {
708            edges = &[
709                ("a", "b"),
710                ("b", "c"),
711            ],
712            from = "a",
713            to = "c",
714            shortest_path = Some(&["a", "b", "c"]),
715        },
716        test_path_multiple_options => {
717            edges = &[
718                ("a", "b"),
719                ("b", "c"),
720                ("a", "c"),
721            ],
722            from = "a",
723            to = "c",
724            shortest_path = Some(&["a", "c"]),
725        },
726        test_path_two_islands => {
727            edges = &[
728                ("a", "b"),
729                ("c", "d"),
730            ],
731            from = "a",
732            to = "d",
733            shortest_path = None,
734        },
735        test_path_with_cycle => {
736            edges = &[
737                ("a", "b"),
738                ("b", "a"),
739            ],
740            from = "a",
741            to = "b",
742            shortest_path = Some(&["a", "b"]),
743        },
744        test_path_with_cycle_2 => {
745            edges = &[
746                ("a", "b"),
747                ("b", "c"),
748                ("c", "b"),
749            ],
750            from = "a",
751            to = "b",
752            shortest_path = Some(&["a", "b"]),
753        },
754        test_path_with_cycle_3 => {
755            edges = &[
756                ("a", "b"),
757                ("b", "c"),
758                ("c", "b"),
759                ("b", "d"),
760                ("d", "e"),
761            ],
762            from = "a",
763            to = "e",
764            shortest_path = Some(&["a", "b", "d", "e"]),
765        },
766    }
767}