题目
题目描述
小 F 的生日还有一个多月,大 F 早早地准备起了礼物。
“你想要什么礼物呀?嗯…要不要好吃的?”
“才不要呢,我想要好看的花,永远不会凋谢的花。”
小 F 和大 F 一起生活的国家—— Fairy 国,可以抽象成一棵 $N$ 个节点的树,每个节点就是一个城市,编号为 $1\ldots N$。
大 F 要游历各个城市,为心爱的小 F 寻找好看的花。
Fairy 国的每个城市都有一座山,山上有恰好一朵永远不会凋谢的花,编号为 $i$ 的城市的花的美丽值为 $B_i$。大 F 要在 $N$ 个城市中选出恰好 $M$ 个,并摘来这 $M$ 个城市中的 $M$ 朵花送给小F。可是呢,如果树上的一条边连接的两个城市的花都被摘去,这条边就会塌陷,Fairy 国就会陷入分裂,大 F 作为一个善良的人,不希望这样的情况发生。所以,一种摘法合法,当且仅当对于每条边,这条边相连的两个节点的花不被同时摘去。
大 F 希望小 F 快乐,小 F 的快乐程度将是摘来的 $M$ 朵花的美丽程度的积。大 F 今天闲着没事,想要求出对于所有合法的摘法,小 F 的快乐程度之和对 $998244353$ 取模的结果。
输入格式
第一行两个正整数 $N$ 和 $M$,表示节点的个数与大 F 要为小 F 摘的花的朵数。
第二行 $N$ 个整数 $B_{1\ldots N}$,表示 $N$ 朵花的美丽程度。
接下来 $N-1$ 行,每行两个正整数,描述树上的一条边,保证形成一棵树。
输出格式
一个整数,表示对于所有合法的摘法,小 F 的快乐程度之和对 $998244353$ 取模的结果。
样例
样例输入 1
5 3
3 5 4 8 11
1 2
1 3
2 4
2 5
样例输出 1
616
样例解释 1
有两种选法,选的点集分别是${1,4,5}$ 和 ${3,4,5}$,所以答案是 $3\times 8 \times 11 + 4 \times 8 \times 11$,即 $616$。
样例输入 2
15 6
9 10 2 7 2 4 5 9 3 2 1 9 3 10 7
12 3
4 3
15 8
2 14
7 14
8 14
3 15
6 1
11 1
7 11
9 14
8 5
10 5
13 15
样例输出 2
8214265
样例解释 2
这个样例解释,真要写起来的话就会很长,所以我不解释了,你自己写个暴力看看题意有没有理解错吧 QAQ ,辛苦啦。
数据范围与提示
对于所有数据,保证 $1 \le M \le N \le 8 \times 10^4$,$0 \le B_i < 998244353$。
下表为各个 Subtask 的额外限制与得分,空格表示该项无额外限制。你只有通过一个 Subtask 的所有数据才能得到该 Subtask 的分。
Subtask 编号 | $N$ | $M$ | 特殊限制 | 分值 |
---|---|---|---|---|
1 | $\le 500$ | 7 | ||
2 | $\le 4000$ | 15 | ||
3 | $\le 10$ | 15 | ||
4 | $\forall 1\le i < N$,读入的第 $i$ 条边是 $(i,i+1)$ | 18 | ||
5 | $\forall 1\le i < N$,读入的第 $i$ 条边是 $(1,i+1)$ | 20 | ||
6 | 25 |
题解
首先有简单的$dp$方程,$f_{i,j,0/1}$表示$i$号节点子树中选择了$j$个点,$i$号点是否被选的所有答案和
$f_{i,j,1}=\sum_{k=0}^{j}f_{i,k,1}f_{son_i,j-k,0}$
$f_{i,j,0}=\sum_{k=0}^{j}f_{i,k,0}(f_{son_i,j-k,1}+f_{son_i,j-k,0})$
发现这是个卷积的形式!那么可以$NTT$优化了,复杂度$O(nm\log m)$
期望得分:$7$分太真实了
当然优化优化可能$Subtask\ 3$也能过
$Subtask\ 2$对应的是$N$比较小的情况,卡卡卡卡卡卡卡卡卡卡卡卡卡常可能能过?暂时先放在一边
发现$Subtask\ 4$对应一条链(序列)的情况,可以处理出来每个点的初始值,考虑怎么优化合并
可以分治合并!这样复杂度就是$O(m\log m\times \log n)$的了
发现$Subtask\ 5$对应一个短腿菊花图的情况,考虑怎么优化合并
可以套个堆!堆每次取出两个最短的多项式然后合并,跟上面的分治差不多一个道理,复杂度$O(m\log m\times \log^2 n)$(大概)
这个复杂度很优秀(能过),考虑能不能对应到一个普通的树上
可以套用树链剖分处理轻链重链的方法,把轻链暴力往上合并(用上面菊花图的方法),对于重链分治处理(用上面链的方法)
这样复杂度就是$O(m\log m\times \log^2n)$了(大概)
然后我卡了一上午的常……最后发现我$NTT$常数太大换了种写法直接优化了$3s$多……
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<vector>
# include<queue>
# include<algorithm>
# define mid (l+r>>1)
# define Vo vector<o>
# define Vi vector<int>
using namespace std;
const int MAX=4e5+5,mod=998244353;
struct p{
int x,y;
}c[MAX];
struct q{
Vi A,B;
bool operator< (const q &a)
const{
return A.size()>a.A.size();
}
}f[MAX],g,G;
struct o{
o(){}
o(Vi a,Vi b):A___(a),___A(b){}
Vi A___,_A__,__A_,___A;
}ouo[MAX];
int n,m,num,cnt,lim=1,L,ans,Num;
int R[MAX],h[MAX],fa[MAX],inv[MAX],son[MAX],siz[MAX],id[MAX],val[MAX],re[MAX],a[MAX],b[MAX];
int w[2][MAX];
Vi O;
priority_queue<q> qu;
void dfs(int x=1,int F=0)
{
re[id[x]=++cnt]=x,siz[x]=1,fa[x]=F;
for(int i=h[x];i;i=c[i].x)
if(c[i].y^F)
{
dfs(c[i].y,x),siz[x]+=siz[c[i].y];
if(siz[c[i].y]>siz[son[x]]) son[x]=c[i].y;
}
}
int Pow(int a,int b)
{
int ans=1;
for(;b;b>>=1,a=1ll*a*a%mod)
if(b&1) ans=1ll*ans*a%mod;
return ans;
}
void Pre(int n)
{
Num=n;
int G=Pow(3,(mod-1)/n);
w[0][0]=w[1][0]=1;
for(int i=1;i<n;++i)
w[1][i]=1ll*w[1][i-1]*G%mod;
for(int i=1;i<n;++i)
w[0][i]=w[1][n-i];
}
void NTT(int *A,int r=1)
{
for(int i=1;i<lim;++i)
if(R[i]>i) swap(A[i],A[R[i]]);
for(int i=1;i<lim;i<<=1)
for(int l=i<<1,j=0;j<lim;j+=l)
for(int k=0,x,y;k<i;++k)
x=A[j+k],y=1ll*A[i+j+k]*w[r][Num/(i<<1)*k]%mod,A[j+k]=(x+y)%mod,A[i+j+k]=(x-y+mod)%mod;
if(!r) for(int i=0,G=Pow(lim,mod-2);i<lim;++i)
A[i]=1ll*A[i]*G%mod;
}
Vi operator* (Vi A,Vi B)
{
if(!A.size()) return A;
if(!B.size()) return B;
int n=A.size(),m=B.size();
Vi c;
lim=1,L=0,c.resize(n+m-1);
while(lim<=n+m-2) lim<<=1,++L;
for(int i=0;i<=lim;++i)
R[i]=(R[i>>1]>>1)|((i&1)<<L-1),a[i]=b[i]=0;
for(int i=0;i<n;++i)
a[i]=A[i];
for(int i=0;i<m;++i)
b[i]=B[i];
NTT(a),NTT(b);
for(int i=0;i<=lim;++i)
a[i]=1ll*a[i]*b[i]%mod;
NTT(a,0);
for(int i=0;i<n+m-1;++i)
c[i]=a[i];
return c;
}
Vi operator+ (Vi A,Vi B)
{
int n=A.size(),m=B.size();
if(m>n) A.resize(m);
for(int i=0;i<m;++i)
A[i]=(A[i]+B[i])%mod;
return A;
}
q operator* (q A,q B) {return (q){A.A*B.A,A.B*B.B};}
o operator* (o A,o B)
{
o a;
a.A___=(A.A___*(B.A___+B.__A_))+(A._A__*B.A___);
a._A__=(A.A___*(B._A__+B.___A))+(A._A__*B._A__);
a.__A_=(A.__A_*(B.__A_+B.A___))+(A.___A*B.A___);
a.___A=(A.__A_*(B.___A+B._A__))+(A.___A*B._A__);
return a;
}
void add(int x,int y)
{
c[++num]=(p){h[x],y},h[x]=num;
c[++num]=(p){h[y],x},h[y]=num;
}
int read()
{
int x(0);
char ch=getchar();
for(;!isdigit(ch);ch=getchar());
for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
return x;
}
o Solve(int l,int r)
{
if(l==r) return ouo[l];
return Solve(l,mid)*Solve(mid+1,r);
}
void GET_MERGE(int x)
{
int oao=0;
for(int i=x,siz;i;i=son[i])
{
g.A.resize(1),g.A[0]=1,g.B.resize(2),g.B[0]=0,g.B[1]=val[i];
for(int j=h[i];j;j=c[j].x)
if((c[j].y^son[i])&&(c[j].y^fa[i])) qu.push((q){f[c[j].y].A+f[c[j].y].B,f[c[j].y].A});
siz=qu.size();
while(!qu.empty())
{
G=qu.top(),qu.pop();
if(qu.empty()) break;
G=G*qu.top(),qu.pop(),qu.push(G);
}
if(siz) g=g*G;
ouo[++oao]=(o(g.A,g.B));
}
o _=Solve(1,oao);
f[x]=(q){_.A___+_._A__,_.__A_+_.___A};
}
int main()
{
n=read(),m=read();
int M=1;
for(;M<=n;M<<=1);
Pre(M);
for(int i=1;i<=m<<2;i<<=1)
inv[i]=Pow(3,(mod-1)/i);
for(int i=1;i<=n;++i)
val[i]=read();
for(int i=1;i<n;++i)
add(read(),read());
dfs();
for(int i=n;i>=1;--i)
if(son[fa[re[i]]]!=re[i]) GET_MERGE(re[i]);
if(f[1].A.size()>m) ans+=f[1].A[m];
if(f[1].B.size()>m)
{
ans+=f[1].B[m];
if(ans>=mod) ans-=mod;
}
return printf("%d",ans),0;
}