This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub ruthen71/rcpl
#include "graph/minimum_steiner_tree.hpp"
Graph<T> g; std::vector<int> terminals; auto dp = minimum_steiner_tree(g, terminals, INF); auto wt = minimum_steiner_tree_mst(g, terminals, INF);
minimum_steiner_tree(g, terminals, inf)
minimum_steiner_tree_mst(g, terminals, inf)
#pragma once #include "graph/graph_template.hpp" #include "data_structure/unionfind.hpp" #include <vector> #include <queue> #include <algorithm> #include <cassert> // minimum steiner tree // O(3 ^ k n + 2 ^ k m \log m) (n = |V|, m = |E|, k = |terminals|) // https://www.slideshare.net/wata_orz/ss-12131479#50 // https://kopricky.github.io/code/Academic/steiner_tree.html // https://atcoder.jp/contests/abc364/editorial/10547 template <class T> std::vector<std::vector<T>> minimum_steiner_tree(Graph<T>& g, const std::vector<int>& terminals, const T inf) { const int n = (int)(g.size()); const int k = (int)(terminals.size()); const int k2 = 1 << k; // dp[bit][v] = ターミナルの部分集合が bit (0 ~ k - 1 に圧縮), 加えて頂点 v も含まれる最小シュタイナー木 std::vector dp(k2, std::vector<T>(n, inf)); for (int i = 0; i < k; i++) dp[1 << i][terminals[i]] = T(0); for (int bit = 0; bit < (1 << k); bit++) { // dp[bit][v] = min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]) // 通常の実装 // for (int sub = bit; sub > 0; sub = (sub - 1) & bit) { // 定数倍高速化 // bit の中で 1 要素だけ sub と bit ^ sub のどちらに属するか決める int bit2 = bit ^ (bit & -bit); for (int sub = bit2; sub > 0; sub = (sub - 1) & bit2) { for (int v = 0; v < n; v++) { dp[bit][v] = std::min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]); } } // dp[bit][v] = min(dp[bit][v], dp[bit][u] + cost(u, v)) using tp = std::pair<T, int>; std::priority_queue<tp, std::vector<tp>, std::greater<tp>> que; for (int u = 0; u < n; u++) que.emplace(dp[bit][u], u); while (!que.empty()) { auto [d, u] = que.top(); que.pop(); if (dp[bit][u] != d) continue; for (auto&& e : g[u]) { if (dp[bit][e.to] > d + e.cost) { dp[bit][e.to] = d + e.cost; que.emplace(dp[bit][e.to], e.to); } } } } // dp[k2 - 1][i] = ターミナルと頂点 i を含む最小シュタイナー木 // dp[k2 - 1][terminals[0]] が基本的な答えになる return dp; } // O(2 ^ {n - k} (n + m)) (n = |V|, m = |E|, k = |terminals|) // https://yukicoder.me/problems/no/114/editorial // n - k <= 20 template <class T> T minimum_steiner_tree_mst(Graph<T>& g, const std::vector<int>& terminals, const T inf) { const int n = (int)(g.size()); const int k = (int)(terminals.size()); // ターミナルに含まれない点集合 (others) を取得 std::vector<int> used(n, 0); for (int i = 0; i < k; i++) used[terminals[i]] = 1; std::vector<int> others; for (int i = 0; i < n; i++) { if (used[i] == 0) others.push_back(i); } // 辺のリスト std::vector<Edge<T>> edges; for (int v = 0; v < n; v++) { for (auto&& e : g[v]) { if (e.from < e.to) edges.push_back(e); } } std::sort(edges.begin(), edges.end(), [&](Edge<T>& a, Edge<T>& b) -> bool { return a.cost < b.cost; }); // ターミナル + others の組合せを全列挙 -> Minimum Spanning Tree を求める T ans = inf; for (int bit = 0; bit < (1 << (n - k)); bit++) { // 使う頂点集合 (used) を計算 for (int i = 0; i < n - k; i++) used[others[i]] = bit >> i & 1; // Minimum Spanning Tree を計算 UnionFind uf(n); T cur = 0; int connected = 0; for (auto&& e : edges) { // subv に対する g の誘導部分グラフに含まれる辺のみ試す if (!(used[e.from] and used[e.to])) continue; if (!uf.same(e.from, e.to)) { uf.merge(e.from, e.to); cur += e.cost; connected++; } } // 全域木が作れたか判定 if (connected + 1 == k + __builtin_popcount(bit)) ans = std::min(ans, cur); // used をもとに戻す for (int i = 0; i < n - k; i++) used[others[i]] = 0; } return ans; }
#line 2 "graph/minimum_steiner_tree.hpp" #line 2 "graph/graph_template.hpp" #include <vector> #include <cassert> template <class T> struct Edge { int from, to; T cost; int id; Edge() = default; Edge(const int from, const int to, const T cost = T(1), const int id = -1) : from(from), to(to), cost(cost), id(id) {} friend bool operator<(const Edge<T>& a, const Edge<T>& b) { return a.cost < b.cost; } friend std::ostream& operator<<(std::ostream& os, const Edge<T>& e) { // output format: {id: cost(from, to) = cost} return os << "{" << e.id << ": cost(" << e.from << ", " << e.to << ") = " << e.cost << "}"; } }; template <class T> using Edges = std::vector<Edge<T>>; template <class T> struct Graph { struct EdgeIterators { public: using Iterator = typename std::vector<Edge<T>>::iterator; EdgeIterators() = default; EdgeIterators(const Iterator& begit, const Iterator& endit) : begit(begit), endit(endit) {} Iterator begin() const { return begit; } Iterator end() const { return endit; } size_t size() const { return std::distance(begit, endit); } Edge<T>& operator[](int i) const { return begit[i]; } private: Iterator begit, endit; }; int n, m; bool is_build, is_directed; std::vector<Edge<T>> edges; // CSR (Compressed Row Storage) 形式用 std::vector<int> start; std::vector<Edge<T>> csr_edges; Graph() = default; Graph(const int n, const bool directed = false) : n(n), m(0), is_build(false), is_directed(directed), start(n + 1, 0) {} // 辺を追加し, その辺が何番目に追加されたかを返す int add_edge(const int from, const int to, const T cost = T(1), int id = -1) { assert(!is_build); assert(0 <= from and from < n); assert(0 <= to and to < n); if (id == -1) id = m; edges.emplace_back(from, to, cost, id); return m++; } // CSR 形式でグラフを構築 void build() { assert(!is_build); for (auto&& e : edges) { start[e.from + 1]++; if (!is_directed) start[e.to + 1]++; } for (int v = 0; v < n; v++) start[v + 1] += start[v]; auto counter = start; csr_edges.resize(start.back() + 1); for (auto&& e : edges) { csr_edges[counter[e.from]++] = e; if (!is_directed) csr_edges[counter[e.to]++] = Edge(e.to, e.from, e.cost, e.id); } is_build = true; } EdgeIterators operator[](int i) { if (!is_build) build(); return EdgeIterators(csr_edges.begin() + start[i], csr_edges.begin() + start[i + 1]); } size_t size() const { return (size_t)(n); } friend std::ostream& operator<<(std::ostream& os, Graph<T>& g) { os << "["; for (int i = 0; i < (int)(g.size()); i++) { os << "["; for (int j = 0; j < (int)(g[i].size()); j++) { os << g[i][j]; if (j + 1 != (int)(g[i].size())) os << ", "; } os << "]"; if (i + 1 != (int)(g.size())) os << ", "; } return os << "]"; } }; #line 2 "data_structure/unionfind.hpp" #line 4 "data_structure/unionfind.hpp" #include <algorithm> struct UnionFind { int n; std::vector<int> parents; UnionFind() {} UnionFind(int n) : n(n), parents(n, -1) {} int leader(int x) { return parents[x] < 0 ? x : parents[x] = leader(parents[x]); } bool merge(int x, int y) { x = leader(x), y = leader(y); if (x == y) return false; if (parents[x] > parents[y]) std::swap(x, y); parents[x] += parents[y]; parents[y] = x; return true; } bool same(int x, int y) { return leader(x) == leader(y); } int size(int x) { return -parents[leader(x)]; } std::vector<std::vector<int>> groups() { std::vector<int> leader_buf(n), group_size(n); for (int i = 0; i < n; i++) { leader_buf[i] = leader(i); group_size[leader_buf[i]]++; } std::vector<std::vector<int>> result(n); for (int i = 0; i < n; i++) { result[i].reserve(group_size[i]); } for (int i = 0; i < n; i++) { result[leader_buf[i]].push_back(i); } result.erase(std::remove_if(result.begin(), result.end(), [&](const std::vector<int>& v) { return v.empty(); }), result.end()); return result; } void init(int n) { parents.assign(n, -1); } // reset }; #line 5 "graph/minimum_steiner_tree.hpp" #line 7 "graph/minimum_steiner_tree.hpp" #include <queue> #line 10 "graph/minimum_steiner_tree.hpp" // minimum steiner tree // O(3 ^ k n + 2 ^ k m \log m) (n = |V|, m = |E|, k = |terminals|) // https://www.slideshare.net/wata_orz/ss-12131479#50 // https://kopricky.github.io/code/Academic/steiner_tree.html // https://atcoder.jp/contests/abc364/editorial/10547 template <class T> std::vector<std::vector<T>> minimum_steiner_tree(Graph<T>& g, const std::vector<int>& terminals, const T inf) { const int n = (int)(g.size()); const int k = (int)(terminals.size()); const int k2 = 1 << k; // dp[bit][v] = ターミナルの部分集合が bit (0 ~ k - 1 に圧縮), 加えて頂点 v も含まれる最小シュタイナー木 std::vector dp(k2, std::vector<T>(n, inf)); for (int i = 0; i < k; i++) dp[1 << i][terminals[i]] = T(0); for (int bit = 0; bit < (1 << k); bit++) { // dp[bit][v] = min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]) // 通常の実装 // for (int sub = bit; sub > 0; sub = (sub - 1) & bit) { // 定数倍高速化 // bit の中で 1 要素だけ sub と bit ^ sub のどちらに属するか決める int bit2 = bit ^ (bit & -bit); for (int sub = bit2; sub > 0; sub = (sub - 1) & bit2) { for (int v = 0; v < n; v++) { dp[bit][v] = std::min(dp[bit][v], dp[sub][v] + dp[bit ^ sub][v]); } } // dp[bit][v] = min(dp[bit][v], dp[bit][u] + cost(u, v)) using tp = std::pair<T, int>; std::priority_queue<tp, std::vector<tp>, std::greater<tp>> que; for (int u = 0; u < n; u++) que.emplace(dp[bit][u], u); while (!que.empty()) { auto [d, u] = que.top(); que.pop(); if (dp[bit][u] != d) continue; for (auto&& e : g[u]) { if (dp[bit][e.to] > d + e.cost) { dp[bit][e.to] = d + e.cost; que.emplace(dp[bit][e.to], e.to); } } } } // dp[k2 - 1][i] = ターミナルと頂点 i を含む最小シュタイナー木 // dp[k2 - 1][terminals[0]] が基本的な答えになる return dp; } // O(2 ^ {n - k} (n + m)) (n = |V|, m = |E|, k = |terminals|) // https://yukicoder.me/problems/no/114/editorial // n - k <= 20 template <class T> T minimum_steiner_tree_mst(Graph<T>& g, const std::vector<int>& terminals, const T inf) { const int n = (int)(g.size()); const int k = (int)(terminals.size()); // ターミナルに含まれない点集合 (others) を取得 std::vector<int> used(n, 0); for (int i = 0; i < k; i++) used[terminals[i]] = 1; std::vector<int> others; for (int i = 0; i < n; i++) { if (used[i] == 0) others.push_back(i); } // 辺のリスト std::vector<Edge<T>> edges; for (int v = 0; v < n; v++) { for (auto&& e : g[v]) { if (e.from < e.to) edges.push_back(e); } } std::sort(edges.begin(), edges.end(), [&](Edge<T>& a, Edge<T>& b) -> bool { return a.cost < b.cost; }); // ターミナル + others の組合せを全列挙 -> Minimum Spanning Tree を求める T ans = inf; for (int bit = 0; bit < (1 << (n - k)); bit++) { // 使う頂点集合 (used) を計算 for (int i = 0; i < n - k; i++) used[others[i]] = bit >> i & 1; // Minimum Spanning Tree を計算 UnionFind uf(n); T cur = 0; int connected = 0; for (auto&& e : edges) { // subv に対する g の誘導部分グラフに含まれる辺のみ試す if (!(used[e.from] and used[e.to])) continue; if (!uf.same(e.from, e.to)) { uf.merge(e.from, e.to); cur += e.cost; connected++; } } // 全域木が作れたか判定 if (connected + 1 == k + __builtin_popcount(bit)) ans = std::min(ans, cur); // used をもとに戻す for (int i = 0; i < n - k; i++) used[others[i]] = 0; } return ans; }