BZOJ4555: [Tjoi2016&Heoi2016]求和

BZOJ4555: [Tjoi2016&Heoi2016]求和

我們省選的題…
MirrorGray

考慮這個式子的組合意義,對於每一個i,列舉j表示將i個小球放入j個有序集合,且每個集合選擇或者不選的方案數。
我們用f[i]表示將i個小球放入任意個有序集合,且每個集合選擇或不選的方案數,則列舉最後一個集合的大小i-j,可以得到遞推式:
for(int i = 1;i <=n ;i )
for(int j = 0;j < i ;j )f[i]=(f[i] f[j]*c[i][j]*2);
c[i][j]是組合數。
意義就是有j個小球任意組合,剩下的組成最後一個集合,且這個集合選或不選的方案數。
將c[i][j]寫成fac[i]*inv[j]*inv[i-j],fac是階乘,inv是階乘的逆元,那麼顯然這是一個卷積的形式。由於是同層轉移,所以要分治一下。
做法就呼之欲出了,NTT 分治。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
//by:MirrorGray
using namespace std;
const int N=411111,mod=998244353;
int fac[N],inv[N];
int f[N],wn[N],rev[N];
int po(int a,int b){
int ret=1;
while(b){
if(b&1)ret=(ll)ret*a%mod;
a=(ll)a*a%mod;b>>=1;
}
return ret;
}
struct FFT{
int n,bit,a[N];
void set(){
a[0]=0;
while(n>0)a[n--]=0;
}
void push(int x){
a[n  ]=x;
}
int reverse(int x){
int ret=0;
for(int i=0;i<bit;i  )if(x&(1<<i))ret|=1<<(bit-i-1);
return ret;
}
void fft(int *f,int o){
for(int i=0;i< 1<<bit ;i  )if(rev[i]>i)swap(f[rev[i]],f[i]);
for(int i=1;i<=bit;i  ){
for(int j=0;j< 1<<bit ;j =1<<i){
int t=(1<<(i-1)),W=0;
for(int k=0;k<t;k  ){
int tmp=f[j k],a=(ll)wn[W]*f[j k t]%mod;
f[j k]=(tmp a)%mod;
f[j k t]=(tmp-a mod)%mod;
W =o*(1<<(bit-i));
if(W<0)W =1<<bit;
}
}
}
int inv=po(1<<bit,mod-2);
if(o<0)for(int i=0;i< 1<<bit ;i  )f[i]=(ll)f[i]*inv%mod;
}
void operator *=(FFT&b){
for(bit=0;(1<<bit) <(max(n,b.n)<<1);bit  );
for(int i=0;i< 1<<bit ;i  ){
if (i&1) rev[i]=rev[i>>1]>>1|(1<<(bit-1));
else rev[i]=rev[i>>1]>>1;
}
wn[0]=1;int lala=po(3,(mod-1)/(1<<bit));
for(int i=1;i< 1<<bit ;i  )wn[i]=(ll)wn[i-1]*lala%mod;
fft(a,1);fft(b.a,1);
for(int i=0;i< 1<<bit ;i  )a[i]=(ll)a[i]*b.a[i]%mod;
fft(a,-1);n=b.n=1<<bit;
}
}a,b;
void solve(int l,int r){
if(l==r){
if(l)f[l]=(ll)f[l]*fac[l]%mod*2%mod;
return ;
}
int mid=(l r)>>1;
solve(l,mid);a.set();b.set();
for(int i=l;i<=mid;i  )a.push((ll)f[i]*inv[i]%mod);
for(int i=1;i<=r-l;i  )b.push(inv[i]);
a*=b;
for(int i=mid 1;i<=r;i  )f[i]=(f[i] a.a[i-l-1])%mod;
solve(mid 1,r);
}
int main(){
int n;scanf("%d",&n);
fac[0]=inv[0]=1;
for(int i=1;i<=n;i  ){
fac[i]=(ll)fac[i-1]*i%mod;
inv[i]=po(fac[i],mod-2);
}
f[0]=1;solve(0,n);
int ans=0;
for(int i=1;i<=n;i  )ans=(ans f[i])%mod;
printf("%d\n",(ans 1)%mod);
return 0;
}