rcpl

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ruthen71/rcpl

:heavy_check_mark: Persistent Segment Tree (永続セグメント木)
(segment_tree/persistent_segment_tree.hpp)

通常のセグメント木に加えて、過去のすべての状態に対して 1 点変更と総積の取得が可能なデータ構造です。

$t$ 回目の 1 点変更後が行われた後の長さ $n$ の 配列 $a^t$ に対し

を $O(\log n)$ 時間で処理することが出来ます。

テンプレート引数として、モノイド $(S, \cdot)$ を M として受け取ります。 モノイドとは以下の条件を満たす代数構造です。

例えば、$\cdot$ として $\max$ を計算するモノイドは ここ に定義されています。

計算量は $\cdot$, $e$ が定数時間で計算できると仮定したときのものを記述します。

コンストラクタ

(1) PersistentSegmentTree<M> seg(int n)
(2) PersistentSegmentTree<M> seg(std::vector<S> v)

(1)

長さ $n$ の数列 $a^0$ を作ります。初期値は全部 $e$ です。

(2)

長さ $ n = \left| v \right| $ の数列 $a^0$ を作ります。 $v$ の内容が初期値となります。

計算量

get_time

int seg.get_time()

永続セグメント木が作成された時刻を 0 とし、最後の set・add が行われた時刻を返します。

計算量

set

(1) Node* seg.set(int p, S x)
(2) Node* seg.set(int p, S x, int t)
(3) Node* seg.set(int p, S x, Node* root)

(1)

最新の状態の $a_p$ に $x$ を代入します。

(2)

時刻 $t$ の $a_p^t$ に $x$ を代入します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、$a_p^t$ に $x$ を代入します。

set のいずれについても、更新後のセグメント木のルートノードのポインタを返します。 このポインタでその後の set・add での操作対象のセグメント木を指定することができます。

制約

計算量

add

(1) Node* seg.add(int p, S x)
(2) Node* seg.add(int p, S x, int t)
(3) Node* seg.add(int p, S x, Node* root)

(1)

最新の状態の $a_p$ に $a_p \cdot x$ を代入します。

(2)

時刻 $t$ の $a_p^t$ に $a_p^t \cdot x$ を代入します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、$a_p^t$ に $a_p^t \cdot x$ を代入します。

add のいずれについても、更新後のセグメント木のルートノードのポインタを返します。 このポインタでその後の set・add での操作対象のセグメント木を指定することができます。

制約

計算量

get

(1-1) S seg.get(int p)
(1-2) S seg[int p]
(2) S seg.get(int p, int t)
(3) S seg.get(int p, Node* root)

(1)

最新の状態の $a_p$ を返します。

(2)

時刻 $t$ の $a_p^t$ を返します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、$a_p^t$ を返します。

制約

計算量

prod

(1) S seg.prod(int l, int r)
(2) S seg.prod(int l, int r, int t)
(3) S seg.prod(int l, int r, Node* root)

(1)

最新の状態の $a_l \cdot … \cdot a_{r - 1}$ を計算します。

(2)

時刻 $t$ の $a_l^t \cdot … \cdot a_{r - 1}^t$ を計算します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、$a_l^t \cdot … \cdot a_{r - 1}^t$ を計算します。

prod のいずれについても、モノイドの性質を満たしていると仮定して計算します。 $l = r$ のときは $e$ を返します。

制約

計算量

all_prod

(1) S seg.all_prod()
(2) S seg.all_prod(int t)
(3) S seg.all_prod(Node* root)

(1)

最新の状態の $a_0 \cdot …\cdot a_{n - 1}$ を計算します。

(2)

時刻 $t$ の $a_0^t \cdot …\cdot a_{n - 1}^t$ を計算します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、$a_0^t \cdot …\cdot a_{n - 1}^t$ を計算します。

all_prod のいずれについても、$n = 0$ のときは $e$ を返します。

計算量

make_vector

(1) std::vector<S> seg.make_vector()
(2) std::vector<S> seg.make_vector(int t)
(3) std::vector<S> seg.make_vector(Node* root)

(1)

最新の状態の数列 $a$ を返します。

(2)

時刻 $t$ の数列 $a^t$ を返します。

(3)

時刻 $t$ のセグメント木のルートノードのポインタを root とし、数列 $a^t$ を返します。

計算量

参考文献

Required by

Verified with

Code

#pragma once

#include <cassert>
#include <vector>

// Persistent Segment Tree
// N + Q log_2(N) は N = Q = 500000 のとき 10000000 くらい
template <class MS, int MAX_NODES = 20'000'000> struct PersistentSegmentTree {
  public:
    using S = typename MS::value_type;

    struct Node {
        S d;
        Node *l, *r;
        Node() = default;
        Node(S v, Node* l = nullptr, Node* r = nullptr) : d(v), l(l), r(r) {}
    };

    PersistentSegmentTree() = default;

    explicit PersistentSegmentTree(int n)
        : PersistentSegmentTree(std::vector<S>(n, MS::identity())) {}

    explicit PersistentSegmentTree(const std::vector<S>& v)
        : n((int)(v.size())) {
        roots.push_back(build(v, 0, n));
    }

    int get_time() { return (int)(roots.size()) - 1; }

    Node* set(int p, const S& x, Node* root) {
        assert(0 <= p and p < n);
        roots.push_back(set(p, x, 0, n, root));
        return roots.back();
    }

    Node* set(int p, const S& x) { return set(p, x, roots.back()); }

    Node* set(int p, const S& x, int t) {
        assert(0 <= t and t < (int)(roots.size()));
        return set(p, x, roots[t]);
    }

    Node* add(int p, const S& x, Node* root) {
        assert(0 <= p and p < n);
        roots.push_back(add(p, x, 0, n, root));
        return roots.back();
    }

    Node* add(int p, const S& x) { return add(p, x, roots.back()); }

    Node* add(int p, const S& x, int t) {
        assert(0 <= t and t < (int)(roots.size()));
        return add(p, x, roots[t]);
    }

    S get(int p, Node* root) const {
        assert(0 <= p and p < n);
        return prod(p, p + 1, root);
    }

    S get(int p) const { return get(p, roots.back()); }

    S get(int p, int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return get(p, roots[t]);
    }

    S operator[](int p) const {
        assert(0 <= p and p < n);
        return prod(p, p + 1);
    }

    S prod(int l, int r, Node* root) const {
        assert(0 <= l and l <= r and r <= n);
        return prod(l, r, 0, n, root);
    }

    S prod(int l, int r) const { return prod(l, r, roots.back()); }

    S prod(int l, int r, int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return prod(l, r, roots[t]);
    }

    S all_prod(Node* root) const { return root->d; }

    S all_prod() const { return all_prod(roots.back()); }

    S all_prod(int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return all_prod(roots[t]);
    }

    std::vector<S> make_vector(Node* root) const {
        std::vector<S> vec(n);
        for (int i = 0; i < n; i++) vec[i] = get(i, root);
        return vec;
    }

    std::vector<S> make_vector() const { return make_vector(roots.back()); }

    std::vector<S> make_vector(int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return make_vector(roots[t]);
    }

  private:
    int n;
    std::vector<Node*> roots;
    static inline Node pool[MAX_NODES];
    static inline int pool_idx = 0;

    Node* new_node(S v, Node* l = nullptr, Node* r = nullptr) {
        return &(pool[pool_idx++] = Node(v, l, r));
    }

    Node* merge(Node* l, Node* r) {
        return new_node(MS::operation(l->d, r->d), l, r);
    }

    Node* build(const std::vector<S>& v, int l, int r) {
        if (l + 1 == r) {
            return new_node(v[l]);
        }
        int m = (l + r) / 2;
        return merge(build(v, l, m), build(v, m, r));
    }

    Node* set(int p, const S& x, int l, int r, Node* np) {
        if (l + 1 == r) {
            return new_node(x);
        }
        int m = (l + r) / 2;
        if (l <= p and p < m) {
            return merge(set(p, x, l, m, np->l), np->r);
        } else {
            return merge(np->l, set(p, x, m, r, np->r));
        }
    }

    Node* add(int p, const S& x, int l, int r, Node* np) {
        if (l + 1 == r) {
            return new_node(MS::operation(np->d, x));
        }
        int m = (l + r) / 2;
        if (l <= p and p < m) {
            return merge(add(p, x, l, m, np->l), np->r);
        } else {
            return merge(np->l, add(p, x, m, r, np->r));
        }
    }

    S prod(int ql, int qr, int l, int r, Node* np) const {
        // [ql, qr) と [l, r) が交差しない
        if (qr <= l or r <= ql) return MS::identity();
        // [ql, qr) が [l, r) を完全に含んでいる
        if (ql <= l and r <= qr) return np->d;
        int m = (l + r) / 2;
        return MS::operation(prod(ql, qr, l, m, np->l),
                             prod(ql, qr, m, r, np->r));
    }
};
#line 2 "segment_tree/persistent_segment_tree.hpp"

#include <cassert>
#include <vector>

// Persistent Segment Tree
// N + Q log_2(N) は N = Q = 500000 のとき 10000000 くらい
template <class MS, int MAX_NODES = 20'000'000> struct PersistentSegmentTree {
  public:
    using S = typename MS::value_type;

    struct Node {
        S d;
        Node *l, *r;
        Node() = default;
        Node(S v, Node* l = nullptr, Node* r = nullptr) : d(v), l(l), r(r) {}
    };

    PersistentSegmentTree() = default;

    explicit PersistentSegmentTree(int n)
        : PersistentSegmentTree(std::vector<S>(n, MS::identity())) {}

    explicit PersistentSegmentTree(const std::vector<S>& v)
        : n((int)(v.size())) {
        roots.push_back(build(v, 0, n));
    }

    int get_time() { return (int)(roots.size()) - 1; }

    Node* set(int p, const S& x, Node* root) {
        assert(0 <= p and p < n);
        roots.push_back(set(p, x, 0, n, root));
        return roots.back();
    }

    Node* set(int p, const S& x) { return set(p, x, roots.back()); }

    Node* set(int p, const S& x, int t) {
        assert(0 <= t and t < (int)(roots.size()));
        return set(p, x, roots[t]);
    }

    Node* add(int p, const S& x, Node* root) {
        assert(0 <= p and p < n);
        roots.push_back(add(p, x, 0, n, root));
        return roots.back();
    }

    Node* add(int p, const S& x) { return add(p, x, roots.back()); }

    Node* add(int p, const S& x, int t) {
        assert(0 <= t and t < (int)(roots.size()));
        return add(p, x, roots[t]);
    }

    S get(int p, Node* root) const {
        assert(0 <= p and p < n);
        return prod(p, p + 1, root);
    }

    S get(int p) const { return get(p, roots.back()); }

    S get(int p, int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return get(p, roots[t]);
    }

    S operator[](int p) const {
        assert(0 <= p and p < n);
        return prod(p, p + 1);
    }

    S prod(int l, int r, Node* root) const {
        assert(0 <= l and l <= r and r <= n);
        return prod(l, r, 0, n, root);
    }

    S prod(int l, int r) const { return prod(l, r, roots.back()); }

    S prod(int l, int r, int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return prod(l, r, roots[t]);
    }

    S all_prod(Node* root) const { return root->d; }

    S all_prod() const { return all_prod(roots.back()); }

    S all_prod(int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return all_prod(roots[t]);
    }

    std::vector<S> make_vector(Node* root) const {
        std::vector<S> vec(n);
        for (int i = 0; i < n; i++) vec[i] = get(i, root);
        return vec;
    }

    std::vector<S> make_vector() const { return make_vector(roots.back()); }

    std::vector<S> make_vector(int t) const {
        assert(0 <= t and t < (int)(roots.size()));
        return make_vector(roots[t]);
    }

  private:
    int n;
    std::vector<Node*> roots;
    static inline Node pool[MAX_NODES];
    static inline int pool_idx = 0;

    Node* new_node(S v, Node* l = nullptr, Node* r = nullptr) {
        return &(pool[pool_idx++] = Node(v, l, r));
    }

    Node* merge(Node* l, Node* r) {
        return new_node(MS::operation(l->d, r->d), l, r);
    }

    Node* build(const std::vector<S>& v, int l, int r) {
        if (l + 1 == r) {
            return new_node(v[l]);
        }
        int m = (l + r) / 2;
        return merge(build(v, l, m), build(v, m, r));
    }

    Node* set(int p, const S& x, int l, int r, Node* np) {
        if (l + 1 == r) {
            return new_node(x);
        }
        int m = (l + r) / 2;
        if (l <= p and p < m) {
            return merge(set(p, x, l, m, np->l), np->r);
        } else {
            return merge(np->l, set(p, x, m, r, np->r));
        }
    }

    Node* add(int p, const S& x, int l, int r, Node* np) {
        if (l + 1 == r) {
            return new_node(MS::operation(np->d, x));
        }
        int m = (l + r) / 2;
        if (l <= p and p < m) {
            return merge(add(p, x, l, m, np->l), np->r);
        } else {
            return merge(np->l, add(p, x, m, r, np->r));
        }
    }

    S prod(int ql, int qr, int l, int r, Node* np) const {
        // [ql, qr) と [l, r) が交差しない
        if (qr <= l or r <= ql) return MS::identity();
        // [ql, qr) が [l, r) を完全に含んでいる
        if (ql <= l and r <= qr) return np->d;
        int m = (l + r) / 2;
        return MS::operation(prod(ql, qr, l, m, np->l),
                             prod(ql, qr, m, r, np->r));
    }
};
Back to top page