题目
题目描述
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m的模式串s,其中每一位仍然是A到z的大写字母。
Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径形成的字符串可以由模式串S重复若干次得到?
这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分。
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠)。例如当S=PLUS
的时候,重复两次会得到PLUSPLUS
,重复三次会得到PLUSPLUSPLUS
,同时要注恿,重复必须是整数次的。例如当S=XYXY
时,因为必须重复整数次,所以XYXYXY
不能看作是S重复若干次得到的。
输入输出格式
输入格式:
每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点)。
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S。
输出格式:
给出C行,对应C组测试。
每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模式串的若干次重复.
输入输出样例
输入样例#1:
1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI
输出样例#1:
5
说明
1<=C<=10,3<=$\sum N$<=1000000,3<=$\sum M$<=1000000
题解
思路很简单,点分治统计路径,用$Hash$判断是否能拼成一个完整的模式串
开始处理边边角角的条件自闭了……其实直接处理出每个深度的$Hash$值即可
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
# define LL long long
using namespace std;
const int MAX=1e5+5,mod=998244353,Base=233;
struct p{
int x,y;
}c[MAX<<1];
int n,m,cnt,cnt1,num,T,Sum,rt,maxn;
LL ans;
int h[MAX],F[MAX],G[MAX],siz[MAX],f[MAX],qwq[MAX],fx[MAX],gx[MAX],h1[MAX],h2[MAX];
char s[MAX],S[MAX];
bool use[MAX];
int read()
{
int x(0);
char ch=getchar();
for(;!isdigit(ch);ch=getchar());
for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
return x;
}
void add(int x,int y)
{
c[++num]=(p){h[x],y},h[x]=num;
c[++num]=(p){h[y],x},h[y]=num;
}
void GET_ROOT(int x=1,int fa=0)
{
siz[x]=1,f[x]=0;
for(int i=h[x];i;i=c[i].x)
if(!use[c[i].y]&&(c[i].y^fa)) GET_ROOT(c[i].y,x),siz[x]+=siz[c[i].y],f[x]=max(f[x],siz[c[i].y]);
f[x]=max(f[x],Sum-siz[x]);
if(f[x]<f[rt]) rt=x;
}
void GET_NUM(int x,int fa,int d,int hash)
{
maxn=max(maxn,d);
int tt=(d-1)%m;
if(hash==h1[d]) ++fx[tt],ans+=G[m-tt-1];
if(hash==h2[d]) ++gx[tt],ans+=F[m-tt-1];
for(int i=h[x];i;i=c[i].x)
if(!use[c[i].y]&&(c[i].y^fa)) GET_NUM(c[i].y,x,d+1,(1ll*hash*Base+s[c[i].y])%mod);
}
void ask(int x)
{
G[0]=F[0]=1;
int D=0;
for(int i=h[x],M;i;i=c[i].x)
if(!use[c[i].y])
{
maxn=0,GET_NUM(c[i].y,x,2,s[x]*Base+s[c[i].y]),M=min(m-1,maxn),D=max(M,D);
for(int i=0;i<=M;++i)
G[i]+=gx[i],F[i]+=fx[i],gx[i]=fx[i]=0;
}
for(int i=0;i<=D;++i)
G[i]=F[i]=0;
}
void dfs(int x=rt)
{
use[x]=1,ask(x);
for(int i=h[x];i;i=c[i].x)
if(!use[c[i].y]) f[rt=0]=Sum=siz[c[i].y],GET_ROOT(c[i].y,x),dfs();
}
int main()
{
T=read(),qwq[0]=1;
int Mx=0;
while(T--)
{
n=read(),m=read(),scanf("%s",s+1);
memset(h,0,sizeof(h));
memset(use,0,sizeof(use));
num=ans=0;
for(int i=Mx+1;i<=n;++i)
qwq[i]=1ll*qwq[i-1]*Base%mod;
Mx=max(Mx,n);
for(int i=1;i<n;++i)
add(read(),read());
scanf("%s",S+1);
for(int i=1;i<=n;++i)
h1[i]=(h1[i-1]+1ll*S[(i-1)%m+1]*qwq[i-1])%mod,h2[i]=(h2[i-1]+1ll*S[m-(i-1)%m]*qwq[i-1])%mod;
f[rt]=Sum=n,GET_ROOT(),dfs(),printf("%lld\n",ans);
}
return 0;
}