[LOJ6289]花朵

Posted by Dispwnl on March 7, 2019

题目

题目描述

小 F 的生日还有一个多月,大 F 早早地准备起了礼物。

“你想要什么礼物呀?嗯…要不要好吃的?”

“才不要呢,我想要好看的花,永远不会凋谢的花。”

小 F 和大 F 一起生活的国家—— Fairy 国,可以抽象成一棵 $N$ 个节点的树,每个节点就是一个城市,编号为 $1\ldots N$。

大 F 要游历各个城市,为心爱的小 F 寻找好看的花。

Fairy 国的每个城市都有一座山,山上有恰好一朵永远不会凋谢的花,编号为 $i$ 的城市的花的美丽值为 $B_i$。大 F 要在 $N$ 个城市中选出恰好 $M$ 个,并摘来这 $M$ 个城市中的 $M$ 朵花送给小F。可是呢,如果树上的一条边连接的两个城市的花都被摘去,这条边就会塌陷,Fairy 国就会陷入分裂,大 F 作为一个善良的人,不希望这样的情况发生。所以,一种摘法合法,当且仅当对于每条边,这条边相连的两个节点的花不被同时摘去

大 F 希望小 F 快乐,小 F 的快乐程度将是摘来的 $M$ 朵花的美丽程度的积。大 F 今天闲着没事,想要求出对于所有合法的摘法,小 F 的快乐程度之和对 $998244353$ 取模的结果。

输入格式

第一行两个正整数 $N$ 和 $M$,表示节点的个数与大 F 要为小 F 摘的花的朵数。

第二行 $N$ 个整数 $B_{1\ldots N}$,表示 $N$ 朵花的美丽程度。

接下来 $N-1$ 行,每行两个正整数,描述树上的一条边,保证形成一棵树。

输出格式

一个整数,表示对于所有合法的摘法,小 F 的快乐程度之和对 $998244353$ 取模的结果。

样例

样例输入 1

5 3
3 5 4 8 11
1 2
1 3
2 4
2 5

样例输出 1

616

样例解释 1

有两种选法,选的点集分别是${1,4,5}$ 和 ${3,4,5}$,所以答案是 $3\times 8 \times 11 + 4 \times 8 \times 11$,即 $616$。

样例输入 2

15 6
9 10 2 7 2 4 5 9 3 2 1 9 3 10 7
12 3
4 3
15 8
2 14
7 14
8 14
3 15
6 1
11 1
7 11
9 14
8 5
10 5
13 15

样例输出 2

8214265

样例解释 2

这个样例解释,真要写起来的话就会很长,所以我不解释了,你自己写个暴力看看题意有没有理解错吧 QAQ ,辛苦啦。

数据范围与提示

对于所有数据,保证 $1 \le M \le N \le 8 \times 10^4$,$0 \le B_i < 998244353$。

下表为各个 Subtask 的额外限制与得分,空格表示该项无额外限制。你只有通过一个 Subtask 的所有数据才能得到该 Subtask 的分。

Subtask 编号 $N$ $M$ 特殊限制 分值
1 $\le 500$     7
2 $\le 4000$     15
3   $\le 10$   15
4     $\forall 1\le i < N$,读入的第 $i$ 条边是 $(i,i+1)$ 18
5     $\forall 1\le i < N$,读入的第 $i$ 条边是 $(1,i+1)$ 20
6       25

题解

首先有简单的$dp$方程,$f_{i,j,0/1}$表示$i$号节点子树中选择了$j$个点,$i$号点是否被选的所有答案和

$f_{i,j,1}=\sum_{k=0}^{j}f_{i,k,1}f_{son_i,j-k,0}​$

$f_{i,j,0}=\sum_{k=0}^{j}f_{i,k,0}(f_{son_i,j-k,1}+f_{son_i,j-k,0})$

发现这是个卷积的形式!那么可以$NTT$优化了,复杂度$O(nm\log m)$

期望得分:$7$分太真实了

当然优化优化可能$Subtask\ 3$也能过

$Subtask\ 2$对应的是$N$比较小的情况,卡卡卡卡卡卡卡卡卡卡卡卡卡常可能能过?暂时先放在一边

发现$Subtask\ 4$对应一条链(序列)的情况,可以处理出来每个点的初始值,考虑怎么优化合并

可以分治合并!这样复杂度就是$O(m\log m\times \log n)​$的了

发现$Subtask\ 5$对应一个短腿菊花图的情况,考虑怎么优化合并

可以套个堆!堆每次取出两个最短的多项式然后合并,跟上面的分治差不多一个道理,复杂度$O(m\log m\times \log^2 n)$(大概)

这个复杂度很优秀(能过),考虑能不能对应到一个普通的树上

可以套用树链剖分处理轻链重链的方法,把轻链暴力往上合并(用上面菊花图的方法),对于重链分治处理(用上面链的方法)

这样复杂度就是$O(m\log m\times \log^2n)$了(大概)

然后我卡了一上午的常……最后发现我$NTT$常数太大换了种写法直接优化了$3s$多……

代码

# include<iostream>
# include<cstring>
# include<cstdio>
# include<vector>
# include<queue>
# include<algorithm>
# define mid (l+r>>1)
# define Vo vector<o>
# define Vi vector<int>
using namespace std;
const int MAX=4e5+5,mod=998244353;
struct p{
	int x,y;
}c[MAX];
struct q{
	Vi A,B;
	bool operator< (const q &a)
	const{
		return A.size()>a.A.size();
	}
}f[MAX],g,G;
struct o{
	o(){}
	o(Vi a,Vi b):A___(a),___A(b){}
	Vi A___,_A__,__A_,___A;
}ouo[MAX];
int n,m,num,cnt,lim=1,L,ans,Num;
int R[MAX],h[MAX],fa[MAX],inv[MAX],son[MAX],siz[MAX],id[MAX],val[MAX],re[MAX],a[MAX],b[MAX];
int w[2][MAX];
Vi O;
priority_queue<q> qu;
void dfs(int x=1,int F=0)
{
	re[id[x]=++cnt]=x,siz[x]=1,fa[x]=F;
	for(int i=h[x];i;i=c[i].x)
	  if(c[i].y^F)
	  {
	  	dfs(c[i].y,x),siz[x]+=siz[c[i].y];
	  	if(siz[c[i].y]>siz[son[x]]) son[x]=c[i].y;
	  }
}
int Pow(int a,int b)
{
	int ans=1;
	for(;b;b>>=1,a=1ll*a*a%mod)
	  if(b&1) ans=1ll*ans*a%mod;
	return ans;
}
void Pre(int n)
{
	Num=n;
	int G=Pow(3,(mod-1)/n);
	w[0][0]=w[1][0]=1;
	for(int i=1;i<n;++i)
	  w[1][i]=1ll*w[1][i-1]*G%mod;
	for(int i=1;i<n;++i)
	  w[0][i]=w[1][n-i];
}
void NTT(int *A,int r=1)
{
	for(int i=1;i<lim;++i)
	  if(R[i]>i) swap(A[i],A[R[i]]);
	for(int i=1;i<lim;i<<=1)
	  for(int l=i<<1,j=0;j<lim;j+=l)
		for(int k=0,x,y;k<i;++k)
		  x=A[j+k],y=1ll*A[i+j+k]*w[r][Num/(i<<1)*k]%mod,A[j+k]=(x+y)%mod,A[i+j+k]=(x-y+mod)%mod;
	if(!r) for(int i=0,G=Pow(lim,mod-2);i<lim;++i)
	  A[i]=1ll*A[i]*G%mod;
}
Vi operator* (Vi A,Vi B)
{
	if(!A.size()) return A;
	if(!B.size()) return B;
	int n=A.size(),m=B.size();
	Vi c;
	lim=1,L=0,c.resize(n+m-1);
	while(lim<=n+m-2) lim<<=1,++L;
	for(int i=0;i<=lim;++i)
	  R[i]=(R[i>>1]>>1)|((i&1)<<L-1),a[i]=b[i]=0;
	for(int i=0;i<n;++i)
	  a[i]=A[i];
	for(int i=0;i<m;++i)
	  b[i]=B[i];
	NTT(a),NTT(b);
	for(int i=0;i<=lim;++i)
	  a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,0);
	for(int i=0;i<n+m-1;++i)
	  c[i]=a[i];
	return c;
}
Vi operator+ (Vi A,Vi B)
{
	int n=A.size(),m=B.size();
	if(m>n) A.resize(m);
	for(int i=0;i<m;++i)
	  A[i]=(A[i]+B[i])%mod;
	return A;
}
q operator* (q A,q B) {return (q){A.A*B.A,A.B*B.B};}
o operator* (o A,o B)
{
	o a;
	a.A___=(A.A___*(B.A___+B.__A_))+(A._A__*B.A___);
	a._A__=(A.A___*(B._A__+B.___A))+(A._A__*B._A__);
	a.__A_=(A.__A_*(B.__A_+B.A___))+(A.___A*B.A___);
	a.___A=(A.__A_*(B.___A+B._A__))+(A.___A*B._A__);
	return a;
}
void add(int x,int y)
{
	c[++num]=(p){h[x],y},h[x]=num;
	c[++num]=(p){h[y],x},h[y]=num;
}
int read()
{
	int x(0);
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar());
	for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
	return x;
}
o Solve(int l,int r)
{
	if(l==r) return ouo[l];
	return Solve(l,mid)*Solve(mid+1,r);
}
void GET_MERGE(int x)
{
	int oao=0;
	for(int i=x,siz;i;i=son[i])
	  {
	  	g.A.resize(1),g.A[0]=1,g.B.resize(2),g.B[0]=0,g.B[1]=val[i];
	  	for(int j=h[i];j;j=c[j].x)
	      if((c[j].y^son[i])&&(c[j].y^fa[i])) qu.push((q){f[c[j].y].A+f[c[j].y].B,f[c[j].y].A});
	    siz=qu.size();
		while(!qu.empty())
		{
			G=qu.top(),qu.pop();
			if(qu.empty()) break;
			G=G*qu.top(),qu.pop(),qu.push(G);
		}
		if(siz) g=g*G;
		ouo[++oao]=(o(g.A,g.B));
	  }
	o _=Solve(1,oao);
	f[x]=(q){_.A___+_._A__,_.__A_+_.___A};
}
int main()
{
	n=read(),m=read();
	int M=1;
	for(;M<=n;M<<=1);
	Pre(M);
	for(int i=1;i<=m<<2;i<<=1)
	  inv[i]=Pow(3,(mod-1)/i);
	for(int i=1;i<=n;++i)
	  val[i]=read();
	for(int i=1;i<n;++i)
	  add(read(),read());
	dfs();
	for(int i=n;i>=1;--i)
	  if(son[fa[re[i]]]!=re[i]) GET_MERGE(re[i]);
	if(f[1].A.size()>m) ans+=f[1].A[m];
	if(f[1].B.size()>m)
	{
		ans+=f[1].B[m];
		if(ans>=mod) ans-=mod;
	}
	return printf("%d",ans),0;
}