题目
Description
Flute 很喜欢柠檬。
它准备了一串用树枝串起来的贝壳,打算用一种魔法把贝壳变成柠檬。贝壳一共有 N (1 ≤ N ≤ 100,000) 只,按顺序串在树枝上。
为了方便,我们从左到右给贝壳编号 1..N。每只贝壳的大小不一定相同, 贝壳 i 的大小为 si(1 ≤ si ≤10,000)。变柠檬的魔法要求,Flute 每次从树枝一端取下一小段连续的贝壳,并 选择一种贝壳的大小 s0。如果 这一小段贝壳中 大小为 s0 的贝壳有 t 只,那么魔法可以把这一小段贝壳变成 s 0t^2 只柠檬。
Flute 可以取任意多次贝壳,直到树枝上的贝壳被全部取完。各个小段中,Flute 选择的贝壳大小 s 0 可以不同。而最终 Flute 得到的柠檬数,就是所有小段柠檬数的总和。Flute 想知道,它最多能用这一串贝壳 变出多少柠檬。请你帮忙解决这个问题。
Input
第 1 行:一个整数,表示 N。
第 2 .. N + 1 行:每行一个整数,第 i + 1 行表示 si。
Output
仅一个整数,表示 Flute 最多能得到的柠檬数。
Sample Input
5
2
2
5
2
3
Sample Output
21
题解
想不到思路……好巧妙啊
首先有个重要的性质:
从$i$开始的一串贝壳选中的大小肯定为$s_i$
因为ta没要求划分多少段,所以根据贪心的思想可以得到这个结论
所以就有了$O(n^2)$的方程: $f_i=max(f_i,f_{j-1}+s_i\times (num_i-num_j+1)^2),s_j=s_i,j\leq i$
怎么优化呢?
发现有点斜率优化的样子,尝试化简:
$f_i=f_{j-1}+s_i\times num_i^2+s_i\times (num_j-1)^2-2s_i\times num_i(num_j-1)$
$f_{j-1}+s_i\times (num_j-1)^2=num_i\times 2s_i(num_j-1)+f_i-s_i\times num_i^2$ $Y k X b$
发现$k$和$X$都单调递增,要求的是f_i的最大值,所以可以用单调栈(用$vector$代替)维护一下,在栈上二分一下$j$的位置即可
代码
# include<iostream>
# include<cstring>
# include<cstdio>
# include<vector>
# include<algorithm>
# define LL long long
using namespace std;
const int MAX=1e5+5;
int n;
int s[MAX],sum[MAX],Sum[MAX];
LL f[MAX];
vector<int> ve[MAX];
int read()
{
int x(0);
char ch=getchar();
for(;!isdigit(ch);ch=getchar());
for(;isdigit(ch);x=x*10+ch-48,ch=getchar());
return x;
}
LL mi_2(int x)
{
return 1ll*x*x;
}
LL X(int i)
{
return 2ll*(Sum[i]-1);
}
LL Y(int i)
{
return f[i-1]+s[i]*mi_2(Sum[i]-1);
}
bool calc(int a,int b,int c)
{
return (X(a)-X(b))*(Y(c)-Y(b))-(Y(a)-Y(b))*(X(c)-X(b))>0;
}
int main()
{
n=read();
for(int i=1;i<=n;++i)
s[i]=read(),Sum[i]=++sum[s[i]];
for(int i=1,qwq;i<=n;++i)
{
qwq=ve[s[i]].size()-1;
while(qwq>0&&calc(i,ve[s[i]][qwq],ve[s[i]][qwq-1])) --qwq,ve[s[i]].pop_back();
ve[s[i]].push_back(i);
int l=1,r=qwq+1,ans=0;
while(l<=r)
{
int mid=(l+r>>1),qaq=ve[s[i]][mid],qvq=ve[s[i]][mid-1];
if(f[qaq-1]+s[i]*mi_2(Sum[i]-Sum[qaq]+1)>f[qvq-1]+s[i]*mi_2(Sum[i]-Sum[qvq]+1)) ans=mid,l=mid+1;
else r=mid-1;
}
f[i]=f[ve[s[i]][ans]-1]+s[i]*mi_2(Sum[i]-Sum[ve[s[i]][ans]]+1);
}
return printf("%lld",f[n]),0;
}