[BZOJ3160]万径人踪灭

Posted by Dispwnl on March 1, 2019

题目

题目描述

图片不放了烦死

给定一个只含$a,b$的字符串,在其中选取一个子序列,使得:

  • 位置和字符都关于某条对称轴对称
  • 不能是连续的一段

输入输出样例

输入样例#1:

abaabaa

输出样例#1:

14

输入样例#2:

aaabbbaaa

输出样例#2:

44

输入样例#3:

aaaaaaaa

输出样例#3:

53

题解

发现要求不连续的是很难求的,所以可以求所有的回文子序列然后减去所有的回文子串

回文子串可以用$manacher​$求,现在问题就是怎么求所有的回文子序列

考虑对每个对称轴统计答案

所以假设对称轴$\frac{i+j}{2}$两旁共有$n$对相同位置的相同字符,答案就是$\sum_{i=1}^{n}C_{n}^{i}=2^n-1$(单个字符单独处理)

假设$f_{\frac{i+j}{2}}$为对称轴$\frac{i+j}{2}$两旁相同位置的相同字符的对数,发现这个除以二很烦,直接统一乘个二消掉变成$f_{i+j}$

然后就变成了字符串匹配问题……直接枚举字符卷积就行了

开始因为有模数非常傻逼的敲了个$NTT$……

$fuge$:你写$NTT$干啥……对数最大才$1e5$啊

然后懒得改了,结果模数混了WA了好几发……

代码

# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
using namespace std;
const int MAX=4e5+5,mod=998244353,Mod=1e9+7;
int n,m,lim=1,L,ans,cnt;
int a[MAX],b[MAX],R[MAX],val[MAX],hw[MAX];
char s[MAX],ss[MAX];
int Pow(int a,int b,int MOD=mod)
{
	int ans=1;
	for(;b;b>>=1,a=1ll*a*a%MOD)
	  if(b&1) ans=1ll*ans*a%MOD;
	return ans;
}
void NTT(int *A,int tt=1)
{
	for(int i=0;i<lim;++i)
	  if(i<R[i]) swap(A[i],A[R[i]]);
	for(int i=1,w1;i<lim;i<<=1)
	  {
	  	w1=Pow(3,(mod-1)/(i<<1));
	  	for(int l=i<<1,j=0,w;j<lim;j+=l)
	  	  {
	  	  	w=1;
	  	  	for(int k=0,x,y;k<i;++k,w=1ll*w*w1%mod)
	  	  	  {
	  	  	  	x=A[j+k],y=1ll*w*A[i+j+k]%mod,A[j+k]=x+y,A[i+j+k]=x-y+mod;
	  	  	  	if(A[j+k]>=mod) A[j+k]-=mod;
	  	  	  	if(A[i+j+k]>=mod) A[i+j+k]-=mod;
			  }
		  }
	  }
	if(tt==-1)
	{
		reverse(A+1,A+lim);
		for(int i=0,G=Pow(lim,mod-2);i<lim;++i)
		  A[i]=1ll*A[i]*G%mod;
	}
}
void NTT(char x)
{
	memset(a,0,lim<<2);
	memset(b,0,lim<<2);
	for(int i=0;i<n;++i)
	  if(s[i]==x) a[i]=b[i]=1;
	NTT(a),NTT(b);
	for(int i=0;i<=lim;++i)
	  a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,-1);
	for(int i=0;i<=2*n-2;++i)
	  val[i]+=a[i];
}
void Init()
{
	ss[0]='@',ss[++cnt]='#';
	for(int i=0;i<n;++i)
	  ss[++cnt]=s[i],ss[++cnt]='#';
	ss[cnt+1]='$';
}
void Manacher()
{
	Init();
	for(int i=1,Mr=0,mid;i<=cnt;++i)
	  {
	  	if(i<Mr) hw[i]=min(hw[(mid<<1)-i],hw[mid]+mid-i);
	  	else hw[i]=1;
	  	for(;ss[i-hw[i]]==ss[i+hw[i]];++hw[i]);
	  	if(i+hw[i]>Mr) Mr=i+hw[i],mid=i;
	  }
	for(int i=1;i<=cnt;++i)
	  ans=(ans-(((hw[i]-1-!isalpha(ss[i]))>>1)+1)+Mod)%Mod;
}
int main()
{
	scanf("%s",s),n=strlen(s);
	while(lim<=2*n-2) lim<<=1,++L;
	for(int i=0;i<=lim;++i)
	  R[i]=(R[i>>1]>>1)|((i&1)<<L-1);
	NTT('a'),NTT('b');
	for(int i=0,x;i<=2*n-2;++i)
	  {
	  	if(i&1) val[i]>>=1;
	  	else --val[i],val[i]>>=1;
	  	x=(Pow(2,val[i],Mod)-1+Mod)%Mod;
	  	if(!(i&1)) x=1ll*x*2%Mod;
	  	ans+=x;
	  	if(ans>=Mod) ans-=Mod;
	  }
	return Manacher(),printf("%d",(ans+n)%Mod),0;
}