conspire/math/set/dsu/
mod.rs

1use std::collections::HashMap;
2
3struct DisjointSetUnion {
4    parent: Vec<usize>,
5    rank: Vec<u8>,
6}
7
8impl DisjointSetUnion {
9    fn new(n: usize) -> Self {
10        Self {
11            parent: (0..n).collect(),
12            rank: vec![0; n],
13        }
14    }
15    fn find(&mut self, x: usize) -> usize {
16        if self.parent[x] != x {
17            let p = self.parent[x];
18            self.parent[x] = self.find(p);
19        }
20        self.parent[x]
21    }
22    fn union(&mut self, a: usize, b: usize) {
23        let mut ra = self.find(a);
24        let mut rb = self.find(b);
25        if ra == rb {
26            return;
27        }
28        if self.rank[ra] < self.rank[rb] {
29            std::mem::swap(&mut ra, &mut rb);
30        }
31        self.parent[rb] = ra;
32        if self.rank[ra] == self.rank[rb] {
33            self.rank[ra] += 1;
34        }
35    }
36}
37
38pub fn disjoint_set_union<const N: usize>(
39    set_members: &[[usize; N]],
40    num_members: usize,
41) -> Vec<Vec<usize>> {
42    let num_sets = set_members.len();
43    let mut member_sets = vec![vec![]; num_members];
44    set_members.iter().enumerate().for_each(|(set, members)| {
45        members
46            .iter()
47            .for_each(|&member| member_sets[member].push(set))
48    });
49    let mut dsu = DisjointSetUnion::new(num_sets);
50    member_sets
51        .into_iter()
52        .filter(|v| v.len() >= 2)
53        .for_each(|sets| {
54            let first = sets[0];
55            sets[1..].iter().for_each(|&s| dsu.union(first, s))
56        });
57    let mut disjoint_sets = HashMap::<_, Vec<_>>::new();
58    (0..num_sets).for_each(|s| disjoint_sets.entry(dsu.find(s)).or_default().push(s));
59    disjoint_sets
60        .into_values()
61        .map(|sets| {
62            let mut members = sets
63                .into_iter()
64                .flat_map(|set| set_members[set])
65                .collect::<Vec<_>>();
66            members.sort_unstable();
67            members.dedup();
68            members
69        })
70        .collect()
71}