1use std::cmp::min;
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10pub struct DirectedGraph<T: PartialEq + Hash + Copy + Ord + Debug + Display>(
12 HashMap<T, DirectedNode<T>>,
13);
14
15impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedGraph<T> {
16 pub fn new() -> Self {
18 Self(HashMap::new())
19 }
20
21 pub fn add_edge(&mut self, source: T, target: T) {
23 self.0.entry(source).or_insert_with(DirectedNode::new).add_target(target);
24 self.0.entry(target).or_insert_with(DirectedNode::new);
25 }
26
27 pub fn get_targets(&self, id: T) -> Option<&HashSet<T>> {
29 self.0.get(&id).as_ref().map(|node| &node.0)
30 }
31
32 pub fn get_closure(&self, start: T) -> HashSet<T> {
36 let mut res = HashSet::new();
37 res.insert(start);
38 loop {
39 let mut entries_added = false;
40
41 for source in self.0.keys() {
42 match self.get_targets(*source) {
43 None => continue,
44 Some(targets) if targets.is_empty() => continue,
45 Some(targets) => {
46 for target in targets {
47 if !res.contains(target) {
48 continue;
49 }
50 if res.insert(source.clone()) {
51 entries_added = true
52 }
53 }
54 }
55 }
56 }
57
58 if !entries_added {
59 return res;
60 }
61 }
62 }
63
64 pub fn topological_sort(&self) -> Result<Vec<T>, Error<T>> {
67 TarjanSCC::new(self).run()
68 }
69
70 pub fn find_shortest_path(&self, from: T, to: T) -> Option<Vec<T>> {
73 let mut shortest_path_edges: HashMap<T, T> = HashMap::new();
83
84 let mut discovered_nodes = VecDeque::new();
86 discovered_nodes.push_back(from);
87
88 loop {
89 let Some(current_node) = discovered_nodes.pop_front() else {
91 return None;
93 };
94 match self.get_targets(current_node) {
95 None => continue,
96 Some(targets) if targets.is_empty() => continue,
97 Some(targets) => {
98 for target in targets {
99 if !shortest_path_edges.contains_key(target) {
102 shortest_path_edges.insert(*target, current_node);
103 discovered_nodes.push_back(*target);
104 }
105 if *target == to {
108 let mut result = vec![*target];
109 let mut path_node: T = *target;
110 loop {
111 path_node = *shortest_path_edges.get(&path_node).unwrap();
112 result.push(path_node);
113 if path_node == from {
114 break;
115 }
116 }
117 result.reverse();
118 return Some(result);
119 }
120 }
121 }
122 }
123 }
124 }
125}
126
127impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Default for DirectedGraph<T> {
128 fn default() -> Self {
129 Self(HashMap::new())
130 }
131}
132
133#[derive(Eq, PartialEq)]
135pub struct DirectedNode<T: PartialEq + Hash + Copy + Ord + Debug + Display>(HashSet<T>);
136
137impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> DirectedNode<T> {
138 pub fn new() -> Self {
140 Self(HashSet::new())
141 }
142
143 pub fn add_target(&mut self, target: T) {
145 self.0.insert(target);
146 }
147}
148
149#[derive(Debug)]
151pub enum Error<T: PartialEq + Hash + Copy + Ord + Debug + Display> {
152 CyclesDetected(HashSet<Vec<T>>),
153}
154
155impl<T: PartialEq + Hash + Copy + Ord + Debug + Display> Error<T> {
156 pub fn format_cycle(&self) -> String {
157 match &self {
158 Error::CyclesDetected(cycles) => {
159 let mut cycles: Vec<_> = cycles.iter().cloned().collect();
161 cycles.sort_unstable();
162
163 let mut output = "{".to_string();
164 for cycle in cycles.iter() {
165 output.push_str("{");
166 for item in cycle.iter() {
167 output.push_str(&format!("{} -> ", item));
168 }
169 if !cycle.is_empty() {
170 output.truncate(output.len() - 4);
171 }
172 output.push_str("}, ");
173 }
174 if !cycles.is_empty() {
175 output.truncate(output.len() - 2);
176 }
177 output.push_str("}");
178 output
179 }
180 }
181 }
182}
183
184struct TarjanSCC<'a, T: PartialEq + Hash + Copy + Ord + Debug + Display> {
190 index: u64,
192 indices: HashMap<T, u64>,
194 low_links: HashMap<T, u64>,
196 stack: Vec<T>,
198 on_stack: HashSet<T>,
201 cycles: HashSet<Vec<T>>,
203 node_order: Vec<T>,
205 graph: &'a DirectedGraph<T>,
207}
208
209impl<'a, T: Hash + Copy + Ord + Debug + Display> TarjanSCC<'a, T> {
210 fn new(graph: &'a DirectedGraph<T>) -> Self {
211 TarjanSCC {
212 index: 0,
213 indices: HashMap::new(),
214 low_links: HashMap::new(),
215 stack: Vec::new(),
216 on_stack: HashSet::new(),
217 cycles: HashSet::new(),
218 node_order: Vec::new(),
219 graph,
220 }
221 }
222
223 fn run(mut self) -> Result<Vec<T>, Error<T>> {
226 let mut nodes: Vec<_> = self.graph.0.keys().cloned().collect();
229 nodes.sort_unstable();
230 for node in &nodes {
231 if !self.indices.contains_key(node) {
234 self.visit(*node);
235 }
236 }
237
238 if self.cycles.is_empty() {
239 Ok(self.node_order.drain(..).collect())
240 } else {
241 Err(Error::CyclesDetected(self.cycles.drain().collect()))
242 }
243 }
244
245 fn visit(&mut self, current_node: T) {
246 self.indices.insert(current_node, self.index);
248 self.low_links.insert(current_node, self.index);
249 self.index += 1;
250 self.stack.push(current_node);
251 self.on_stack.insert(current_node);
252
253 let mut targets: Vec<_> = self.graph.0[¤t_node].0.iter().cloned().collect();
254 targets.sort_unstable();
255
256 for target in &targets {
257 if !self.indices.contains_key(target) {
258 self.visit(*target);
260 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
262 let target_low_link = *self.low_links.get(&target).unwrap();
263 self.low_links.insert(current_node, min(current_node_low_link, target_low_link));
264 } else if self.on_stack.contains(target) {
265 let current_node_low_link = *self.low_links.get(¤t_node).unwrap();
266 let target_index = *self.indices.get(&target).unwrap();
267 self.low_links.insert(current_node, min(current_node_low_link, target_index));
268 }
269 }
270
271 if self.low_links.get(¤t_node) == self.indices.get(¤t_node) {
273 let mut strongly_connected_nodes = HashSet::new();
274 let mut stack_node;
275 loop {
276 stack_node = self.stack.pop().unwrap();
277 self.on_stack.remove(&stack_node);
278 strongly_connected_nodes.insert(stack_node);
279 if stack_node == current_node {
280 break;
281 }
282 }
283 self.insert_cycles_from_scc(
284 &strongly_connected_nodes,
285 stack_node,
286 HashSet::new(),
287 vec![],
288 );
289 }
290 self.node_order.push(current_node);
291 }
292
293 fn insert_cycles_from_scc(
296 &mut self,
297 scc_nodes: &HashSet<T>,
298 current_node: T,
299 mut visited_nodes: HashSet<T>,
300 mut path: Vec<T>,
301 ) {
302 if visited_nodes.contains(¤t_node) {
303 let (current_node_path_index, _) =
306 path.iter().enumerate().find(|(_, val)| val == &¤t_node).unwrap();
307 let mut cycle = path[current_node_path_index..].to_vec();
308
309 Self::rotate_cycle(&mut cycle);
312 cycle.push(*cycle.first().unwrap());
315 self.cycles.insert(cycle);
316 return;
317 }
318
319 visited_nodes.insert(current_node);
320 path.push(current_node);
321
322 let targets_in_scc: Vec<_> =
323 self.graph.0[¤t_node].0.iter().filter(|n| scc_nodes.contains(n)).collect();
324 for target in targets_in_scc {
325 self.insert_cycles_from_scc(scc_nodes, *target, visited_nodes.clone(), path.clone());
326 }
327 }
328
329 fn rotate_cycle(cycle: &mut Vec<T>) {
333 let mut lowest_index = 0;
334 let mut lowest_value = cycle.first().unwrap();
335 for (index, node) in cycle.iter().enumerate() {
336 if node < lowest_value {
337 lowest_index = index;
338 lowest_value = node;
339 }
340 }
341 cycle.rotate_left(lowest_index);
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 macro_rules! test_topological_sort {
350 (
351 $(
352 $test_name:ident => {
353 edges = $edges:expr,
354 order = $order:expr,
355 },
356 )+
357 ) => {
358 $(
359 #[test]
360 fn $test_name() {
361 topological_sort_test(&$edges, &$order);
362 }
363 )+
364 }
365 }
366
367 macro_rules! test_cycles {
368 (
369 $(
370 $test_name:ident => {
371 edges = $edges:expr,
372 cycles = $cycles:expr,
373 },
374 )+
375 ) => {
376 $(
377 #[test]
378 fn $test_name() {
379 cycles_test(&$edges, &$cycles);
380 }
381 )+
382 }
383 }
384
385 macro_rules! test_shortest_path {
386 (
387 $(
388 $test_name:ident => {
389 edges = $edges:expr,
390 from = $from:expr,
391 to = $to:expr,
392 shortest_path = $shortest_path:expr,
393 },
394 )+
395 ) => {
396 $(
397 #[test]
398 fn $test_name() {
399 shortest_path_test($edges, $from, $to, $shortest_path);
400 }
401 )+
402 }
403 }
404
405 fn topological_sort_test(edges: &[(&'static str, &'static str)], order: &[&'static str]) {
406 let mut graph = DirectedGraph::new();
407 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
408 let actual_order = graph.topological_sort().expect("found a cycle");
409
410 let expected_order: Vec<_> = order.iter().cloned().collect();
411 assert_eq!(expected_order, actual_order);
412 }
413
414 fn cycles_test(edges: &[(&'static str, &'static str)], cycles: &[&[&'static str]]) {
415 let mut graph = DirectedGraph::new();
416 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
417 let Error::CyclesDetected(reported_cycles) = graph
418 .topological_sort()
419 .expect_err("topological sort succeeded on a dataset with a cycle");
420
421 let expected_cycles: HashSet<Vec<_>> =
422 cycles.iter().cloned().map(|c| c.iter().cloned().collect()).collect();
423 assert_eq!(reported_cycles, expected_cycles);
424 }
425
426 fn shortest_path_test(
427 edges: &[(&'static str, &'static str)],
428 from: &'static str,
429 to: &'static str,
430 expected_shortest_path: Option<&[&'static str]>,
431 ) {
432 let mut graph = DirectedGraph::new();
433 edges.iter().for_each(|e| graph.add_edge(e.0, e.1));
434 let actual_shortest_path = graph.find_shortest_path(from, to);
435 let expected_shortest_path =
436 expected_shortest_path.map(|path| path.iter().cloned().collect::<Vec<_>>());
437 assert_eq!(actual_shortest_path, expected_shortest_path);
438 }
439
440 test_topological_sort! {
443 test_empty => {
444 edges = [],
445 order = [],
446 },
447 test_fan_out => {
448 edges = [
449 ("a", "b"),
450 ("b", "c"),
451 ("b", "d"),
452 ("d", "e"),
453 ],
454 order = ["c", "e", "d", "b", "a"],
455 },
456 test_fan_in => {
457 edges = [
458 ("a", "b"),
459 ("b", "d"),
460 ("c", "d"),
461 ("d", "e"),
462 ],
463 order = ["e", "d", "b", "a", "c"],
464 },
465 test_forest => {
466 edges = [
467 ("a", "b"),
468 ("b", "c"),
469 ("d", "e"),
470 ],
471 order = ["c", "b", "a", "e", "d"],
472 },
473 test_diamond => {
474 edges = [
475 ("a", "b"),
476 ("a", "c"),
477 ("b", "d"),
478 ("c", "d"),
479 ],
480 order = ["d", "b", "c", "a"],
481 },
482 test_lattice => {
483 edges = [
484 ("a", "b"),
485 ("a", "c"),
486 ("b", "d"),
487 ("b", "e"),
488 ("c", "d"),
489 ("e", "f"),
490 ("d", "f"),
491 ],
492 order = ["f", "d", "e", "b", "c", "a"],
493 },
494 test_deduped_edge => {
495 edges = [
496 ("a", "b"),
497 ("a", "b"),
498 ("b", "c"),
499 ],
500 order = ["c", "b", "a"],
501 },
502 }
503
504 test_cycles! {
505 test_cycle_self_referential => {
508 edges = [
509 ("a", "a"),
510 ],
511 cycles = [
512 &["a", "a"],
513 ],
514 },
515 test_cycle_two_nodes => {
516 edges = [
517 ("a", "b"),
518 ("b", "a"),
519 ],
520 cycles = [
521 &["a", "b", "a"],
522 ],
523 },
524 test_cycle_two_nodes_with_path_in => {
525 edges = [
526 ("a", "b"),
527 ("b", "c"),
528 ("c", "d"),
529 ("d", "c"),
530 ],
531 cycles = [
532 &["c", "d", "c"],
533 ],
534 },
535 test_cycle_two_nodes_with_path_out => {
536 edges = [
537 ("a", "b"),
538 ("b", "a"),
539 ("b", "c"),
540 ("c", "d"),
541 ],
542 cycles = [
543 &["a", "b", "a"],
544 ],
545 },
546 test_cycle_three_nodes => {
547 edges = [
548 ("a", "b"),
549 ("b", "c"),
550 ("c", "a"),
551 ],
552 cycles = [
553 &["a", "b", "c", "a"],
554 ],
555 },
556 test_cycle_three_nodes_with_inner_cycle => {
557 edges = [
558 ("a", "b"),
559 ("b", "c"),
560 ("c", "b"),
561 ("c", "a"),
562 ],
563 cycles = [
564 &["a", "b", "c", "a"],
565 &["b", "c", "b"],
566 ],
567 },
568 test_cycle_three_nodes_doubly_linked => {
569 edges = [
570 ("a", "b"),
571 ("b", "a"),
572 ("b", "c"),
573 ("c", "b"),
574 ("c", "a"),
575 ("a", "c"),
576 ],
577 cycles = [
578 &["a", "b", "a"],
579 &["b", "c", "b"],
580 &["a", "c", "a"],
581 &["a", "b", "c", "a"],
582 &["a", "c", "b", "a"],
583 ],
584 },
585 test_cycle_with_inner_cycle => {
586 edges = [
587 ("a", "b"),
588 ("b", "c"),
589 ("c", "a"),
590
591 ("b", "d"),
592 ("d", "e"),
593 ("e", "c"),
594 ],
595 cycles = [
596 &["a", "b", "c", "a"],
597 &["a", "b", "d", "e", "c", "a"],
598 ],
599 },
600 test_two_join_cycles => {
601 edges = [
602 ("a", "b"),
603 ("b", "c"),
604 ("c", "a"),
605 ("b", "d"),
606 ("d", "a"),
607 ],
608 cycles = [
609 &["a", "b", "c", "a"],
610 &["a", "b", "d", "a"],
611 ],
612 },
613 test_cycle_four_nodes_doubly_linked => {
614 edges = [
615 ("a", "b"),
616 ("b", "a"),
617 ("b", "c"),
618 ("c", "b"),
619 ("c", "d"),
620 ("d", "c"),
621 ("d", "a"),
622 ("a", "d"),
623 ],
624 cycles = [
625 &["a", "b", "c", "d", "a"],
626 &["a", "b", "a"],
627 &["a", "d", "c", "b", "a"],
628 &["a", "d", "a"],
629 &["b", "c", "b"],
630 &["c", "d", "c"],
631 ],
632 },
633
634 test_cycle_self_referential_islands => {
637 edges = [
638 ("a", "a"),
639 ("b", "b"),
640 ("c", "c"),
641 ("d", "e"),
642 ],
643 cycles = [
644 &["a", "a"],
645 &["b", "b"],
646 &["c", "c"],
647 ],
648 },
649 test_cycle_two_sets_of_two_nodes => {
650 edges = [
651 ("a", "b"),
652 ("b", "a"),
653 ("c", "d"),
654 ("d", "c"),
655 ],
656 cycles = [
657 &["a", "b", "a"],
658 &["c", "d", "c"],
659 ],
660 },
661 test_cycle_two_sets_of_two_nodes_connected => {
662 edges = [
663 ("a", "b"),
664 ("b", "a"),
665 ("c", "d"),
666 ("d", "c"),
667 ("a", "c"),
668 ],
669 cycles = [
670 &["a", "b", "a"],
671 &["c", "d", "c"],
672 ],
673 },
674 }
675
676 test_shortest_path! {
677 test_empty_graph => {
678 edges = &[],
679 from = "a",
680 to = "b",
681 shortest_path = None,
682 },
683 test_two_nodes => {
684 edges = &[
685 ("a", "b"),
686 ],
687 from = "a",
688 to = "b",
689 shortest_path = Some(&["a", "b"]),
690 },
691 test_path_to_self => {
692 edges = &[
693 ("a", "a"),
694 ],
695 from = "a",
696 to = "a",
697 shortest_path = Some(&["a", "a"]),
698 },
699 test_path_to_self_no_edge => {
700 edges = &[
701 ("a", "b"),
702 ],
703 from = "a",
704 to = "a",
705 shortest_path = None,
706 },
707 test_path_three_nodes => {
708 edges = &[
709 ("a", "b"),
710 ("b", "c"),
711 ],
712 from = "a",
713 to = "c",
714 shortest_path = Some(&["a", "b", "c"]),
715 },
716 test_path_multiple_options => {
717 edges = &[
718 ("a", "b"),
719 ("b", "c"),
720 ("a", "c"),
721 ],
722 from = "a",
723 to = "c",
724 shortest_path = Some(&["a", "c"]),
725 },
726 test_path_two_islands => {
727 edges = &[
728 ("a", "b"),
729 ("c", "d"),
730 ],
731 from = "a",
732 to = "d",
733 shortest_path = None,
734 },
735 test_path_with_cycle => {
736 edges = &[
737 ("a", "b"),
738 ("b", "a"),
739 ],
740 from = "a",
741 to = "b",
742 shortest_path = Some(&["a", "b"]),
743 },
744 test_path_with_cycle_2 => {
745 edges = &[
746 ("a", "b"),
747 ("b", "c"),
748 ("c", "b"),
749 ],
750 from = "a",
751 to = "b",
752 shortest_path = Some(&["a", "b"]),
753 },
754 test_path_with_cycle_3 => {
755 edges = &[
756 ("a", "b"),
757 ("b", "c"),
758 ("c", "b"),
759 ("b", "d"),
760 ("d", "e"),
761 ],
762 from = "a",
763 to = "e",
764 shortest_path = Some(&["a", "b", "d", "e"]),
765 },
766 }
767}