[HAOI2018]字串覆盖

Posted by Dispwnl on March 13, 2019

题目

题目描述

小C对字符串颇有研究,他觉得传统的字符串匹配太无聊了,于是他想到了这样一个问题.

对于两个长度为 $n$ 的串 $A, B$ , 小C每次会给出给出 $4$ 个参数 $s, t, l, r$ . 令 $A$ 从 $s$ 到 $t$ 的子串(从 $1$ 开始标号)为 $T$,令 $B$ 从 $l$ 到 $r$ 的子串为 $P$.然后他会进行下面的操作:

如果 $T​$ 的某个子串与 $P​$ 相同,我们就可以覆盖 $T​$ 的这个子串,并获得 $K - i​$ 的收益,其中 $i​$ 是初始时 $A​$ 中(注意不是 $T​$ 中)这个子串的起始位置,$K​$是给定的参数.一个位置不能被覆盖多次.覆盖操作可以进行任意多次,你需要输出获得收益的最大值.

注意每次询问都是独立的,即进行一次询问后,覆盖的位置会复原.

输入格式

第一行两个整数 $n, K$ ,表示字符串长度和参数.

接下来一行一个字符串 $A$.

接下来一行一个字符串 $B$ .

接下来一行一个整数 $q$ ,表示询问个数.

接下来 $q$ 行,每行四个整数 $s, t, l, r$ ,表示一次询问.

输出格式

输出 $q$ 行,每行一个整数,表示一个询问的答案.

样例

样例输入

10 11
abcbababab
ababcbabab
5
1 9 7 9
3 10 8 10
1 10 1 2
5 7 2 3
1 5 3 6

样例输出

6
10
22
5
10

样例解释

对于第一组询问 $T$ = abcbababa, $P$ = aba,覆盖加粗部分的子串,收益为$K - 5 = 6$.

对于第二组询问 $T$ = cbababab,$P$ = bab, 收益为$(K - 4) + (K - 8) = 10$.

数据范围与提示

对于所有数据,有 $1 \le n, q \le 10^5$ ,$A, B$ 仅由小写英文字母组成,$1 \le s \le t \le n, 1 \le l \le r \le n, n < K \le 10^9$.

对于 $n = 10^5$ 的测试点,满足 $51 \le r - l \le 2000$ 的询问不超过 $11000$ 个,且 $r - l$ 在该区间内均匀随机.

题解

因为$K>n$,所以不用考虑负数的问题,贪心可得最优答案肯定是从左往右能选则选

设$L=r-l+1$,这样$A$中一个区间$[s,t]$最多能选$\frac{t-s+1}{L}$个串

如果$L$比较大的话就可以暴力贪心了

对$A+B$串建后缀数组,一个子串在上面出现位置在$rank$上肯定是一段连续的区间,所以对$rank$建主席树,每次询问二分找到$rank$区间,然后主席树上二分找出现的最小位置即可

如果$L$比较小,暴力找肯定会TLE,考虑处理出来对于每个$L$,位置$i$对应子串$A[i…i+L-1]$,往后第一个出现相同且不与它重叠的位置$j$,即$A[j…j+L-1]=A[i…i+L-1]$,这样处理出来全部的关系就形成了一棵树,这样每个询问就可以先找到$[s,t]$中第一个位置,然后倍增找到最后一个在区间内的子串位置即可

代码

# include<iostream>
# include<cstring>
# include<cstdio>
# include<vector>
# include<map>
# include<algorithm>
# define tl s[k].l
# define tr s[k].r
# define mid (l+r>>1)
# define LL long long
using namespace std;
const int MAX=2e5+5,Base=233,mod=1e9+7;
struct p{
	int l,r,x;
}s[MAX<<5];
struct o{
	int s,t,id,U;
}qu[MAX];
int n,K,q,m,N,tot;
int het[MAX],sa[MAX],rk[MAX],tax[MAX],h[MAX],rt[MAX],Log[MAX],H[MAX],qwq[MAX],A[MAX],d[MAX];
LL ans[MAX];
int f[19][MAX],fa[19][MAX];
LL Sum[19][MAX];
char a[MAX],b[MAX];
map<int,int> mp;
vector<int> vec[51];
int read()
{
	int x(0);
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar());
	for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
	return x;
}
void rsort()
{
	for(int i=0;i<=m;++i)
	  tax[i]=0;
	for(int i=1;i<=n;++i)
	  ++tax[rk[h[i]]];
	for(int i=1;i<=m;++i)
	  tax[i]+=tax[i-1];
	for(int i=n;i>=1;--i)
	  sa[tax[rk[h[i]]]--]=h[i];
}
int Hash(int l,int r) {return (H[r]-1ll*H[l-1]*qwq[r-l+1]%mod+mod)%mod;}
bool cmp(int *a,int x,int y,int w) {return a[x]==a[y]&&a[x+w]==a[y+w];}
void Suffix()
{
	for(int i=1;i<=n;++i)
	  rk[i]=a[i],h[i]=i;
	m=127,rsort();
	for(int p=1,w=1,i;p<n;w+=w,m=p)
	  {
	  	for(p=0,i=n-w+1;i<=n;++i)
	  	  h[++p]=i;
	  	for(i=1;i<=n;++i)
	  	  if(sa[i]>w) h[++p]=sa[i]-w;
	  	rsort(),swap(rk,h),rk[sa[1]]=p=1;
	  	for(i=2;i<=n;++i)
	  	  rk[sa[i]]=cmp(h,sa[i],sa[i-1],w)?p:++p;
	  }
	int j,k=0;
	for(int i=1;i<=n;het[rk[i++]]=k)
	  for(k=k?k-1:k,j=sa[rk[i]-1];a[j+k]==a[i+k];++k);
}
void change(int l,int r,int &k,int pre,int x)
{
	s[k=++tot]=s[pre],++s[k].x;
	if(l==r) return;
	if(x<=mid) change(l,mid,tl,s[pre].l,x);
	else change(mid+1,r,tr,s[pre].r,x);
}
int ask(int l,int r,int pre,int k,int L,int R)
{
	if(!k||r<L||R<l||!(s[k].x-s[pre].x)) return -1;
	if(l==r) return l;
	int x=ask(l,mid,s[pre].l,tl,L,R);
	if(x!=-1) return x;
	return ask(mid+1,r,s[pre].r,tr,L,R);
}
void Init()
{
	for(int i=1;i<=n;++i)
	  f[0][i]=het[i];
	for(int i=1;i<=18;++i)
	  for(int j=0;j<=n;++j)
	    f[i][j]=min(f[i-1][j],f[i-1][min(n,j+(1<<i-1))]);
}
int GET_LCP(int l,int r)
{
	if(l==r) return n-sa[l]+1;
	int t=Log[r-l];
	return min(f[t][l+1],f[t][r-(1<<t)+1]);
}
int main()
{
	n=N=read(),K=read(),scanf("%s%s",a+1,b+1);
	Log[0]=-1,a[++n]='#',H[0]=1,qwq[0]=1;
	for(int i=1;i<=N;++i)
	  a[++n]=b[i],H[i]=(1ll*H[i-1]*Base%mod+a[i])%mod,qwq[i]=1ll*qwq[i-1]*Base%mod;
	for(int i=1;i<=n;++i)
	  Log[i]=Log[i>>1]+1;
	Suffix(),Init(),q=read();
	for(int i=1;i<=n;++i)
	  change(1,n,rt[i],rt[i-1],sa[i]);
	for(int i=1,s,t,l,r,L,x;i<=q;++i)
	  {
	  	s=read(),t=read(),l=read(),r=read(),L=r-l+1,x=0;
	  	if(L<=50) {vec[L].push_back(i),qu[i]=(o){s,t,i};}
		int ll=rk[N+1+l],rr=n,UR,UL,Ls=ll;
		while(ll<=rr)
		{
			int Mid(ll+rr>>1);
			if(GET_LCP(Ls,Mid)>=L) UR=Mid,ll=Mid+1;
			else rr=Mid-1;
		}
		ll=1,rr=rk[N+1+l];
		while(ll<=rr)
		{
			int Mid(ll+rr>>1);
			if(GET_LCP(Mid,Ls)>=L) UL=Mid,rr=Mid-1;
			else ll=Mid+1;
		}
		while(x!=-1&&s<=t-L+1)
		{
			x=ask(1,n,rt[UL-1],rt[UR],s,t-L+1);
			if(L<=50) {qu[i].U=x;break;}
			if(x!=-1) ans[i]+=K-x,s=x+L;
		}
	  }
	d[0]=1;
	for(int L=1,D,t;L<=50;++L)
	  if(vec[L].size())
	  {
	  	D=0,mp.clear();
	  	for(int i=N-L+1,R=N-L+1,x;i>=1;--i)
	  	  {
	  	  	if(i<=N-2*L+1) mp[A[R]]=R,--R;
	  	  	A[i]=Hash(i,i+L-1),x=mp[A[i]],fa[0][i]=x,Sum[0][i]=K-i,D=max(D,d[i]=d[x]+1);
		  }
		t=((1<<Log[D])==D)?Log[D]:Log[D]+1;
		for(int i=1;i<=t;++i)
		  for(int j=1;j<=N;++j)
		    fa[i][j]=fa[i-1][fa[i-1][j]],Sum[i][j]=Sum[i-1][j]+Sum[i-1][fa[i-1][j]];
		for(int i=0,id,cnt=vec[L].size(),_;i<cnt;++i)
		  {
		  	id=vec[L][i],_=qu[id].U;
		  	for(int j=t;j>=0;--j)
		  	  if(fa[j][_]&&fa[j][_]<=qu[id].t-L+1) ans[qu[id].id]+=Sum[j][_],_=fa[j][_];
		  	ans[qu[id].id]+=Sum[0][_];
		  }
	  }
	for(int i=1;i<=q;++i)
	  printf("%lld\n",ans[i]);
	return 0;
}