[算法] 线段树

想必很多同学都遇到过区间 最值/和/积 问题。

线段树,是一种解决区间 最值/和/积 并且要求能够在O(log n)时间内查询区间值或进行区间修改的数据结构。

区间修改与查询问题

这种问题(要求修改或查询区间,而且要边修改边查询),学习过普及组的同学都不少见。现在,让我们归纳这些问题的解法。

单点修改、单点查询

使用数组即可。

查询:取出数组值。O(1)

修改:设置数组值。O(1)

单点修改、区间查询

方法一 暴力

查询需要遍历被查询的区间。O(n)

修改只需修改数组元素值即可。O(1)

因此,这种方法可以用于查询很少,修改很多的地方。

方法二 前缀和

查询只需将首尾两个前缀和相减。O(1)

修改需要修改i以及其后的元素。O(n)

因此,这种方法可以用于查询很多,修改很少(或根本不修改)的地方。但是对于查询和修改基本持平的题目,相对于暴力没有什么提升。

方法三 树状数组

本蒟蒻太弱了,没有学这个。

区间修改、单点查询

方法一 暴力

查询只需取出值。O(1)

修改需要区间修改。O(n)

方法二 差分(前缀和逆运算)

查询需要统计和。O(n)

修改只需要修改首尾元素O(1)

方法三 树状数组

区间修改、区间查询

现在不难看出,到了这里,前缀和与差分已经完全没有优势了。树状数组也不支持这一类题目(一个叫“树状数组区间操作”的东西可以)。那么,自然容易想到一种分块思想。

分块思想

本文中,我们认为查询是区间和,修改是区间内元素都增加一定值

这个时候我们可以把数组分成sqrt(n)块,并对于每段,都单独储存一个区间和。

初始状态

接下来我们要查询区间的值(例子:下标2-9的和)。那么我们只需取出每段的和(如果整段都被覆盖到)以及几个元素(整段没有覆盖到)的值相加。

统计

显然,全部被覆盖到的段最多sqrt(n)个。取到未覆盖到的元素个数一定小于2*sqrt(n)。因此时间复杂度是sqrt(n)

那么要修改呢?修改时我们只需修改段的修改标记(如果整段全部覆盖到)以及几个整段没有全部覆盖到的元素(别忘了同时更新段的和)。我们现在将2-9部分都加3.

修改

有了修改标记,查询和修改的过程就变得复杂一些。如果一个带修改标记的段整段被区间覆盖,那么显然,查询时很容易算出其修改后的和,修改时,只需在原修改标记上再加上新的修改标记。

如果没有完全覆盖呢?这时候就要清除修改标记,把修改标记传递给段内元素(段内元素按照修改标记修改)。还要将修改标记乘以段的长度,加到段的和上。

带修改标记的统计
带修改标记的修改

这样,区间修改与查询操作的时间复杂度都降到了O(sqrt(n))

继续改进

然而O(sqrt(n))并非一个理想情况。理想的是O(log n)。能否做到呢?

其实,我们需要将上面的和、修改标记以及段的左边界、右边界用一个Struct存起来。

再结合线段树的“树”字,你想到了什么?可以使用递归结构。先将原数组分为左右两部分,再将每个部分继续分下去,直到段的长度只有1。这就是线段树。

我们可以用数组存二叉树。令1为根,则如果节点下标为i,其左儿子是i<<1,右儿子是i<<1|1。这样就可以用数组存储它了。易证最大需要数组空间为4*n

现在,我们可以这么定义二叉树:

//应该叫做SegmentTree,“但是叫XtAkIoi也是没有什么问题的”

template<class T>
struct XtOkIai {
    int l,r; //左、右边界
    T v,tag; //v:和,tag:修改标记
    
    int size() {
        return r-l+1;
    }
};
template<class T,int len=1000>
struct XtAkIoi {
    XtOkIai<T> data[len*4+1];
    
    ...
};

要构建二叉树,我们需要的递归参数有:用于构建的数组,当前区间的左、右边界,当前节点在data数组上的下标。

为了方便,要在struct中将当前节点的左、右区间存下。线段树可以包含数组的所有功能,因此创建后脱离数组而存在。

如果当前节点是个叶子节点,那就可以直接将数组中的值写入当前节点的和。否则,就要利用其左、右儿子的和算出当前和。(这个操作定义为push_up,因为要被多次用到)。

那么,写出创建线段树的代码就不难了。

    void push_up(int p) {
        data[p].v =
            data[p<<1].v + data[p<<1].tag*data[p<<1].size()+
            data[p<<1|1].v + data[p<<1|1].tag*data[p<<1|1].size();
    }
    
    void make(int l,int r,T *base,int p=1) {
        data[p].l=l;
        data[p].r=r;
        data[p].tag=0;
        if (!(l<=r)) {
            cerr<<"XtAkIoi: incorrect make. Tried to input negative range ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        if(l==r) {
            data[p].v=base[l];
        }
        else {
            int mid=(l+r)/2;
            make(l,mid,base,p<<1);
            make(mid+1,r,base,p<<1|1);
            push_up(p);
        }
    }

接下来要考虑的是区间查询(因为单点查询操作可以被区间查询替代,故不写单点查询)

根据上面的分块思想,如果某节点直接被命中,就直接从该节点读取值(叶子节点一定会直接命中的),并且还要考虑修改标记。否则分为左、右两部分继续递归(要先检查需要查询的区间是否覆盖到左或右部分,未覆盖则不继续)

如果一个节点没有直接命中,还有修改标记的话,那么修改标记就要向下传递(定义为push_down)。此处如果直接命中,也执行push_down。这样返回数值时就无需再考虑修改标记。由于叶子节点也可能执行push_down,那么就要判断是否是叶子节点,因为叶子节点没有子节点了。

    void push_down(int p) {
        data[p].v += data[p].tag*data[p].size(); //将修改标记加到和上
        if(data[p].l != data[p].r) { //判断是否叶子
            data[p<<1].tag += data[p].tag;
            data[p<<1|1].tag += data[p].tag;
        }
        data[p].tag=0; //清除修改标记
    }
    
    T query(int l,int r,int p=1) {
        if (!(data[p].l<=l && l<=r && r<=data[p].r)) {
            cerr<<"XtAkIoi: incorrect query. Range of the node is ["
                <<data[p].l<<","<<data[p].r<<"], but the range of query is ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        push_down(p);
        //assert(data[p].tag==0);
        if(l==data[p].l && r==data[p].r) { //直接命中
            return data[p].v;
        }
        else { //未直接命中
            int mid=(data[p].l+data[p].r)/2;
            T a=0,b=0;
            if(l<=mid) { //是否覆盖到左部分
                a=query(l,min(mid,r),(p<<1));
            }
            if(mid+1<=r) { //是否覆盖到右部分
                b=query(max(mid+1,l),r,(p<<1|1));
            }
            //push_up(p); //由于查询除了push_down之外不作别的修改操作,不需要push_up.
            return a+b; //左右边取得的值相加。
        }
    }

修改与查询类似。无论是否直接命中,都先push_down,然后如果段被直接命中,那么更改修改标记,否则分为左右两部分继续。

    void modify(int l,int r,T d,int p=1) {
        if (!(data[p].l<=l && l<=r && r<=data[p].r)) {
            cerr<<"SegmentTree: incorrect modify. Range of the node is ["
                <<data[p].l<<","<<data[p].r<<"], but the range of query is ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        push_down(p);
        //assert(data[p].tag==0);
        if(l==data[p].l && r==data[p].r) { //直接命中
            data[p].tag+=d;
            return;
        }
        else {
            int mid=(data[p].l+data[p].r)/2;
            if(l<=mid) { //是否覆盖到左部分
                modify(l,min(mid,r),d,(p<<1));
            }
            if(mid+1<=r) { //是否覆盖到右部分
                modify(max(mid+1,l),r,d,(p<<1|1));
            }
            push_up(p);
            return;
        }
    }

最后,由于有出锅的可能性,加上输出线段树内所有数据的调试函数(万一调试函数自己出锅了呢)

    void debug() {
        for(int i=1;i<=len*4;i++) {
            if(data[i].l>0) {
                cerr<<"#"<<i<<" ["<<data[i].l
                    <<","<<data[i].r<<"] v="
                    <<data[i].v<<", tag="
                    <<data[i].tag<<endl;;
            }
        }
    }

至此,线段树就写完了。可见,线段树建模的难点实际上是如何利用修改标记。

建议不要每次用线段树都抄模板。要理解之后自己写。

模板题

线段树1(基础):https://www.luogu.org/problemnew/show/P3372

线段树2(进阶): https://www.luogu.org/problemnew/show/P3373

线段树1的AC代码如下:

#include<bits/stdc++.h>
using namespace std;
typedef long long ll; 

template<class T>
struct XtOkIai {
    int l,r;
    T v,tag;
    
    int size() {
        return r-l+1;
    }
};

template<class T,int len=1000>
struct XtAkIoi {
    XtOkIai<T> data[len*4+1];
    
    void push_up(int p) {
        data[p].v =
            data[p<<1].v + data[p<<1].tag*data[p<<1].size()+
            data[p<<1|1].v + data[p<<1|1].tag*data[p<<1|1].size();
    }
    
    void push_down(int p) {
        data[p].v += data[p].tag*data[p].size();
        if(data[p].l != data[p].r) {
            data[p<<1].tag += data[p].tag;
            data[p<<1|1].tag += data[p].tag;
        }
        data[p].tag=0;
    }
    
    void make(int l,int r,T *base,int p=1) {
        data[p].l=l;
        data[p].r=r;
        data[p].tag=0;
        if (!(l<=r)) {
            cerr<<"SegmentTree: incorrect make. Tried to input negative range ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        if(l==r) {
            data[p].v=base[l];
        }
        else {
            int mid=(l+r)/2;
            make(l,mid,base,p<<1);
            make(mid+1,r,base,p<<1|1);
            push_up(p);
        }
    }
    
    T query(int l,int r,int p=1) {
        if (!(data[p].l<=l && l<=r && r<=data[p].r)) {
            cerr<<"SegmentTree: incorrect query. Range of the node is ["
                <<data[p].l<<","<<data[p].r<<"], but the range of query is ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        push_down(p);
        //assert(data[p].tag==0);
        if(l==data[p].l && r==data[p].r) {
            return data[p].v;
        }
        else {
            int mid=(data[p].l+data[p].r)/2;
            T a=0,b=0;
            if(l<=mid) {
                a=query(l,min(mid,r),(p<<1));
            }
            if(mid+1<=r) {
                b=query(max(mid+1,l),r,(p<<1|1));
            }
            //push_up(p);
            return a+b;
        }
    }
    
    void modify(int l,int r,T d,int p=1) {
        if (!(data[p].l<=l && l<=r && r<=data[p].r)) {
            cerr<<"SegmentTree: incorrect modify. Range of the node is ["
                <<data[p].l<<","<<data[p].r<<"], but the range of query is ["
                <<l<<","<<r<<"]. Please check for errors."<<endl;
            int ______[2];
            ______[9999999] = ______[0];
        }
        push_down(p);
        //assert(data[p].tag==0);
        if(l==data[p].l && r==data[p].r) {
            data[p].tag+=d;
            return;
        }
        else {
            int mid=(data[p].l+data[p].r)/2;
            if(l<=mid) {
                modify(l,min(mid,r),d,(p<<1));
            }
            if(mid+1<=r) {
                modify(max(mid+1,l),r,d,(p<<1|1));
            }
            push_up(p);
            return;
        }
    }
    
    void debug() {
        for(int i=1;i<=len*4;i++) {
            if(data[i].l>0) {
                cerr<<"#"<<i<<" ["<<data[i].l
                    <<","<<data[i].r<<"] v="
                    <<data[i].v<<", tag="
                    <<data[i].tag<<endl;;
            }
        }
    }
};

XtAkIoi<ll,100000> tr;

ll a[100001];
int n,m;

int main() {
    ios::sync_with_stdio(false);
    
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>a[i];
    tr.make(1,n,a);
    
    while(m--) {
        int t;
        cin>>t;
        if(t==1) {
            int x,y;
            ll z;
            cin>>x>>y>>z;
            tr.modify(x,y,z);
            //tr.debug();
        }
        else {
            int x,y;
            cin>>x>>y;
            cout<<tr.query(x,y)<<endl;
            //tr.debug();
        }
    }
} 

线段树2(用吸氧代替读入优化,我觉得可以):

//吸氧
    #pragma GCC optimize("inline,Ofast",3)
    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize("Ofast")
    #pragma GCC optimize("inline")
    #pragma GCC optimize("-fgcse")
    #pragma GCC optimize("-fgcse-lm")
    #pragma GCC optimize("-fipa-sra")
    #pragma GCC optimize("-ftree-pre")
    #pragma GCC optimize("-ftree-vrp")
    #pragma GCC optimize("-fpeephole2")
    #pragma GCC optimize("-ffast-math")
    #pragma GCC optimize("-fsched-spec")
    #pragma GCC optimize("unroll-loops")
    #pragma GCC optimize("-falign-jumps")
    #pragma GCC optimize("-falign-loops")
    #pragma GCC optimize("-falign-labels")
    #pragma GCC optimize("-fdevirtualize")
    #pragma GCC optimize("-fcaller-saves")
    #pragma GCC optimize("-fcrossjumping")
    #pragma GCC optimize("-fthread-jumps")
    #pragma GCC optimize("-funroll-loops")
    #pragma GCC optimize("-fwhole-program")
    #pragma GCC optimize("-freorder-blocks")
    #pragma GCC optimize("-fschedule-insns")
    #pragma GCC optimize("inline-functions")
    #pragma GCC optimize("-fschedule-insns2")
    #pragma GCC optimize("-fstrict-aliasing")
    #pragma GCC optimize("-fstrict-overflow")
    #pragma GCC optimize("-falign-functions")
    #pragma GCC optimize("-fcse-skip-blocks")
    #pragma GCC optimize("-fcse-follow-jumps")
    #pragma GCC optimize("-fsched-interblock")
    #pragma GCC optimize("-fpartial-inlining")
    #pragma GCC optimize("no-stack-protector")
    #pragma GCC optimize("-freorder-functions")
    #pragma GCC optimize("-findirect-inlining")
    #pragma GCC optimize("inline-small-functions")
    #pragma GCC optimize("-finline-small-functions")
    #pragma GCC optimize("-ftree-switch-conversion")
    #pragma GCC optimize("-foptimize-sibling-calls")
    #pragma GCC optimize("-fexpensive-optimizations")
    #pragma GCC optimize("-funsafe-loop-optimizations")
    #pragma GCC optimize("inline-functions-called-once")
    #pragma GCC optimize("-fdelete-null-pointer-checks")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define int ll
#ifdef DEBUG
#define CERR cerr
#endif
#ifndef DEBUG
#define CERR if(false)cerr
#endif
#define iter iterator
#define fake true
inline long long llread() {long long x;scanf("%lld",&x);return x;}
inline signed read() {signed x;scanf("%d",&x);return x;}
#define a%b (a%b+b)%b 
const int MAXN=100001;

int n,m,MOD;

//重写的线段树
template<int len>
struct SegmentTree{
    typedef int T;
    typedef int XT;
    const static T initval=0;
    struct SegmentTreeNode{
        const static T tag1d=0;
        const static XT tag2d=1;
        T data;
        T tag1;  //增加数
        XT tag2;  //翻倍数
        int l,r;

        int size() {return r-l+1;} //求长度
        T realvalue() {return (data%MOD*tag2%MOD+size()*tag1%MOD)%MOD;} //pushdown: 求真值
        //真值:(data)*tag2+tag1
        void pushvalue(T ntag1,XT ntag2) { //pushdown: 接受数值
            tag1*=ntag2; //注意两个tag的共生关系。
            //分配律:(data+tag1)*ntag2+ntag1=(data)*ntag2+tag1*ntag2
            tag1%=MOD;
            tag2*=ntag2;
            tag2%=MOD;
            tag1+=ntag1;
            tag1%=MOD;
        }
        void cleartag() {tag1=tag1d;tag2=tag2d;} //清除标记
    };
    SegmentTreeNode d[len*4+1];

    void push_up(int p) {
        if(d[p].size()>1) {
            d[p].data=d[p<<1].realvalue()+d[p<<1|1].realvalue();
            d[p].data%=MOD;
        }
    }

    void push_down(int p) {
        if(d[p].size()>1) {
            d[p<<1].pushvalue(d[p].tag1,d[p].tag2);
            d[p<<1|1].pushvalue(d[p].tag1,d[p].tag2);
        }
        d[p].data=d[p].realvalue();
        d[p].cleartag();
    }

    void init(T *arr,int l,int r,int p=1) {
        d[p].l=l;
        d[p].r=r;
        d[p].cleartag();
        if(l==r) {
            d[p].data=arr[l];
            return;
        }
        int mid=(l+r)/2;
        init(arr,l,mid,p<<1);
        init(arr,mid+1,r,p<<1|1);
        push_up(p);
    }

    T query(int ql,int qr,int p=1) {
        int l=d[p].l,r=d[p].r;
        push_down(p);
        if(l==ql && r==qr) {
            return d[p].data;
        }
        int mid=(l+r)/2;
        T ret=initval;
        if(ql<=mid) ret+=query(ql,min(mid,qr),p<<1);
        if(qr> mid) ret+=query(max(mid+1,ql),qr,p<<1|1);
        return ret%MOD;
    }

    void modify1(int ql,int qr,T val=SegmentTreeNode::tag1d,int p=1) {
        int l=d[p].l,r=d[p].r;
        push_down(p);
        if(l==ql && r==qr) {
            d[p].pushvalue(val,SegmentTreeNode::tag2d);
            return;
        }
        int mid=(l+r)/2;
        if(ql<=mid) modify1(ql,min(mid,qr),val,p<<1);
        if(qr> mid) modify1(max(mid+1,ql),qr,val,p<<1|1);
        push_up(p);
    }

    void modify2(int ql,int qr,XT val=SegmentTreeNode::tag2d,int p=1) {
        int l=d[p].l,r=d[p].r;
        push_down(p);
        if(l==ql && r==qr) {
            d[p].pushvalue(SegmentTreeNode::tag1d,val);
            return;
        }
        int mid=(l+r)/2;
        if(ql<=mid) modify2(ql,min(mid,qr),val,p<<1);
        if(qr> mid) modify2(max(mid+1,ql),qr,val,p<<1|1);
        push_up(p);
    }

    void debug() {
        cout<<"----SegmentTree data dump----"<<endl;
        for(int i=1;i<=4*len;i++) {
            if(d[i].l>0) {
                cout<<"#"<<i<<'['<<d[i].l<<','<<d[i].r<<"] ";
                cout<<"data="<<d[i].data<<' ';
                cout<<"tag1="<<d[i].tag1<<' ';
                cout<<"tag2="<<d[i].tag2<<endl;
            }
        }
        cout<<"-----------------------------"<<endl;
    }

    void debugArr() { //输出数组(确定query不出锅时才能用)。
        cout<<"----SegmentTree array dump----"<<endl;
        for(int i=d[1].l;i<=d[1].r;i++) {
            cout<<"#"<<i<<'\t'<<'\t'<<query(i,i)<<endl;
        }
        cout<<"------------------------------"<<endl;
    }
};

SegmentTree<MAXN> st;
int arr[MAXN];

signed main() {
    n=llread();m=llread();MOD=llread();

    for(int i=1;i<=n;i++) {
        cin>>arr[i];
    }
    st.init(arr,1,n);

    for(int i=1;i<=m;i++) {
        int g;
        cin>>g;
        if(g==1) {
            int x=llread(),y=llread(),k=llread();
            st.modify2(x,y,k);
        }
        else if(g==2) {
            int x=llread(),y=llread(),k=llread();
            st.modify1(x,y,k);
        }
        else if(g==3) {
            int x=llread(),y=llread();
            printf("%lld\n",st.query(x,y));
        }
        else printf("Error\n");
    }
}
点赞

发表评论

电子邮件地址不会被公开。必填项已用 * 标注