题目
题目描述
图片不放了烦死
给定一个只含$a,b$的字符串,在其中选取一个子序列,使得:
- 位置和字符都关于某条对称轴对称
- 不能是连续的一段
输入输出样例
输入样例#1:
abaabaa
输出样例#1:
14
输入样例#2:
aaabbbaaa
输出样例#2:
44
输入样例#3:
aaaaaaaa
输出样例#3:
53
题解
发现要求不连续的是很难求的,所以可以求所有的回文子序列然后减去所有的回文子串
回文子串可以用$manacher$求,现在问题就是怎么求所有的回文子序列
考虑对每个对称轴统计答案
所以假设对称轴$\frac{i+j}{2}$两旁共有$n$对相同位置的相同字符,答案就是$\sum_{i=1}^{n}C_{n}^{i}=2^n-1$(单个字符单独处理)
假设$f_{\frac{i+j}{2}}$为对称轴$\frac{i+j}{2}$两旁相同位置的相同字符的对数,发现这个除以二很烦,直接统一乘个二消掉变成$f_{i+j}$
然后就变成了字符串匹配问题……直接枚举字符卷积就行了
开始因为有模数非常傻逼的敲了个$NTT$……
$fuge$:你写$NTT$干啥……对数最大才$1e5$啊
然后懒得改了,结果模数混了WA
了好几发……
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
using namespace std;
const int MAX=4e5+5,mod=998244353,Mod=1e9+7;
int n,m,lim=1,L,ans,cnt;
int a[MAX],b[MAX],R[MAX],val[MAX],hw[MAX];
char s[MAX],ss[MAX];
int Pow(int a,int b,int MOD=mod)
{
int ans=1;
for(;b;b>>=1,a=1ll*a*a%MOD)
if(b&1) ans=1ll*ans*a%MOD;
return ans;
}
void NTT(int *A,int tt=1)
{
for(int i=0;i<lim;++i)
if(i<R[i]) swap(A[i],A[R[i]]);
for(int i=1,w1;i<lim;i<<=1)
{
w1=Pow(3,(mod-1)/(i<<1));
for(int l=i<<1,j=0,w;j<lim;j+=l)
{
w=1;
for(int k=0,x,y;k<i;++k,w=1ll*w*w1%mod)
{
x=A[j+k],y=1ll*w*A[i+j+k]%mod,A[j+k]=x+y,A[i+j+k]=x-y+mod;
if(A[j+k]>=mod) A[j+k]-=mod;
if(A[i+j+k]>=mod) A[i+j+k]-=mod;
}
}
}
if(tt==-1)
{
reverse(A+1,A+lim);
for(int i=0,G=Pow(lim,mod-2);i<lim;++i)
A[i]=1ll*A[i]*G%mod;
}
}
void NTT(char x)
{
memset(a,0,lim<<2);
memset(b,0,lim<<2);
for(int i=0;i<n;++i)
if(s[i]==x) a[i]=b[i]=1;
NTT(a),NTT(b);
for(int i=0;i<=lim;++i)
a[i]=1ll*a[i]*b[i]%mod;
NTT(a,-1);
for(int i=0;i<=2*n-2;++i)
val[i]+=a[i];
}
void Init()
{
ss[0]='@',ss[++cnt]='#';
for(int i=0;i<n;++i)
ss[++cnt]=s[i],ss[++cnt]='#';
ss[cnt+1]='$';
}
void Manacher()
{
Init();
for(int i=1,Mr=0,mid;i<=cnt;++i)
{
if(i<Mr) hw[i]=min(hw[(mid<<1)-i],hw[mid]+mid-i);
else hw[i]=1;
for(;ss[i-hw[i]]==ss[i+hw[i]];++hw[i]);
if(i+hw[i]>Mr) Mr=i+hw[i],mid=i;
}
for(int i=1;i<=cnt;++i)
ans=(ans-(((hw[i]-1-!isalpha(ss[i]))>>1)+1)+Mod)%Mod;
}
int main()
{
scanf("%s",s),n=strlen(s);
while(lim<=2*n-2) lim<<=1,++L;
for(int i=0;i<=lim;++i)
R[i]=(R[i>>1]>>1)|((i&1)<<L-1);
NTT('a'),NTT('b');
for(int i=0,x;i<=2*n-2;++i)
{
if(i&1) val[i]>>=1;
else --val[i],val[i]>>=1;
x=(Pow(2,val[i],Mod)-1+Mod)%Mod;
if(!(i&1)) x=1ll*x*2%Mod;
ans+=x;
if(ans>=Mod) ans-=Mod;
}
return Manacher(),printf("%d",(ans+n)%Mod),0;
}