矩陣線段樹

線段樹維護矩陣

無標題.png

直接維護矩陣c++

摒棄以前難看的代碼,換上清真的git

#include<bits/stdc++.h>
using namespace std;
using LL = long long;
template<class T = int> T mian(){
    T s=0,f=1;char ch;
    while(!isdigit(ch=getchar()))(ch=='-')&&(f=-1);
    for(s=ch-'0';isdigit(ch=getchar());s=s*10+ch-'0');
    return s*f;
}
const int maxn = 5e5+5;
const int p = 1e9+7;
struct Matrix{
    LL c[2][2];
    Matrix(){clear();}
    void clear(){memset(c,0,sizeof(c));}
    void e(){clear();c[0][0]=c[1][1]=1;}
    LL *operator[](int x){return c[x];}
    const LL*operator[](int x)const{return c[x];}
    friend Matrix operator*(const Matrix &a,const Matrix &b){
        Matrix ans;
        ans[0][0]=(a[0][0]*b[0][0]%p+a[0][1]*b[1][0]%p)%p;
        ans[0][1]=(a[0][0]*b[0][1]%p+a[0][1]*b[1][1]%p)%p;
        ans[1][0]=(a[1][0]*b[0][0]%p+a[1][1]*b[1][0]%p)%p;
        ans[1][1]=(a[1][0]*b[0][1]%p+a[1][1]*b[1][1]%p)%p;
        return ans;
    }
    friend Matrix operator + (const Matrix &a,const Matrix &b){
        Matrix ans;
        ans[0][0]=(a[0][0]+b[0][0])%p;
        ans[0][1]=(a[0][1]+b[0][1])%p;
        ans[1][0]=(a[1][0]+b[1][0])%p;
        ans[1][1]=(a[1][1]+b[1][1])%p;
        return ans;
    }
}base;

Matrix ksm(Matrix b,int n){
    Matrix ans;ans.e();
    for(;n;n>>=1,b=b*b)
        if(n&1)ans=ans*b;
    return ans;
}

struct Seg_Node{
    Seg_Node *lch,*rch;
    int l,r,isfucked;
    Matrix sum,add;

    Seg_Node():lch(NULL),rch(NULL),l(0),r(0),isfucked(0){}

    int mid(){return (l+r)>>1;}

    void push_up(){ sum=lch->sum+rch->sum; }
    
    void plus(Matrix x){sum=x*sum;add=x*add;isfucked=1;}

    void push_down(){
        if(!isfucked)return ;
        lch->plus(add);
        rch->plus(add);
        add.e();
        isfucked=0;
    }
    
};

typedef Seg_Node* ptr;
ptr root;

void build(int l,int r,ptr &o=root){
    //printf("%d %d\n",l,r);
    o=new Seg_Node; o->l=l; o->r=r;
    o->add.e(); o->isfucked=0;
    if(l==r)return (void)(o->sum=ksm(base,mian()-1));
    int mid=o->mid();
    build(l,mid,o->lch);
    build(mid+1,r,o->rch);
    o->push_up();
}

void addval(int l,int r,Matrix val,ptr o=root){
    if(l<=o->l&&o->r<=r)return o->plus(val);
    int mid=o->mid(); o->push_down();
    if(l<=mid)addval(l,r,val,o->lch);
    if(r>mid) addval(l,r,val,o->rch);
    o->push_up();
}

LL getsum(int l,int r,ptr o=root){
    if(l<=o->l&&o->r<=r)return o->sum[0][0];
    int mid=o->mid();LL ans=0; o->push_down();
    if(l<=mid)(ans+=getsum(l,r,o->lch))%=p;
    if(r>mid) (ans+=getsum(l,r,o->rch))%=p;
    return ans;
}

int n,m;

int main(){
    base[0][0]=base[0][1]=base[1][0]=1;
    base[1][1]=0;
    n=mian(),m=mian();
    build(1,n);
    for(int i=0;i<m;++i){
        int op=mian(),l=mian(),r=mian(),k;
        if(op==1)k=mian(),addval(l,r,ksm(base,k));
        if(op==2)printf("%lld\n",getsum(l,r));
    }
    return 0;
}
相關文章
相關標籤/搜索