[SDOI2016]模式字符串

Posted by Dispwnl on April 14, 2019

题目

题目描述

给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m的模式串s,其中每一位仍然是A到z的大写字母。

Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径形成的字符串可以由模式串S重复若干次得到?

这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分。

所谓模式串的重复,是将若干个模式串S依次相接(不能重叠)。例如当S=PLUS的时候,重复两次会得到PLUSPLUS,重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以XYXYXY不能看作是S重复若干次得到的。

输入输出格式

输入格式:

每一个数据有多组测试,

第一行输入一个整数C,表示总的测试个数。

对于每一组测试来说:

第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,

之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点)。

之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S。

输出格式:

给出C行,对应C组测试。

每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模式串的若干次重复.

输入输出样例

输入样例#1:

1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI

输出样例#1:

5

说明

1<=C<=10,3<=$\sum N$<=1000000,3<=$\sum M$<=1000000

题解

思路很简单,点分治统计路径,用$Hash$判断是否能拼成一个完整的模式串

开始处理边边角角的条件自闭了……其实直接处理出每个深度的$Hash$值即可

代码

# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
# define LL long long
using namespace std;
const int MAX=1e5+5,mod=998244353,Base=233;
struct p{
	int x,y;
}c[MAX<<1];
int n,m,cnt,cnt1,num,T,Sum,rt,maxn;
LL ans;
int h[MAX],F[MAX],G[MAX],siz[MAX],f[MAX],qwq[MAX],fx[MAX],gx[MAX],h1[MAX],h2[MAX];
char s[MAX],S[MAX];
bool use[MAX];
int read()
{
	int x(0);
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar());
	for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
	return x;
}
void add(int x,int y)
{
	c[++num]=(p){h[x],y},h[x]=num;
	c[++num]=(p){h[y],x},h[y]=num;
}
void GET_ROOT(int x=1,int fa=0)
{
	siz[x]=1,f[x]=0;
	for(int i=h[x];i;i=c[i].x)
	  if(!use[c[i].y]&&(c[i].y^fa)) GET_ROOT(c[i].y,x),siz[x]+=siz[c[i].y],f[x]=max(f[x],siz[c[i].y]);
	f[x]=max(f[x],Sum-siz[x]);
	if(f[x]<f[rt]) rt=x;
}
void GET_NUM(int x,int fa,int d,int hash)
{
	maxn=max(maxn,d);
	int tt=(d-1)%m;
	if(hash==h1[d]) ++fx[tt],ans+=G[m-tt-1];
	if(hash==h2[d]) ++gx[tt],ans+=F[m-tt-1];
	for(int i=h[x];i;i=c[i].x)
	  if(!use[c[i].y]&&(c[i].y^fa)) GET_NUM(c[i].y,x,d+1,(1ll*hash*Base+s[c[i].y])%mod);
}
void ask(int x)
{
	G[0]=F[0]=1;
	int D=0;
	for(int i=h[x],M;i;i=c[i].x)
	  if(!use[c[i].y])
	  {
	  	maxn=0,GET_NUM(c[i].y,x,2,s[x]*Base+s[c[i].y]),M=min(m-1,maxn),D=max(M,D);
	  	for(int i=0;i<=M;++i)
	  	  G[i]+=gx[i],F[i]+=fx[i],gx[i]=fx[i]=0;
	  }
	for(int i=0;i<=D;++i)
	  G[i]=F[i]=0;
}
void dfs(int x=rt)
{
	use[x]=1,ask(x);
	for(int i=h[x];i;i=c[i].x)
	  if(!use[c[i].y]) f[rt=0]=Sum=siz[c[i].y],GET_ROOT(c[i].y,x),dfs();
}
int main()
{
	T=read(),qwq[0]=1;
	int Mx=0;
	while(T--)
	{
		n=read(),m=read(),scanf("%s",s+1);
		memset(h,0,sizeof(h));
		memset(use,0,sizeof(use));
		num=ans=0;
		for(int i=Mx+1;i<=n;++i)
		  qwq[i]=1ll*qwq[i-1]*Base%mod;
		Mx=max(Mx,n);
		for(int i=1;i<n;++i)
		  add(read(),read());
		scanf("%s",S+1);
		for(int i=1;i<=n;++i)
		  h1[i]=(h1[i-1]+1ll*S[(i-1)%m+1]*qwq[i-1])%mod,h2[i]=(h2[i-1]+1ll*S[m-(i-1)%m]*qwq[i-1])%mod;
		f[rt]=Sum=n,GET_ROOT(),dfs(),printf("%lld\n",ans);
	}
	return 0;
}