题目
题目描述
在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。
现在他想计算这样一个函数的值:
$f(n)=\sum_{i=0}^n\sum_{j=0}^i S(i,j)\times 2^j \times (j!)$
S(i, j)表示第二类斯特林数,递推公式为:
$S(i, j) = j \times S(i - 1, j) + S(i - 1, j - 1), 1 \le j \le i - 1$。
边界条件为:$S(i, i) = 1(0 \le i), S(i, 0) = 0(1 \le i)$
你能帮帮他吗?
输入输出格式
输入格式:
输入只有一个正整数
输出格式:
输出f(n)。由于结果会很大,输出f(n)对998244353(7 × 17 × 223 + 1)取模的结果即可。
输入输出样例
输入样例#1:
3
输出样例#1:
87
说明
对于50%数据1 ≤ n ≤5000
对于100%数据1 ≤ n ≤ 100000
题解
先化简这个式子,对于第二类斯特林数,有通式:
$S(n,m)=\frac{1}{m!}\sum_{i=0}^{m}(-1)^iC_{m}^{i}(m-i)^n$
所以
$f(n)=\sum_{i=0}^{n}\sum_{j=0}^{i}\frac{1}{j!}\sum_{k=0}^{j}(-1)^kC_{j}^{k}(j-k)^i\times 2^j\times j!$
$f(n)=\sum_{j=0}^{n}2^j\times j!\sum_{i=0}^{n}\sum_{k=0}^{j}\frac{(-1)^k}{k!}\frac{(j-k)^i}{(j-k)!}$
即$f(n)=\sum_{j=0}^{n}2^j\times j!\sum_{k=0}^{j}\frac{(-1)^k}{k!}\frac{\sum_{i=0}^{n}(j-k)^i}{(j-k)!}$
$\sum_{i=0}^{n}a^i$这个玩意如果乘以$a-1$,相当于$(a+a^2+a^3+…+a^{n+1})-(1+a+a^2+…+a^n)$,错位相减为$a^{n+1}-1$,所以$\sum_{i=0}^{n}a^i=\frac{a^{n+1}-1}{a-1}$($a=1$时这个玩意就等于$n+1$了)
所以处理出来后面的两个多项式卷积就好了
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
using namespace std;
const int MAX=4e5+5,mod=998244353;
int n,lim=1,L,ans;
int a[MAX],b[MAX],R[MAX],fac[MAX];
int Pow(int a,int b)
{
int ans=1;
for(;b;b>>=1,a=1ll*a*a%mod)
if(b&1) ans=1ll*ans*a%mod;
return ans;
}
void NTT(int *A,int tt=1)
{
for(int i=0;i<lim;++i)
if(i<R[i]) swap(A[i],A[R[i]]);
for(int i=1,w1;i<lim;i<<=1)
{
w1=Pow(3,(mod-1)/(i<<1));
for(int l=i<<1,w,j=0;j<lim;j+=l)
{
w=1;
for(int k=0,x,y;k<i;++k,w=1ll*w*w1%mod)
{
x=A[j+k],y=1ll*w*A[i+j+k]%mod,A[j+k]=x+y,A[i+j+k]=x-y+mod;
if(A[j+k]>=mod) A[j+k]-=mod;
if(A[i+j+k]>=mod) A[i+j+k]-=mod;
}
}
}
if(tt==-1)
{
reverse(A+1,A+lim);
for(int i=0,G=Pow(lim,mod-2);i<lim;++i)
A[i]=1ll*A[i]*G%mod;
}
}
int main()
{
scanf("%d",&n),fac[0]=a[0]=b[0]=1;
for(int i=1,x;i<=n;++i)
fac[i]=1ll*fac[i-1]*i%mod,x=Pow(fac[i],mod-2),a[i]=1ll*((i&1)?mod-1:1)*x%mod,b[i]=1ll*(Pow(i,n+1)-1+mod)%mod*Pow(i-1,mod-2)%mod*x%mod;
b[1]=n+1;
while(lim<=n<<1) lim<<=1,++L;
for(int i=0;i<=lim;++i)
R[i]=(R[i>>1]>>1)|((i&1)<<L-1);
NTT(a),NTT(b);
for(int i=0;i<=lim;++i)
a[i]=1ll*a[i]*b[i]%mod;
NTT(a,-1);
for(int i=0;i<=n;++i)
{
a[i]=1ll*a[i]*Pow(2,i)%mod*fac[i]%mod,ans+=a[i];
if(ans>=mod) ans-=mod;
}
return printf("%d",ans),0;
}