dtoj#4317. 隨機(random)

題目描述:

有一棵$N$個點的樹和$M$個操做$x,y,S$,表示給$x$到$y$的鏈上的節點都加入一個數字字符串$S$。node

全部操做都結束後須要對每一個點進行一次詢問。首先在該節點的全部字符串造成的`trie`上隨機選擇一個點(此時爲步數$0$),這裏定義`trie`的根節點爲空串,而後每次隨機一個與當前點相鄰的點移動過去且步數$+1$,若是移動到了某個字符串對應的節點則結束。詢問每一個點對應`trie`的指望步數。c++

輸出模$998244353$意義下的值。假如答案爲$\frac{x}{y}(gcd(x,y)=1)$,那麼須要輸出$r(0≤r<998244353)$知足$x \equiv y×r(mod~998244353)$。git

數據保證答案不會存在$y \equiv 0(mod~998244353)$的狀況。ide

思路:

考慮對於一棵固定的 $trie​$ 樹,一個點走到終止點的指望會是
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}f[to])+f[fa])+1
$$
那麼若是咱們對於一顆 $trie$ 樹 $dfs$ 時,對於孩子節點的狀況已經求得了,咱們把 $f[fax]$ 看做未知數, $f[x]$ 能夠表示成一個關於 $f[fax]$ 的一次函數。
$$
f[x]=A_x\times f[fax]+B_x
$$
考慮如何肯定每個值的 $A_x$ 和 $B_x$ ,訪問到 $x$ 對於每個孩子的狀況已經肯定,即:
$$
f[to]=A_{to}\times f[x]+B_{to}
$$
那麼:
$$
f[x]=\frac{1}{d[x]}\times((\sum_{son}A_{to}f[x]+B_{to})+f[fax])+1
$$
整理一下獲得:
$$
f[x]=\frac{1}{d[x]-\sum_{son}A_{to}}f[fax]+\frac{(\sum_{son}B_{to})+d[x]}{d[x]-\sum_{son}A_{to}}
$$
容易知道最後對於一棵 $trie​$ 樹的答案是
$$
Ans=\frac{\sum_{i=1}^{cnt} f[x]}{cnt}
$$
( $cnt$ 表示 $trie$ 樹的點數 )函數

因此要考慮維護整個子樹的和ui

咱們再用同樣的方法表示出子樹的和 $g[x]$ :
$$
g[x]=f[x]+\sum_{son}g[to]
$$
同理 $g[x]$ 能夠表示成關於 $ f[fax] $ 的一次函數:
$$
g[to]=C_xf[fax]+D_x
$$spa

$$
g[x]=f[x]+\sum_{son}(C_{x}f[x]+D_{x})
$$code

$$
g[x]=A_xf[fax]+Bx+\sum_{son}(C_{to}(A_xf[fax]+B_x)+D_{to})
$$blog

$$
g[x]=((1+\sum_{son}C_{to})A_{x})f[fax]+(\sum_{son}C_{to}+1)B_x+\sum_{son}D_{to}
$$字符串

對於樹上結點的刪除與添加在 $trie$ 樹上動態修改

如下代碼:

#include<bits/stdc++.h>
#define il inline
#define pb push_back
#define LL long long
#define _(d) while(d(isdigit(ch=getchar())))
using namespace std;
const int N=3e5+5,M=2e6+5,p=998244353;
char s[M],c[M];
int n,head[N],ne[N<<1],to[N<<1],cnt,fa[N][21],d[N],res[N],be[N];
int sz[M],rt[N],A[M],B[M],C[M],D[M],ch[M][12],num[M],tag[M],m;
struct node{int x,c;};
vector<node> v[N];
il int read(){
   int x,f=1;char ch;
   _(!)ch=='-'?f=-1:f;x=ch^48;
   _()x=(x<<1)+(x<<3)+(ch^48);
   return f*x;
}
il int mu(int x,int y){
    return x+y>=p?x+y-p:x+y;
}
il int ksm(LL a,int y){
    LL b=1;
    while(y){
        if(y&1)b=b*a%p;
        a=a*a%p;y>>=1;
    }
    return b;
}
il void ins(int x,int y){
    ne[++cnt]=head[x];
    head[x]=cnt;to[cnt]=y;
}
il void dfs1(int x){
    for(int i=1;fa[x][i-1];i++)fa[x][i]=fa[fa[x][i-1]][i-1];
    for(int i=head[x];i;i=ne[i]){
        if(fa[x][0]==to[i])continue;
        fa[to[i]][0]=x;
        d[to[i]]=d[x]+1;dfs1(to[i]);
    }
}
il int Lca(int x,int y){
    if(d[x]<d[y])swap(x,y);
    for(int i=20;i>=0;i--)if(d[fa[x][i]]>=d[y])x=fa[x][i];
    if(x==y)return x;
    for(int i=20;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
    return fa[x][0];
}
il void update(int x,int f=0){
    int Sa=0,Sb=0,Sc=0,Sd=0,d=0;
    A[x]=B[x]=C[x]=D[x]=0;d=f^1;
    if(!num[x]){sz[x]=0;return;}sz[x]=1;
    for(int i=0,to;i<=9;i++)if(sz[to=ch[x][i]]>0){
        d++;sz[x]+=sz[to];
        Sa=mu(Sa,A[to]);Sb=mu(Sb,B[to]);
        Sc=mu(Sc,C[to]);Sd=mu(Sd,D[to]);
    }
    if(tag[x]){D[x]=Sd;return;}
    A[x]=ksm(mu(d,p-Sa),p-2);B[x]=1ll*A[x]*mu(Sb,d)%p;
    C[x]=1ll*A[x]*mu(Sc,1)%p;D[x]=mu(1ll*B[x]*mu(Sc,1)%p,Sd);
}
il void add(int &x,int id,int l){
    if(!x)x=++cnt;num[x]++;
    if(l>=be[id+1]){tag[x]++;update(x);return;}
    add(ch[x][s[l]-'0'],id,l+1);update(x);
}
il void del(int x,int id,int l){
    num[x]--;
    if(l>=be[id+1]){tag[x]--;update(x);return;}
    del(ch[x][s[l]-'0'],id,l+1);update(x);
}
il void merge(int &x,int y){
    if(!num[x]||!num[y]){x=(num[x]?x:y);return;}
    num[x]+=num[y];tag[x]+=tag[y];
    for(int i=0;i<=9;i++)merge(ch[x][i],ch[y][i]);
    update(x);
}
il int query(int x){
    update(x,1);
    return 1ll*D[x]*ksm(sz[x],p-2)%p;
}
il void dfs(int x,int fa){
    int son=0;
    for(int i=head[x];i;i=ne[i]){
        if(fa^to[i]){
            dfs(to[i],x);son=to[i];
        }
    }
    if(son)rt[x]=rt[son];
    for(int i=head[x];i;i=ne[i])
        if(fa^to[i]&&to[i]^son)merge(rt[x],rt[to[i]]);
    for(int i=0;i<v[x].size();i++){
        node k=v[x][i];
        if(k.c>0)add(rt[x],k.x,be[k.x]);
        else del(rt[x],k.x,be[k.x]);
    }
    res[x]=query(rt[x]);
}
int main()
{
    n=read();
    for(int i=1;i<n;i++){
        int x=read(),y=read();
        ins(x,y);ins(y,x);
    }
    d[1]=1;dfs1(1);m=read();
    int now=1;
    for(int i=1;i<=m;i++){
        int x=read(),y=read();
        scanf(" %s",c+1);
        be[i]=now;int l=strlen(c+1);
        for(int i=1;i<=l;i++)s[now++]=c[i];
        int lca=Lca(x,y);
        if(d[x]>d[y])swap(x,y);
        if(x==lca)v[fa[x][0]].pb((node){i,-1}),v[y].pb((node){i,1});
        else v[x].pb((node){i,1}),v[y].pb((node){i,1}),v[lca].pb((node){i,-1}),v[fa[lca][0]].pb((node){i,-1});
    }
    be[m+1]=now;
    cnt=0;dfs(1,0);
    for(int i=1;i<=n;i++)printf("%d\n",res[i]);
    return 0;
} 
View Code
相關文章
相關標籤/搜索