rcpl

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ruthen71/rcpl

:x: verify/graph/auxiliary_tree.test.cpp

Depends on

Code

#define PROBLEM "https://atcoder.jp/contests/abc359/tasks/abc359_g"

#include <iostream>

#include "graph/read_graph.hpp"
#include "graph/auxiliary_tree.hpp"

int main() {
    int N;
    std::cin >> N;
    auto g = read_graph<int>(N, N - 1);
    std::vector<int> A(N);
    std::vector<std::vector<int>> vs(N);
    for (int i = 0; i < N; i++) {
        std::cin >> A[i];
        A[i]--;
        vs[A[i]].push_back(i);
    }
    AuxiliaryTree<int> aux(g);
    long long ans = 0;
    for (int col = 0; col < N; col++) {
        auto vec = vs[col];
        if ((int)(vec.size()) == 0) continue;
        auto res = aux.get(vec);
        auto caux = res.first;
        auto label = res.second;
        const int M = (int)(vec.size());
        std::vector<int> dp((int)(label.size()));
        auto rec = [&](auto f, int cur, int par) -> void {
            dp[cur] = (A[label[cur]] == col);
            for (auto&& e : caux[cur]) {
                if (e.to == par) continue;
                f(f, e.to, cur);
                dp[cur] += dp[e.to];
                ans += (long long)e.cost * dp[e.to] * (M - dp[e.to]);
            }
        };
        rec(rec, 0, -1);
    }
    std::cout << ans << '\n';
    return 0;
}
#line 1 "verify/graph/auxiliary_tree.test.cpp"
#define PROBLEM "https://atcoder.jp/contests/abc359/tasks/abc359_g"

#include <iostream>

#line 2 "graph/read_graph.hpp"

#line 2 "graph/graph_template.hpp"

#include <cassert>
#include <vector>

template <class T> struct Edge {
    int from, to;
    T cost;
    int id;

    Edge() = default;
    Edge(const int from, const int to, const T cost = T(1), const int id = -1)
        : from(from), to(to), cost(cost), id(id) {}

    friend bool operator<(const Edge<T>& a, const Edge<T>& b) {
        return a.cost < b.cost;
    }

    friend std::ostream& operator<<(std::ostream& os, const Edge<T>& e) {
        // output format: {id: cost(from, to) = cost}
        return os << "{" << e.id << ": cost(" << e.from << ", " << e.to
                  << ") = " << e.cost << "}";
    }
};
template <class T> using Edges = std::vector<Edge<T>>;

template <class T> struct Graph {
    struct EdgeIterators {
      public:
        using Iterator = typename std::vector<Edge<T>>::iterator;
        EdgeIterators() = default;
        EdgeIterators(const Iterator& begit, const Iterator& endit)
            : begit(begit), endit(endit) {}
        Iterator begin() const { return begit; }
        Iterator end() const { return endit; }
        size_t size() const { return std::distance(begit, endit); }
        Edge<T>& operator[](int i) const { return begit[i]; }

      private:
        Iterator begit, endit;
    };

    int n, m;
    bool is_build, is_directed;
    std::vector<Edge<T>> edges;

    // CSR (Compressed Row Storage) 形式用
    std::vector<int> start;
    std::vector<Edge<T>> csr_edges;

    Graph() = default;
    Graph(const int n, const bool directed = false)
        : n(n), m(0), is_build(false), is_directed(directed), start(n + 1, 0) {}

    // 辺を追加し, その辺が何番目に追加されたかを返す
    int add_edge(const int from,
                 const int to,
                 const T cost = T(1),
                 int id = -1) {
        assert(!is_build);
        assert(0 <= from and from < n);
        assert(0 <= to and to < n);
        if (id == -1) id = m;
        edges.emplace_back(from, to, cost, id);
        return m++;
    }

    // CSR 形式でグラフを構築
    void build() {
        assert(!is_build);
        for (auto&& e : edges) {
            start[e.from + 1]++;
            if (!is_directed) start[e.to + 1]++;
        }
        for (int v = 0; v < n; v++) start[v + 1] += start[v];
        auto counter = start;
        csr_edges.resize(start.back() + 1);
        for (auto&& e : edges) {
            csr_edges[counter[e.from]++] = e;
            if (!is_directed)
                csr_edges[counter[e.to]++] = Edge(e.to, e.from, e.cost, e.id);
        }
        is_build = true;
    }

    EdgeIterators operator[](int i) {
        if (!is_build) build();
        return EdgeIterators(csr_edges.begin() + start[i],
                             csr_edges.begin() + start[i + 1]);
    }

    size_t size() const { return (size_t)(n); }

    friend std::ostream& operator<<(std::ostream& os, Graph<T>& g) {
        os << "[";
        for (int i = 0; i < (int)(g.size()); i++) {
            os << "[";
            for (int j = 0; j < (int)(g[i].size()); j++) {
                os << g[i][j];
                if (j + 1 != (int)(g[i].size())) os << ", ";
            }
            os << "]";
            if (i + 1 != (int)(g.size())) os << ", ";
        }
        return os << "]";
    }
};
#line 4 "graph/read_graph.hpp"

template <class T>
Graph<T> read_graph(const int n,
                    const int m,
                    const bool weight = false,
                    const bool directed = false,
                    const int offset = 1) {
    Graph<T> g(n, directed);
    for (int i = 0; i < m; i++) {
        int a, b;
        std::cin >> a >> b;
        a -= offset, b -= offset;
        T c = 1;
        if (weight) std::cin >> c;
        g.add_edge(a, b, c);
    }
    g.build();
    return g;
}

template <class T>
Graph<T> read_parent(const int n,
                     const bool weight = false,
                     const bool directed = false,
                     const int offset = 1) {
    Graph<T> g(n, directed);
    for (int i = 1; i < n; i++) {
        int p;
        std::cin >> p;
        p -= offset;
        T c = 1;
        if (weight) std::cin >> c;
        g.add_edge(p, i, c);
    }
    g.build();
    return g;
}
#line 2 "graph/auxiliary_tree.hpp"

#line 2 "graph/lowest_common_ancestor.hpp"

#line 2 "misc/topbit.hpp"

#line 2 "misc/countl_zero.hpp"

#if __cplusplus >= 202002L
#include <bit>
#endif

// countl_zero
// (000, 001, 010, 011, 100) -> (32, 31, 30, 30, 29)
#if __cplusplus >= 202002L
using std::countl_zero;
#else
int countl_zero(unsigned int x) {
    return x == 0 ? 32 : __builtin_clz(x);
}
int countl_zero(unsigned long long int x) {
    return x == 0 ? 64 : __builtin_clzll(x);
}
#endif
int countl_zero(int x) { return countl_zero((unsigned int)(x)); }
int countl_zero(long long int x) {
    return countl_zero((unsigned long long int)(x));
}
#line 4 "misc/topbit.hpp"

// topbit
// (000, 001, 010, 011, 100) -> (-1, 0, 1, 1, 2)
int topbit(int x) { return 31 - countl_zero(x); }
int topbit(unsigned int x) { return 31 - countl_zero(x); }
int topbit(long long int x) { return 63 - countl_zero(x); }
int topbit(unsigned long long int x) { return 63 - countl_zero(x); }
#line 5 "graph/lowest_common_ancestor.hpp"

#line 7 "graph/lowest_common_ancestor.hpp"

template <class T> struct LowestCommonAncestor {
    int n, lg;
    std::vector<int> depth;
    std::vector<std::vector<int>> parent;

    LowestCommonAncestor(Graph<T>& g, const int root = 0)
        : n((int)(g.size())),
          lg(topbit(n) + 1),
          depth(n, 0),
          parent(lg, std::vector<int>(n)) {
        auto dfs = [&](auto f, int cur, int par) -> void {
            parent[0][cur] = par;
            for (auto&& e : g[cur]) {
                if (e.to == par) continue;
                depth[e.to] = depth[cur] + 1;
                f(f, e.to, cur);
            }
        };
        dfs(dfs, root, -1);
        for (int k = 0; k + 1 < lg; k++) {
            for (int v = 0; v < n; v++) {
                parent[k + 1][v] =
                    parent[k][v] < 0 ? -1 : parent[k][parent[k][v]];
            }
        }
    }

    int lca(int u, int v) {
        assert((int)(depth.size()) == n);
        if (depth[u] > depth[v]) std::swap(u, v);
        // depth[u] <= depth[v]
        for (int k = 0; k < lg; k++) {
            if ((depth[v] - depth[u]) >> k & 1) v = parent[k][v];
        }
        if (u == v) return u;
        for (int k = lg - 1; k >= 0; k--) {
            if (parent[k][u] != parent[k][v]) {
                u = parent[k][u];
                v = parent[k][v];
            }
        }
        return parent[0][u];
    }

    int level_ancestor(int u, const int d) {
        assert((int)(depth.size()) == n);
        if (depth[u] < d) return -1;
        for (int k = 0; k < lg; k++) {
            if (d >> k & 1) u = parent[k][u];
        }
        return u;
    }

    int distance(const int u, const int v) {
        return depth[u] + depth[v] - 2 * depth[lca(u, v)];
    }
};
#line 5 "graph/auxiliary_tree.hpp"

#include <algorithm>

template <class T> struct AuxiliaryTree {
    int n, root;
    std::vector<int> preorder, rank;
    std::vector<T> depth;
    LowestCommonAncestor<T> lca;

    AuxiliaryTree(Graph<T>& g, const int root = 0)
        : n((int)(g.size())),
          root(root),
          depth(n, T(0)),
          rank(n),
          lca(g, root) {
        // DFS して行きがけ順に頂点を並べる
        auto dfs = [&](auto f, int cur, int par) -> void {
            preorder.push_back(cur);
            for (auto&& e : g[cur]) {
                if (e.to == par) continue;
                depth[e.to] = depth[cur] + e.cost;
                f(f, e.to, cur);
            }
        };
        dfs(dfs, root, -1);
        for (int i = 0; i < n; i++) rank[preorder[i]] = i;
    }

    // (圧縮後のグラフ, グラフの頂点番号 -> 元のグラフの頂点番号 の対応表)
    std::pair<Graph<T>, std::vector<int>> get(std::vector<int> vs) {
        if (vs.empty()) return {};

        auto comp = [&](int i, int j) -> bool { return rank[i] < rank[j]; };
        std::sort(vs.begin(), vs.end(), comp);
        for (int i = 0, vslen = (int)(vs.size()); i + 1 < vslen; i++) {
            vs.emplace_back(lca.lca(vs[i], vs[i + 1]));
        }
        std::sort(vs.begin(), vs.end(), comp);
        vs.erase(unique(vs.begin(), vs.end()), vs.end());

        // Auxiliary Tree
        Graph<T> aux(vs.size(), false);
        std::vector<int> rs;
        rs.push_back(0);

        // i は新しい頂点番号, vs[i] はもとの頂点番号
        // vs は Auxiliary Tree の行きがけ順になっているのでループが DFS
        // になっている
        for (int i = 1; i < (int)(vs.size()); i++) {
            // LCA まで遡ってから辺を追加する
            int l = lca.lca(vs[rs.back()], vs[i]);
            while (vs[rs.back()] != l) rs.pop_back();
            aux.add_edge(rs.back(), i, depth[vs[i]] - depth[vs[rs.back()]]);
            rs.push_back(i);
        }
        aux.build();
        return {aux, vs};
    }
};
#line 7 "verify/graph/auxiliary_tree.test.cpp"

int main() {
    int N;
    std::cin >> N;
    auto g = read_graph<int>(N, N - 1);
    std::vector<int> A(N);
    std::vector<std::vector<int>> vs(N);
    for (int i = 0; i < N; i++) {
        std::cin >> A[i];
        A[i]--;
        vs[A[i]].push_back(i);
    }
    AuxiliaryTree<int> aux(g);
    long long ans = 0;
    for (int col = 0; col < N; col++) {
        auto vec = vs[col];
        if ((int)(vec.size()) == 0) continue;
        auto res = aux.get(vec);
        auto caux = res.first;
        auto label = res.second;
        const int M = (int)(vec.size());
        std::vector<int> dp((int)(label.size()));
        auto rec = [&](auto f, int cur, int par) -> void {
            dp[cur] = (A[label[cur]] == col);
            for (auto&& e : caux[cur]) {
                if (e.to == par) continue;
                f(f, e.to, cur);
                dp[cur] += dp[e.to];
                ans += (long long)e.cost * dp[e.to] * (M - dp[e.to]);
            }
        };
        rec(rec, 0, -1);
    }
    std::cout << ans << '\n';
    return 0;
}
Back to top page