Păduri de mulțimi disjuncte (DSU)

Informație

O bună parte din conținutul acestui articol este preluat din cursul creat de același autor pentru lotul de juniori din 2023, curs care se poate găsi aici

Structurile de date sunt de multe ori foarte utile în multe contexte în programare, acestea dovedindu-se a fi în special foarte puternice și esențiale în lucrul problemelor date la diverse olimpiade și concursuri de informatică. Acest articol va prezenta o structură de date care nu e la prima vedere foarte complicată față de alte structuri de date mai consacrate, dar care se dovedește a fi foarte puternică în rezolvarea multor probleme de toate felurile.

Așa cum sugerează și titlul, vom prezenta în acest articol pădurile de mulțimi disjuncte, sau union-find, denumire dată după cele două operații principale pe care această structură de date le oferă. Union-Find poate fi folosit cu mare ușurință pentru probleme de tipul acelora în care ni se cere să aflăm pe parcurs ce valori sunt legate între ele printr-o relație, presupunând că relațiile dintre valori se adaugă treptat. Pe parcurs se vor remarca diverse optimizări, precum și diferitele clase de probleme în care se poate folosi o asemenea structură de date.

Pentru ușurarea explicațiilor, vom presupune că avem o situație ipotetică în care avem nn prieteni și ni se dau operații în care fie două persoane devin prietene, fie trebuie să decidem dacă două persoane aparțin aceluiași grup de prieteni.

Definirea operațiilor și funcționalității structurii de date

Fundamente

Pentru a reprezenta datele, vom ține într-un vector dimensiunea fiecărei mulțimi, iar într-un alt vector vom ține pentru fiecare poziție, nodul reprezentativ corespunzător grupului de prieteni din care face parte, la început fiecare nod fiind reprezentantul lui însuși.

vector<int> rad(n + 1), card(n + 1);

for (int i = 1; i <= n; ++i) {
    rad[i] = i;
    card[i] = 1;
}

Operația de unire (Union)

La acest pas, ni se dau două persoane și trebuie să stabilim relația de prietenie dintre ei. Deși această operație se face în timp constant, contează foarte mult modul în care facem relația de atribuire, aceasta putând schimba radical complexitatea algoritmului. Astfel, voi introduce prima optimizare, si anume optimizarea de unire după cardinalul mulțimii, astfel încât vom uni mereu mulțimea cu cardinal mai mic la mulțimea cu cardinal mai mare.

Motivul pentru care această optimizare duce la o complexitate mai mică va fi dat de numărul mai mic de operații pe care funcția Find le va face la fiecare pas. De asemenea, această optimizare de a uni mulțimile mai mici la cele mai mari se regăsește în mod frecvent și în alte contexte în diverse structuri de date și nu numai.

void Union(int a, int b) {
    if (card[a] < card[b]) {  // (1)
        swap(a, b);
    }
    rad[b] = a;          // (2)
    card[a] += card[b];  // (3)
}
  1. Vom vrea sa atașăm nodul bb la nodul aa.
  2. Rădăcina lui bb devine aa.
  3. Creștem cardinalul lui aa cu cardinalul lui bb.

Operația de căutare (Find)

La această operație, vrem să găsim pentru un nod, poziția nodului reprezentativ în structura noastră de date. În mod normal, această operație poate face cel mult O(n)\mathcal{O}(n) pași, în cazul în care arborele rezultat ar fi un lanț. Totuși, putem să ne folosim de parcurgerile pe care le facem pentru a reține rezultatele pentru toate nodurile de pe parcursul acelui drum, astfel încât la o parcurgere ulterioară, numărul de pași să se reducă spre un număr constant, structura arborelui ajungând similară cu cea a unui arbore stea.

int Find(int x) {
    if (rad[x] == x) {  // (1)
        return x;
    }
    rad[x] = Find(rad[x]);  // (2)
    return rad[x];
}
  1. Dacă nodul nostru este rădăcină, înseamnă că l-am găsit și-l returnăm în consecință.
  2. Rădăcina nodului nostru va deveni rădăcina rădăcinii curente.

Prime concluzii

Operația union are complexitatea O(1)\mathcal{O}(1), iar operația de find are complexitatea O(n)\mathcal{O}(n). Totuși, datorită optimizărilor menționate mai sus (compresia drumurilor și unirea după dimensiunea mulțimilor), numărul total de operații făcute este O(nlog\*n)\mathcal{O}(n \log^\* n), unde log\*x\log^\* x reprezintă inversul funcției Ackermann, valoare care se poate aproxima ca fiind o constantă.
De asemenea, nefolosirea optimizării de compresie a drumurilor ar duce la complexitatea O(nlogn)\mathcal{O}(n \log n), rezultat foarte important în contextul altor optimizări, cum ar fi tehnica small-to-large sau în general în demonstrarea diverselor rezultate ce țin de sume armonice.

Problema disjoint

InfoArena

View Problem

InfoArena
Deschide

Pentru fiecare operație citită de la intrare, vom implementa funcțiile necesare pentru a obține rezultatul problemei. Unirea a două mulțimi implică mai întâi folosirea funcției Find pentru a găsi rădăcinile, iar mai apoi folosim funcția Union pentru a face unirea propriu-zisă. Folosirea ambelor optimizări pentru îmbunătățirea complexității duce la soluția optimă, ce rulează într-un timp aproximativ liniar raportat la numărul de valori citite.

Soluția de 100 de puncte este următoarea:

#include <fstream>
#include <vector>

using namespace std;

constexpr int NMAX = 100002;

int n, m;

vector<int> rad(NMAX), card(NMAX);

int Find(int x) {
    if (rad[x] == x) {
        return x;
    }
    rad[x] = Find(rad[x]);
    return rad[x];
}

void Union(int a, int b) {
    if (card[a] < card[b]) {
        swap(a, b);
    }
    rad[b] = a;
    card[a] += card[b];
}

int main() {
    ifstream fin("disjoint.in");
    ofstream fout("disjoint.out");

    fin >> n >> m;

    for (int i = 1; i <= n; ++i) {
        rad[i] = i;
        card[i] = 1;
    }

    for (int i = 1; i <= m; ++i) {
        int cod, x, y;
        fin >> cod >> x >> y;

        if (cod == 1 && Find(x) != Find(y)) {
            Union(Find(x), Find(y));
        } else {
            fout << (Find(x) == Find(y)) ? "DA\n" : "NU\n";
        }
    }

    fin.close();
    fout.close();
    return 0;
}

Oare putem implementa mai eficient?

Inițial, noi am implementat această structură folosind doi vectori, anume cel în care ținem cardinalul fiecărei mulțimi, precum și cel în care ținem rădăcina fiecărei mulțimi. Totuși, se poate observa faptul că noi folosim o grămadă de informație inutilă din cauza faptului că pentru fiecare număr, practic ne interesează doar dacă e o rădăcină a unei mulțimi de valori sau nu. Astfel, vom recurge la a reprezenta pozițiile corespunzătoare rădăcinilor cu numere negative, reprezentând x-x, unde xx e cardinalul mulțimii reprezentat de acea valoare, respectiv reprezentarea nodurilor adiacente cu numere pozitive, reprezentând rădăcina mulțimii din care acea valoare face parte.

Soluția de 100 de puncte cu optimizarea de memorie este următoarea:

#include <fstream>
#include <vector>

using namespace std;

constexpr int NMAX = 100002;

int n, q;

vector<int> sz(NMAX);

int Find(int nod) {
    if (sz[nod] < 0) {
        return nod;
    }
    sz[nod] = Find(sz[nod]);
    return sz[nod];
}

void Union(int a, int b) {
    if (a == b) {
        return;
    }

    if (sz[a] > sz[b]) {
        swap(a, b);
    }

    sz[a] += sz[b];
    sz[b] = a;
}

int main() {
    ifstream fin("disjoint.in");
    ofstream fout("disjoint.out");

    fin >> n >> q;

    for (int i = 1; i <= n; ++i) {
        sz[i] = -1;
    }

    for (int i = 1; i <= q; ++i) {
        int cod, x, y;
        fin >> cod >> x >> y;

        if (cod == 1 && Find(x) != Find(y)) {
            Union(Find(x), Find(y));
        } else {
            fout << (Find(x) == Find(y)) ? "DA\n" : "NU\n";
        }
    }

    fin.close();
    fout.close();
    return 0;
}

Problema bile

InfoArena

View Problem

InfoArena
Deschide

Mai întâi, trebuie observat faptul că problema determinării conectivității dinamice este una foarte dificil de rezolvat (vezi acest articol de pe Wikipedia), deci nu are sens să ne chinuim cu asemenea implementări care nu fac obiectul cursului nostru sau în general a programelor olimpiadelor de informatică.

Asta ne duce cu gândul să încercăm să privim problema dintr-o perspectivă diferită, în special și datorită faptului că nu suntem forțați să răspundem la actualizări online. Din acest motiv, vom introduce o abordare care se folosește la multe soluții ce se bazează pe folosirea pădurilor de mulțimi disjuncte.

Practic, în loc să privim problema de la început la final, vom rezolva problema inversă, în care putem adăuga bile, ceea ce ne ajută să reducem problema la o aplicație standard a pădurilor de mulțimi disjuncte, răspunsurile ajungând în cele din urmă să fie afișate în ordinea inversă în care le-am aflat.

Soluția de 100 de puncte este următoarea:

#include <algorithm>
#include <fstream>
#include <utility>

using namespace std;

template <typename T, size_t N>
using Array = T[N];

int n, maxim;

Array<int, 251 * 251> rad, card, rasp;
Array<pair<int, int>, 251 * 251> elim;
Array<Array<int, 251>, 251> nr;
Array<Array<bool, 251>, 251> viz;

int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};

int Find(int x) {
    if (rad[x] == x) {
        return x;
    }

    return rad[x] = Find(rad[x]);
}

void Union(int a, int b) {
    if (card[a] < card[b]) {
        swap(a, b);
    }

    rad[b] = a;
    card[a] += card[b];

    maxim = max(maxim, card[a]);
}

int main() {
    ifstream fin("bile.in");

    fin >> n;

    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= n; j++) {
            int idx = (i - 1) * n + j;
            nr[i][j] = idx;
            card[idx] = 1;
            rad[idx] = idx;
        }
    }

    for (int i = 1; i <= n * n; ++i) {
        fin >> elim[i].first >> elim[i].second;
    }

    fin.close();

    for (int i = n * n; i >= 1; i--) {
        rasp[i] = maxim;
        int x = elim[i].first;
        int y = elim[i].second;

        for (int j = 0; j < 4; j++) {
            int newX = x + dx[j];
            int newY = y + dy[j];

            if (newX >= 1 && newX <= n && newY >= 1 && newY <= n
                && viz[newX][newY]) {
                int b1 = nr[x][y];
                int b2 = nr[newX][newY];
                if (nr[newX][newY] != 0 && Find(b1) != Find(b2)) {
                    Union(Find(b1), Find(b2));
                }
            }
        }

        maxim = max(maxim, 1);
        viz[x][y] = true;
    }

    ofstream fout("bile.out");

    for (int i = 1; i <= n * n; ++i) {
        fout << rasp[i] << '\n';
    }

    fout.close();

    return 0;
}

Problema Secvmax

InfoArena

View Problem

InfoArena
Deschide

Aici putem folosi din nou prelucrarea numerelor în ordine crescătoare a numerelor din vector, iar atunci când adăugăm valorile în considerare, vom verifica fiecare vecin să vedem dacă putem uni valorile din cele două mulțimi, iar la fiecare pas răspunsul e cardinalul maxim al unei mulțimi, care e crescător pe măsură ce creștem valorile adăugate.

#include <algorithm>
#include <fstream>
#include <vector>

using namespace std;

ifstream f("secvmax.in");
ofstream g("secvmax.out");

struct Pair {
    int value;
    int index;

    constexpr bool operator<(const Pair &other) const {
        if (value == other.value) {
            return index < other.index;
        }
        return value < other.value;
    }
};

vector<Pair> queries, elements;
vector<int> result, parent, sequence, length;
int maxLength = 0;

int Find(int node) {
    int root;
    for (root = parent[node]; root != parent[root]; root = parent[root])
        ;
    int x = node;
    while (x != root) {
        swap(root, parent[x]);
    }
    return root;
}

void Union(int a, int b) {
    if (length[a] < length[b]) {
        length[b] += length[a];
        parent[a] = b;
        length[a] = 0;
    } else {
        length[a] += length[b];
        parent[b] = a;
        length[b] = 0;
    }

    maxLength = max(maxLength, length[a] + length[b]);
}

int main() {
    int n, q;
    f >> n >> q;

    elements.resize(n + 1);
    parent.resize(n + 1);
    sequence.resize(n + 1);
    length.resize(n + 1);

    queries.resize(q + 1);
    result.resize(q + 1);

    for (int i = 1; i <= n; ++i) {
        int value;
        f >> value;
        elements[i] = {value, i};
        sequence[i] = value;
        parent[i] = i;
    }

    sort(elements.begin() + 1, elements.end());

    for (int i = 1; i <= q; ++i) {
        int value;
        f >> value;
        queries[i] = {value, i};
    }

    sort(queries.begin() + 1, queries.end());

    int pos = 1;

    for (int i = 1; i <= q; ++i) {
        while (pos <= n && elements[pos].value <= queries[i].value) {
            int idx = elements[pos++].index;
            length[idx] = 1;
            maxLength = max(maxLength, 1);

            if (idx > 0 && sequence[idx - 1] <= sequence[idx]) {
                Union(Find(idx - 1), Find(idx));
            }

            if (idx < n - 1 && sequence[idx + 1] < sequence[idx]) {
                Union(Find(idx + 1), Find(idx));
            }
        }
        result[queries[i].index] = maxLength;
    }

    for (int i = 1; i <= q; ++i) {
        g << result[i] << '\n';
    }

    return 0;
}

Problema joingraf

Kilonova

View Problem

Kilonova

Could not fetch problem details

Deschide

Pentru a rezolva această problemă există mai multe abordări, plecând de la diverse moduri de a gândi problema, dar în contextul pădurilor de mulțimi disjuncte, ne vom concentra pe soluția cu DSU.

Mai întâi, trebuie să observăm că componentele conexe sunt ca niște intervale. De exemplu, să luăm n=7n = 7. Atunci, la început intervalele vor fi: \[1,1],\[2,2],\[3,3],\[4,4],\[5,5],\[6,6],\[7,7]\[1, 1], \[2, 2], \[3, 3], \[4, 4], \[5, 5], \[6, 6], \[7, 7]. Dacă unim muchiile de la 3 la 6, intervalele vor deveni: \[1,1],\[2,2],[3,6],\[7,7]\[1, 1], \[2, 2], [3, 6], \[7, 7].

Atunci, putem folosi o structură de tip DSU. Vom reține par_ipar\_i ca "părintele" nodului ii, sau mai ușor de înțeles, capătul stânga al intervalului în care este nodul ii. Este nevoie să reținem doar capătul dreapta, deoarece capătul dreapta al secvenței curente este predecesorul capătului stânga al secvenței următoare. Vom reține și nxt_inxt\_i capătul stânga al secvenței de după secvența în care este ii.

Iar atunci când avem update cu x,yx, y, mergem la fiecare secvență până la yy (adică când avansăm de la pp la următoarea, facem p=nxt_pp = nxt\_p) și o reunim cu secvența în care este xx.

Iar la query, verificăm dacă intervalul în care este xx este egal cu cel în care este yy. Complexitate: O(N+Q log N)\mathcal{O}(N + Q \ log \ N) timp, O(n)\mathcal{O}(n) memorie.

#include <iostream>
#include <vector>

using namespace std;

constexpr int N = 1e6 + 2;
int n, q;

vector<int> tt(N), jump(N);

int root(int nd) {
    while (tt[nd] != nd) {
        return tt[nd] = root(tt[nd]);
    }
    return nd;
}
void unite(int a, int b) {
    a = root(a);
    b = root(b);

    if (a == b) {
        return;
    }

    tt[b] = tt[a];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> q;

    for (int i = 1; i <= n; i++) {
        tt[i] = jump[i] = i;
    }

    while (q--) {
        int t, x, y;
        cin >> t >> x >> y;

        if (t != 1) {
            cout << (root(x) == root(y) ? "Da\n" : "Nu\n");
            continue;
        }

        int p = x;
        while (p <= y) {
            unite(x, p);
            p = jump[p] + 1;
        }
        jump[x] = jump[y];
    }
    return 0;
}

Concluzii

Acest articol este menit să introducă audiența în folosirea pădurilor de mulțimi disjuncte, punând accentul pe funcționalitățile de bază, fără a menționa alte aplicații importante, cum ar fi algoritmul lui Kruskal sau algoritmii folosiți pentru dynamic connectivity. De asemenea, pădurile de mulțimi disjuncte pot fi folosite pentru a scurta foarte mult implementările aplicațiilor simple la grafuri.

Probleme suplimentare

Bibliografie și lectură suplimentară