rcpl

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub ruthen71/rcpl

:heavy_check_mark: 2D Segment Tree (2 次元セグメント木)
(segment_tree/segment_tree_2d.hpp)

基本的に algebra/monoid 以下のファイルをインクルードして使う

// 1 点更新矩形和
#include "algebra/monoid/monoid_plus.hpp"
#include "data_structure/segment_tree_2d.hpp"
int main() {
    vector<vector<int>> A;
    SegmentTree2D<MonoidPlus<int>> seg(A);
}

Depends on

Verified with

Code

#pragma once

#include "../misc/bit_ceil.hpp"
#include "../misc/countr_zero.hpp"

#include <cassert>
#include <vector>

// Segment Tree 2D
template <class MS> struct SegmentTree2D {
  public:
    using S = typename MS::value_type;

    SegmentTree2D() = default;

    explicit SegmentTree2D(int h, int w)
        : SegmentTree2D(std::vector(h, std::vector<S>(w, MS::identity()))) {}

    explicit SegmentTree2D(const std::vector<std::vector<S>>& v)
        : h((int)(v.size())), w((int)(v[0].size())) {
        sizeh = bit_ceil(h);
        logh = countr_zero(sizeh);
        sizew = bit_ceil(w);
        logw = countr_zero(sizew);
        d = std::vector(sizeh << 1, std::vector<S>(sizew << 1, MS::identity()));
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                d[i + sizeh][j + sizew] = v[i][j];
            }
        }
        for (int i = sizeh - 1; i >= 1; i--) {
            for (int j = sizew; j < (sizew << 1); j++) {
                update_bottom(i, j);
            }
        }
        for (int i = 0; i < (sizeh << 1); i++) {
            for (int j = sizew - 1; j >= 1; j--) {
                update_else(i, j);
            }
        }
    }

    void set(int h, int w, const S& x) {
        assert(0 <= h and h < h and 0 <= w and w < w);
        h += sizeh;
        w += sizew;
        d[h][w] = x;
        for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
        for (int i = 0; i <= logh; i++) {
            for (int j = 1; j <= logw; j++) {
                update_else(h >> i, w >> j);
            }
        }
    }

    void add(int h, int w, const S& x) {
        assert(0 <= h and h < h and 0 <= w and w < w);
        h += sizeh;
        w += sizew;
        d[h][w] = MS::operation(d[h][w], x);
        for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
        for (int i = 0; i <= logh; i++) {
            for (int j = 1; j <= logw; j++) {
                update_else(h >> i, w >> j);
            }
        }
    }

    S operator()(int h, int w) const {
        assert(0 <= h and h < h and 0 <= w and w < w);
        return d[h + sizeh][w + sizew];
    }

    S get(int h, int w) const {
        assert(0 <= h and h < h and 0 <= w and w < w);
        return d[h + sizeh][w + sizew];
    }

    S inner_prod(int h, int w1, int w2) {
        S sml = MS::identity(), smr = MS::identity();
        while (w1 < w2) {
            if (w1 & 1) sml = MS::operation(sml, d[h][w1++]);
            if (w2 & 1) smr = MS::operation(d[h][--w2], smr);
            w1 >>= 1;
            w2 >>= 1;
        }
        return MS::operation(sml, smr);
    }

    S prod(int h1, int w1, int h2, int w2) {
        assert(0 <= h1 and h1 <= h2 and h2 <= h);
        assert(0 <= w1 and w1 <= w2 and w2 <= w);
        S sml = MS::identity(), smr = MS::identity();
        h1 += sizeh;
        h2 += sizeh;
        w1 += sizew;
        w2 += sizew;

        while (h1 < h2) {
            if (h1 & 1) sml = MS::operation(sml, inner_prod(h1++, w1, w2));
            if (h2 & 1) smr = MS::operation(inner_prod(--h2, w1, w2), smr);
            h1 >>= 1;
            h2 >>= 1;
        }
        return MS::operation(sml, smr);
    }

    S all_prod() const { return d[1][1]; }

    std::vector<std::vector<S>> make_vector() {
        std::vector vec(h, std::vector<S>(w));
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) vec[i][j] = get(i, j);
        }
        return vec;
    }

  private:
    int h, logh, sizeh, w, logw, sizew;
    std::vector<std::vector<S>> d;

    inline void update_bottom(int i, int j) {
        d[i][j] = MS::operation(d[(i << 1) | 0][j], d[(i << 1) | 1][j]);
    }

    inline void update_else(int i, int j) {
        d[i][j] = MS::operation(d[i][(j << 1) | 0], d[i][(j << 1) | 1]);
    }
};
#line 2 "segment_tree/segment_tree_2d.hpp"

#line 2 "misc/bit_ceil.hpp"

#include <cassert>

#if __cplusplus >= 202002L
#include <bit>
#endif

// bit_ceil
// (0, 1, 2, 3, 4) -> (1, 1, 2, 4, 4)
#if __cplusplus >= 202002L
using std::bit_ceil;
#else
unsigned int bit_ceil(unsigned int x) {
    unsigned int p = 1;
    while (p < x) p *= 2;
    return p;
}
unsigned long long int bit_ceil(unsigned long long int x) {
    unsigned long long int p = 1;
    while (p < x) p *= 2;
    return p;
}
#endif
int bit_ceil(int x) {
    assert(x >= 0);
    return bit_ceil((unsigned int)(x));
}
long long int bit_ceil(long long int x) {
    assert(x >= 0);
    return bit_ceil((unsigned long long int)(x));
}
#line 2 "misc/countr_zero.hpp"

#if __cplusplus >= 202002L
#include <bit>
#endif

// countr_zero
// (000, 001, 010, 011, 100) -> (32, 0, 1, 0, 2)
#if __cplusplus >= 202002L
using std::countr_zero;
#else
int countr_zero(unsigned int x) {
    return x == 0 ? 32 : __builtin_ctz(x);
}
int countr_zero(unsigned long long int x) {
    return x == 0 ? 64 : __builtin_ctzll(x);
}
#endif
int countr_zero(int x) { return countr_zero((unsigned int)(x)); }
int countr_zero(long long int x) {
    return countr_zero((unsigned long long int)(x));
}
#line 5 "segment_tree/segment_tree_2d.hpp"

#line 7 "segment_tree/segment_tree_2d.hpp"
#include <vector>

// Segment Tree 2D
template <class MS> struct SegmentTree2D {
  public:
    using S = typename MS::value_type;

    SegmentTree2D() = default;

    explicit SegmentTree2D(int h, int w)
        : SegmentTree2D(std::vector(h, std::vector<S>(w, MS::identity()))) {}

    explicit SegmentTree2D(const std::vector<std::vector<S>>& v)
        : h((int)(v.size())), w((int)(v[0].size())) {
        sizeh = bit_ceil(h);
        logh = countr_zero(sizeh);
        sizew = bit_ceil(w);
        logw = countr_zero(sizew);
        d = std::vector(sizeh << 1, std::vector<S>(sizew << 1, MS::identity()));
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) {
                d[i + sizeh][j + sizew] = v[i][j];
            }
        }
        for (int i = sizeh - 1; i >= 1; i--) {
            for (int j = sizew; j < (sizew << 1); j++) {
                update_bottom(i, j);
            }
        }
        for (int i = 0; i < (sizeh << 1); i++) {
            for (int j = sizew - 1; j >= 1; j--) {
                update_else(i, j);
            }
        }
    }

    void set(int h, int w, const S& x) {
        assert(0 <= h and h < h and 0 <= w and w < w);
        h += sizeh;
        w += sizew;
        d[h][w] = x;
        for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
        for (int i = 0; i <= logh; i++) {
            for (int j = 1; j <= logw; j++) {
                update_else(h >> i, w >> j);
            }
        }
    }

    void add(int h, int w, const S& x) {
        assert(0 <= h and h < h and 0 <= w and w < w);
        h += sizeh;
        w += sizew;
        d[h][w] = MS::operation(d[h][w], x);
        for (int i = 1; i <= logh; i++) update_bottom(h >> i, w);
        for (int i = 0; i <= logh; i++) {
            for (int j = 1; j <= logw; j++) {
                update_else(h >> i, w >> j);
            }
        }
    }

    S operator()(int h, int w) const {
        assert(0 <= h and h < h and 0 <= w and w < w);
        return d[h + sizeh][w + sizew];
    }

    S get(int h, int w) const {
        assert(0 <= h and h < h and 0 <= w and w < w);
        return d[h + sizeh][w + sizew];
    }

    S inner_prod(int h, int w1, int w2) {
        S sml = MS::identity(), smr = MS::identity();
        while (w1 < w2) {
            if (w1 & 1) sml = MS::operation(sml, d[h][w1++]);
            if (w2 & 1) smr = MS::operation(d[h][--w2], smr);
            w1 >>= 1;
            w2 >>= 1;
        }
        return MS::operation(sml, smr);
    }

    S prod(int h1, int w1, int h2, int w2) {
        assert(0 <= h1 and h1 <= h2 and h2 <= h);
        assert(0 <= w1 and w1 <= w2 and w2 <= w);
        S sml = MS::identity(), smr = MS::identity();
        h1 += sizeh;
        h2 += sizeh;
        w1 += sizew;
        w2 += sizew;

        while (h1 < h2) {
            if (h1 & 1) sml = MS::operation(sml, inner_prod(h1++, w1, w2));
            if (h2 & 1) smr = MS::operation(inner_prod(--h2, w1, w2), smr);
            h1 >>= 1;
            h2 >>= 1;
        }
        return MS::operation(sml, smr);
    }

    S all_prod() const { return d[1][1]; }

    std::vector<std::vector<S>> make_vector() {
        std::vector vec(h, std::vector<S>(w));
        for (int i = 0; i < h; i++) {
            for (int j = 0; j < w; j++) vec[i][j] = get(i, j);
        }
        return vec;
    }

  private:
    int h, logh, sizeh, w, logw, sizew;
    std::vector<std::vector<S>> d;

    inline void update_bottom(int i, int j) {
        d[i][j] = MS::operation(d[(i << 1) | 0][j], d[(i << 1) | 1][j]);
    }

    inline void update_else(int i, int j) {
        d[i][j] = MS::operation(d[i][(j << 1) | 0], d[i][(j << 1) | 1]);
    }
};
Back to top page