rcpl

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ruthen71/rcpl

:heavy_check_mark: Minimum Steiner Tree (最小シュタイナー木)
(graph/minimum_steiner_tree.hpp)

使い方

Graph<T> g;
std::vector<int> terminals;
auto dp = minimum_steiner_tree(g, terminals, INF);
auto wt = minimum_steiner_tree_mst(g, terminals, INF);

参考文献

Depends on

Verified with

Code

#pragma once

#include "graph/graph_template.hpp"
#include "data_structure/unionfind.hpp"

#include <vector>
#include <queue>
#include <algorithm>
#include <cassert>

// minimum steiner tree
// O(3 ^ k n + 2 ^ k m \log m) (n = |V|, m = |E|, k = |terminals|)
// https://www.slideshare.net/wata_orz/ss-12131479#50
// https://kopricky.github.io/code/Academic/steiner_tree.html
// https://atcoder.jp/contests/abc364/editorial/10547
template <class T> std::vector<std::vector<T>> minimum_steiner_tree(Graph<T>& g, const std::vector<int>& terminals, const T inf) {
    const int n = (int)(g.size());
    const int k = (int)(terminals.size());
    const int k2 = 1 << k;

    // dp[bit][v] = ターミナルの部分集合が bit (0 ~ k - 1 に圧縮), 加えて頂点 v も含まれる最小シュタイナー木
    std::vector dp(k2, std::vector<T>(n, inf));
    for (int i = 0; i < k; i++) dp[1 << i][terminals[i]] = T(0);

    for (int bit = 0; bit < (1 << k); bit++) {
        // dp[bit][v] = min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v])
        // 通常の実装
        // for (int sub = bit; sub > 0; sub = (sub - 1) & bit) {
        // 定数倍高速化
        // bit の中で 1 要素だけ sub と bit ^ sub のどちらに属するか決める
        int bit2 = bit ^ (bit & -bit);
        for (int sub = bit2; sub > 0; sub = (sub - 1) & bit2) {
            for (int v = 0; v < n; v++) {
                dp[bit][v] = std::min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]);
            }
        }
        // dp[bit][v] = min(dp[bit][v], dp[bit][u] + cost(u, v))
        using tp = std::pair<T, int>;
        std::priority_queue<tp, std::vector<tp>, std::greater<tp>> que;
        for (int u = 0; u < n; u++) que.emplace(dp[bit][u], u);
        while (!que.empty()) {
            auto [d, u] = que.top();
            que.pop();
            if (dp[bit][u] != d) continue;
            for (auto&& e : g[u]) {
                if (dp[bit][e.to] > d + e.cost) {
                    dp[bit][e.to] = d + e.cost;
                    que.emplace(dp[bit][e.to], e.to);
                }
            }
        }
    }
    // dp[k2 - 1][i] = ターミナルと頂点 i を含む最小シュタイナー木
    // dp[k2 - 1][terminals[0]] が基本的な答えになる
    return dp;
}

// O(2 ^ {n - k} (n + m)) (n = |V|, m = |E|, k = |terminals|)
// https://yukicoder.me/problems/no/114/editorial
// n - k <= 20
template <class T> T minimum_steiner_tree_mst(Graph<T>& g, const std::vector<int>& terminals, const T inf) {
    const int n = (int)(g.size());
    const int k = (int)(terminals.size());

    // ターミナルに含まれない点集合 (others) を取得
    std::vector<int> used(n, 0);
    for (int i = 0; i < k; i++) used[terminals[i]] = 1;
    std::vector<int> others;
    for (int i = 0; i < n; i++) {
        if (used[i] == 0) others.push_back(i);
    }

    // 辺のリスト
    std::vector<Edge<T>> edges;
    for (int v = 0; v < n; v++) {
        for (auto&& e : g[v]) {
            if (e.from < e.to) edges.push_back(e);
        }
    }
    std::sort(edges.begin(), edges.end(), [&](Edge<T>& a, Edge<T>& b) -> bool { return a.cost < b.cost; });

    // ターミナル + others の組合せを全列挙 -> Minimum Spanning Tree を求める
    T ans = inf;
    for (int bit = 0; bit < (1 << (n - k)); bit++) {
        // 使う頂点集合 (used) を計算
        for (int i = 0; i < n - k; i++) used[others[i]] = bit >> i & 1;

        // Minimum Spanning Tree を計算
        UnionFind uf(n);
        T cur = 0;
        int connected = 0;
        for (auto&& e : edges) {
            // subv に対する g の誘導部分グラフに含まれる辺のみ試す
            if (!(used[e.from] and used[e.to])) continue;
            if (!uf.same(e.from, e.to)) {
                uf.merge(e.from, e.to);
                cur += e.cost;
                connected++;
            }
        }

        // 全域木が作れたか判定
        if (connected + 1 == k + __builtin_popcount(bit)) ans = std::min(ans, cur);

        // used をもとに戻す
        for (int i = 0; i < n - k; i++) used[others[i]] = 0;
    }
    return ans;
}
#line 2 "graph/minimum_steiner_tree.hpp"

#line 2 "graph/graph_template.hpp"

#include <vector>
#include <cassert>

template <class T> struct Edge {
    int from, to;
    T cost;
    int id;

    Edge() = default;
    Edge(const int from, const int to, const T cost = T(1), const int id = -1) : from(from), to(to), cost(cost), id(id) {}

    friend bool operator<(const Edge<T>& a, const Edge<T>& b) { return a.cost < b.cost; }

    friend std::ostream& operator<<(std::ostream& os, const Edge<T>& e) {
        // output format: {id: cost(from, to) = cost}
        return os << "{" << e.id << ": cost(" << e.from << ", " << e.to << ") = " << e.cost << "}";
    }
};
template <class T> using Edges = std::vector<Edge<T>>;

template <class T> struct Graph {
    struct EdgeIterators {
       public:
        using Iterator = typename std::vector<Edge<T>>::iterator;
        EdgeIterators() = default;
        EdgeIterators(const Iterator& begit, const Iterator& endit) : begit(begit), endit(endit) {}
        Iterator begin() const { return begit; }
        Iterator end() const { return endit; }
        size_t size() const { return std::distance(begit, endit); }
        Edge<T>& operator[](int i) const { return begit[i]; }

       private:
        Iterator begit, endit;
    };

    int n, m;
    bool is_build, is_directed;
    std::vector<Edge<T>> edges;

    // CSR (Compressed Row Storage) 形式用
    std::vector<int> start;
    std::vector<Edge<T>> csr_edges;

    Graph() = default;
    Graph(const int n, const bool directed = false) : n(n), m(0), is_build(false), is_directed(directed), start(n + 1, 0) {}

    // 辺を追加し, その辺が何番目に追加されたかを返す
    int add_edge(const int from, const int to, const T cost = T(1), int id = -1) {
        assert(!is_build);
        assert(0 <= from and from < n);
        assert(0 <= to and to < n);
        if (id == -1) id = m;
        edges.emplace_back(from, to, cost, id);
        return m++;
    }

    // CSR 形式でグラフを構築
    void build() {
        assert(!is_build);
        for (auto&& e : edges) {
            start[e.from + 1]++;
            if (!is_directed) start[e.to + 1]++;
        }
        for (int v = 0; v < n; v++) start[v + 1] += start[v];
        auto counter = start;
        csr_edges.resize(start.back() + 1);
        for (auto&& e : edges) {
            csr_edges[counter[e.from]++] = e;
            if (!is_directed) csr_edges[counter[e.to]++] = Edge(e.to, e.from, e.cost, e.id);
        }
        is_build = true;
    }

    EdgeIterators operator[](int i) {
        if (!is_build) build();
        return EdgeIterators(csr_edges.begin() + start[i], csr_edges.begin() + start[i + 1]);
    }

    size_t size() const { return (size_t)(n); }

    friend std::ostream& operator<<(std::ostream& os, Graph<T>& g) {
        os << "[";
        for (int i = 0; i < (int)(g.size()); i++) {
            os << "[";
            for (int j = 0; j < (int)(g[i].size()); j++) {
                os << g[i][j];
                if (j + 1 != (int)(g[i].size())) os << ", ";
            }
            os << "]";
            if (i + 1 != (int)(g.size())) os << ", ";
        }
        return os << "]";
    }
};
#line 2 "data_structure/unionfind.hpp"

#line 4 "data_structure/unionfind.hpp"
#include <algorithm>

struct UnionFind {
    int n;
    std::vector<int> parents;

    UnionFind() {}
    UnionFind(int n) : n(n), parents(n, -1) {}

    int leader(int x) { return parents[x] < 0 ? x : parents[x] = leader(parents[x]); }

    bool merge(int x, int y) {
        x = leader(x), y = leader(y);
        if (x == y) return false;
        if (parents[x] > parents[y]) std::swap(x, y);
        parents[x] += parents[y];
        parents[y] = x;
        return true;
    }

    bool same(int x, int y) { return leader(x) == leader(y); }

    int size(int x) { return -parents[leader(x)]; }

    std::vector<std::vector<int>> groups() {
        std::vector<int> leader_buf(n), group_size(n);
        for (int i = 0; i < n; i++) {
            leader_buf[i] = leader(i);
            group_size[leader_buf[i]]++;
        }
        std::vector<std::vector<int>> result(n);
        for (int i = 0; i < n; i++) {
            result[i].reserve(group_size[i]);
        }
        for (int i = 0; i < n; i++) {
            result[leader_buf[i]].push_back(i);
        }
        result.erase(std::remove_if(result.begin(), result.end(), [&](const std::vector<int>& v) { return v.empty(); }), result.end());
        return result;
    }

    void init(int n) { parents.assign(n, -1); }  // reset
};
#line 5 "graph/minimum_steiner_tree.hpp"

#line 7 "graph/minimum_steiner_tree.hpp"
#include <queue>
#line 10 "graph/minimum_steiner_tree.hpp"

// minimum steiner tree
// O(3 ^ k n + 2 ^ k m \log m) (n = |V|, m = |E|, k = |terminals|)
// https://www.slideshare.net/wata_orz/ss-12131479#50
// https://kopricky.github.io/code/Academic/steiner_tree.html
// https://atcoder.jp/contests/abc364/editorial/10547
template <class T> std::vector<std::vector<T>> minimum_steiner_tree(Graph<T>& g, const std::vector<int>& terminals, const T inf) {
    const int n = (int)(g.size());
    const int k = (int)(terminals.size());
    const int k2 = 1 << k;

    // dp[bit][v] = ターミナルの部分集合が bit (0 ~ k - 1 に圧縮), 加えて頂点 v も含まれる最小シュタイナー木
    std::vector dp(k2, std::vector<T>(n, inf));
    for (int i = 0; i < k; i++) dp[1 << i][terminals[i]] = T(0);

    for (int bit = 0; bit < (1 << k); bit++) {
        // dp[bit][v] = min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v])
        // 通常の実装
        // for (int sub = bit; sub > 0; sub = (sub - 1) & bit) {
        // 定数倍高速化
        // bit の中で 1 要素だけ sub と bit ^ sub のどちらに属するか決める
        int bit2 = bit ^ (bit & -bit);
        for (int sub = bit2; sub > 0; sub = (sub - 1) & bit2) {
            for (int v = 0; v < n; v++) {
                dp[bit][v] = std::min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]);
            }
        }
        // dp[bit][v] = min(dp[bit][v], dp[bit][u] + cost(u, v))
        using tp = std::pair<T, int>;
        std::priority_queue<tp, std::vector<tp>, std::greater<tp>> que;
        for (int u = 0; u < n; u++) que.emplace(dp[bit][u], u);
        while (!que.empty()) {
            auto [d, u] = que.top();
            que.pop();
            if (dp[bit][u] != d) continue;
            for (auto&& e : g[u]) {
                if (dp[bit][e.to] > d + e.cost) {
                    dp[bit][e.to] = d + e.cost;
                    que.emplace(dp[bit][e.to], e.to);
                }
            }
        }
    }
    // dp[k2 - 1][i] = ターミナルと頂点 i を含む最小シュタイナー木
    // dp[k2 - 1][terminals[0]] が基本的な答えになる
    return dp;
}

// O(2 ^ {n - k} (n + m)) (n = |V|, m = |E|, k = |terminals|)
// https://yukicoder.me/problems/no/114/editorial
// n - k <= 20
template <class T> T minimum_steiner_tree_mst(Graph<T>& g, const std::vector<int>& terminals, const T inf) {
    const int n = (int)(g.size());
    const int k = (int)(terminals.size());

    // ターミナルに含まれない点集合 (others) を取得
    std::vector<int> used(n, 0);
    for (int i = 0; i < k; i++) used[terminals[i]] = 1;
    std::vector<int> others;
    for (int i = 0; i < n; i++) {
        if (used[i] == 0) others.push_back(i);
    }

    // 辺のリスト
    std::vector<Edge<T>> edges;
    for (int v = 0; v < n; v++) {
        for (auto&& e : g[v]) {
            if (e.from < e.to) edges.push_back(e);
        }
    }
    std::sort(edges.begin(), edges.end(), [&](Edge<T>& a, Edge<T>& b) -> bool { return a.cost < b.cost; });

    // ターミナル + others の組合せを全列挙 -> Minimum Spanning Tree を求める
    T ans = inf;
    for (int bit = 0; bit < (1 << (n - k)); bit++) {
        // 使う頂点集合 (used) を計算
        for (int i = 0; i < n - k; i++) used[others[i]] = bit >> i & 1;

        // Minimum Spanning Tree を計算
        UnionFind uf(n);
        T cur = 0;
        int connected = 0;
        for (auto&& e : edges) {
            // subv に対する g の誘導部分グラフに含まれる辺のみ試す
            if (!(used[e.from] and used[e.to])) continue;
            if (!uf.same(e.from, e.to)) {
                uf.merge(e.from, e.to);
                cur += e.cost;
                connected++;
            }
        }

        // 全域木が作れたか判定
        if (connected + 1 == k + __builtin_popcount(bit)) ans = std::min(ans, cur);

        // used をもとに戻す
        for (int i = 0; i < n - k; i++) used[others[i]] = 0;
    }
    return ans;
}
Back to top page