conspire/math/set/dsu/
mod.rs1use std::collections::HashMap;
2
3struct DisjointSetUnion {
4 parent: Vec<usize>,
5 rank: Vec<u8>,
6}
7
8impl DisjointSetUnion {
9 fn new(n: usize) -> Self {
10 Self {
11 parent: (0..n).collect(),
12 rank: vec![0; n],
13 }
14 }
15 fn find(&mut self, x: usize) -> usize {
16 if self.parent[x] != x {
17 let p = self.parent[x];
18 self.parent[x] = self.find(p);
19 }
20 self.parent[x]
21 }
22 fn union(&mut self, a: usize, b: usize) {
23 let mut ra = self.find(a);
24 let mut rb = self.find(b);
25 if ra == rb {
26 return;
27 }
28 if self.rank[ra] < self.rank[rb] {
29 std::mem::swap(&mut ra, &mut rb);
30 }
31 self.parent[rb] = ra;
32 if self.rank[ra] == self.rank[rb] {
33 self.rank[ra] += 1;
34 }
35 }
36}
37
38pub fn disjoint_set_union<const N: usize>(
39 set_members: &[[usize; N]],
40 num_members: usize,
41) -> Vec<Vec<usize>> {
42 let num_sets = set_members.len();
43 let mut member_sets = vec![vec![]; num_members];
44 set_members.iter().enumerate().for_each(|(set, members)| {
45 members
46 .iter()
47 .for_each(|&member| member_sets[member].push(set))
48 });
49 let mut dsu = DisjointSetUnion::new(num_sets);
50 member_sets
51 .into_iter()
52 .filter(|v| v.len() >= 2)
53 .for_each(|sets| {
54 let first = sets[0];
55 sets[1..].iter().for_each(|&s| dsu.union(first, s))
56 });
57 let mut disjoint_sets = HashMap::<_, Vec<_>>::new();
58 (0..num_sets).for_each(|s| disjoint_sets.entry(dsu.find(s)).or_default().push(s));
59 disjoint_sets
60 .into_values()
61 .map(|sets| {
62 let mut members = sets
63 .into_iter()
64 .flat_map(|set| set_members[set])
65 .collect::<Vec<_>>();
66 members.sort_unstable();
67 members.dedup();
68 members
69 })
70 .collect()
71}