题目
Description
Input
第一行包含一个整数n(≤100000)。
第二行是长度为n的由0到9组成的字符串。
第三行是一个整数m。
接下来m≤5·10行,第i行是一个由0到9组成的字符串s,保证单行字符串长度小于等于10^5,所有字符串长度和小于等于3·10^6
Output
输出m行,第i行表示第si和S匹配所比较的次数。
Sample Input
7
1090901
4
87650
0901
109
090
Sample Output
7
10
3
4
题解
首先发现答案就是$B$串每个前缀在$A$串中出现的次数和(匹配成功之前)
所以可以先找到$B$在$A$中最早匹配成功的位置,现在只需要查询某个节点子树里节点在$A$串中位置在某个区间中的权值和,也就是查询某个子串在$A$串中某个区间出现次数
从叶子开始一路线段树合并即可,线段树要开$\log n$倍的空间……
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
# define tl s[k].l
# define tr s[k].r
# define mid (l+r>>1)
using namespace std;
const int MAX=2e5+5;
int n,N;
char a[MAX];
struct SAM{
int l,r,L,tot;
int len[MAX],fa[MAX],rt[MAX],id[MAX],pos[MAX],tax[MAX];
int son[MAX][10];
SAM() {l=r=1;}
struct p{
int x,l,r;
}s[MAX<<5];
void pus(int k) {s[k].x=s[tl].x+s[tr].x;}
void change(int l,int r,int &k,int x)
{
if(!k) k=++tot;
++s[k].x;
if(l==r) return;
if(x<=mid) change(l,mid,tl,x);
else change(mid+1,r,tr,x);
}
void ins(int x,int id)
{
int tt=r;
len[r=++l]=++L,change(1,n,rt[r],id),pos[r]=id;
for(;tt&&!son[tt][x];tt=fa[tt])
son[tt][x]=r;
if(!tt) return void(fa[r]=1);
int qwq=son[tt][x];
if(len[qwq]==len[tt]+1) return void(fa[r]=qwq);
len[++l]=len[tt]+1,fa[l]=fa[qwq],fa[qwq]=fa[r]=l;
for(int i=0;i<10;++i)
son[l][i]=son[qwq][i];
for(int i=tt;son[i][x]==qwq;i=fa[i])
son[i][x]=l;
}
int merge(int x,int y)
{
if(!x||!y) return x+y;
int k=++tot;
return tl=merge(s[x].l,s[y].l),tr=merge(s[x].r,s[y].r),pus(k),k;
}
int ask(int l,int r,int k,int L,int R)
{
if(l>R||L>r||!k) return 0;
if(l>=L&&r<=R) return s[k].x;
return ask(l,mid,tl,L,R)+ask(mid+1,r,tr,L,R);
}
void Init()
{
for(int i=1;i<=l;++i)
++tax[len[i]];
for(int i=1;i<=n;++i)
tax[i]+=tax[i-1];
for(int i=l;i>=1;--i)
id[tax[len[i]]--]=i;
for(int i=l;i>=1;--i)
if(fa[id[i]]) pos[fa[id[i]]]=min(pos[fa[id[i]]],pos[id[i]]),rt[fa[id[i]]]=merge(rt[fa[id[i]]],rt[id[i]]);
}
int GET_LEN(int _n)
{
int x=1;
for(int i=1;i<=_n&&x;++i)
x=son[x][a[i]-'0'];
return pos[x];
}
int Solve(int _n)
{
int _L=GET_LEN(_n),ans=0;
for(int i=1,x=1;i<_n;++i)
{
x=son[x][a[i]-'0'];
if(!x) break;
ans+=ask(1,n,rt[x],1,_L?_L-_n+i:n);
}
return ans+(_L?_L-_n+1:n);
}
}_S;
int main()
{
scanf("%d%s",&n,a+1);
memset(_S.pos,1,sizeof(_S.pos)),_S.pos[0]=0;
for(int i=1;i<=n;++i)
_S.ins(a[i]-'0',i);
_S.Init(),scanf("%d",&N);
for(int i=1,_n;i<=N;++i)
scanf("%s",a+1),_n=strlen(a+1),printf("%d\n",_S.Solve(_n));
return 0;
}