[BZOJ3413]匹配

Posted by Dispwnl on January 9, 2019

题目

Description

Input

第一行包含一个整数n(≤100000)。

第二行是长度为n的由0到9组成的字符串。

第三行是一个整数m。

接下来m≤5·10行,第i行是一个由0到9组成的字符串s,保证单行字符串长度小于等于10^5,所有字符串长度和小于等于3·10^6

Output

输出m行,第i行表示第si和S匹配所比较的次数。

Sample Input

7
1090901
4
87650
0901
109
090

Sample Output

7
10
3
4

题解

首先发现答案就是$B$串每个前缀在$A$串中出现的次数和(匹配成功之前)

所以可以先找到$B$在$A$中最早匹配成功的位置,现在只需要查询某个节点子树里节点在$A$串中位置在某个区间中的权值和,也就是查询某个子串在$A$串中某个区间出现次数

从叶子开始一路线段树合并即可,线段树要开$\log n$倍的空间……

代码

# include<iostream>
# include<cstring>
# include<cstdio>
# include<algorithm>
# define tl s[k].l
# define tr s[k].r
# define mid (l+r>>1) 
using namespace std;
const int MAX=2e5+5;
int n,N;
char a[MAX];
struct SAM{
    int l,r,L,tot;
    int len[MAX],fa[MAX],rt[MAX],id[MAX],pos[MAX],tax[MAX];
    int son[MAX][10];
    SAM() {l=r=1;}
    struct p{
        int x,l,r;
    }s[MAX<<5];
    void pus(int k) {s[k].x=s[tl].x+s[tr].x;}
    void change(int l,int r,int &k,int x)
    {
        if(!k) k=++tot;
        ++s[k].x;
        if(l==r) return;
        if(x<=mid) change(l,mid,tl,x);
        else change(mid+1,r,tr,x);
    }
    void ins(int x,int id)
    {
        int tt=r;
        len[r=++l]=++L,change(1,n,rt[r],id),pos[r]=id;
        for(;tt&&!son[tt][x];tt=fa[tt])
          son[tt][x]=r;
        if(!tt) return void(fa[r]=1);
        int qwq=son[tt][x];
        if(len[qwq]==len[tt]+1) return void(fa[r]=qwq);
        len[++l]=len[tt]+1,fa[l]=fa[qwq],fa[qwq]=fa[r]=l;
        for(int i=0;i<10;++i)
          son[l][i]=son[qwq][i];
        for(int i=tt;son[i][x]==qwq;i=fa[i])
          son[i][x]=l;
    }
    int merge(int x,int y)
    {
        if(!x||!y) return x+y;
        int k=++tot;
        return tl=merge(s[x].l,s[y].l),tr=merge(s[x].r,s[y].r),pus(k),k;
    }
    int ask(int l,int r,int k,int L,int R)
    {
        if(l>R||L>r||!k) return 0;
        if(l>=L&&r<=R) return s[k].x;
        return ask(l,mid,tl,L,R)+ask(mid+1,r,tr,L,R);
    }
    void Init()
    {
        for(int i=1;i<=l;++i)
          ++tax[len[i]];
        for(int i=1;i<=n;++i)
          tax[i]+=tax[i-1];
        for(int i=l;i>=1;--i)
          id[tax[len[i]]--]=i;
        for(int i=l;i>=1;--i)
          if(fa[id[i]]) pos[fa[id[i]]]=min(pos[fa[id[i]]],pos[id[i]]),rt[fa[id[i]]]=merge(rt[fa[id[i]]],rt[id[i]]);
    }
    int GET_LEN(int _n)
    {
        int x=1;
        for(int i=1;i<=_n&&x;++i)
          x=son[x][a[i]-'0'];
        return pos[x];
    }
    int Solve(int _n)
    {
        int _L=GET_LEN(_n),ans=0;
        for(int i=1,x=1;i<_n;++i)
          {
            x=son[x][a[i]-'0'];
            if(!x) break;
            ans+=ask(1,n,rt[x],1,_L?_L-_n+i:n);
          }
        return ans+(_L?_L-_n+1:n);
    }
}_S;
int main()
{
    scanf("%d%s",&n,a+1);
    memset(_S.pos,1,sizeof(_S.pos)),_S.pos[0]=0;
    for(int i=1;i<=n;++i)
      _S.ins(a[i]-'0',i);
    _S.Init(),scanf("%d",&N);
    for(int i=1,_n;i<=N;++i)
      scanf("%s",a+1),_n=strlen(a+1),printf("%d\n",_S.Solve(_n));
    return 0;
}