In mathematics you don't understand things, you just get used to them.

Description

Link.

对于一棵树,选出一条链 $(u,v)$,把链上结点从 $u$ 到 $v$ 放成一个 长度 $l$ 的数组,使得 $\sum_{i=1}^{l}\sum_{j=1}^{i}a_{j}$ 最大,$a$ 是点权。

Solution

可以发现那个式子等价于 $\sum_{i=1}^{l}ia_{i}$。

考虑点分,设当前根为 $x$。选出来的 $u,v$ 一定是叶子(点权为正),因为没有什么本质差别,所以可以一起算。我们把 $x$ 在 $(u,v)$ 中的位置记作 $o$,$(u,v)$ 的权值就为 $\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i}+l\sum_{i=o+1}^{l}a_{i}+\sum_{i=o+1}^{l}(i-l)a_{i}$,这是个一次函数,令 $b_{1}=\sum_{i=1}^{l}ia_{i}=\sum_{i=1}^{o}ia_{i},b_{2}=\sum_{i=o+1}^{l}(i-l)a_{i},k=l$,得 $\sum_{i=1}^{l}ia_{i}=k\times\sum_{i=o+1}^{l}a_{i}+b_{1}+b_{2}$。

#include<bits/stdc++.h>
typedef long long ll;
#define sf(x) scanf("%d",&x)
#define ssf(x) scanf("%lld",&x)
struct Line {
    ll k,b;
    Line():k(0),b(0){}
    Line(ll _k,ll _b):k(_k),b(_b){}
}lns[10000010];
std::vector<int> G[200010];
ll a[200010],ans,stk[6][200010];
// stk[0]: sum(i=1~l)i*a[i]
// stk[1]: sum(i=o+1~l)(i-l)*a[i]
// stk[2]: sum(i=o+1~l)a[i]
// stk[3]: all the nodes we passed and possible to be the final node
// stk[4]: l
// stk[5]: where to belong to
int n,szf[200010],tot,tr[800010],top,rt,del[200010],siz[200010],mxdep,dep[200010];
ll ff(ll x,int i){return lns[i].k*x+lns[i].b;}
ll getk(int i){return lns[i].k;}
bool chk(ll x,int i,int j){return ff(x,i)>ff(x,j);}
void ins(int l,int r,int now,int t)
{
    if(l^r)
    {
        if(chk(l,t,tr[now]) && chk(r,t,tr[now]))    tr[now]=t;
        else if(chk(l,t,tr[now]) || chk(r,t,tr[now]))
        {
            int mid=(l+r)>>1;
            if(chk(mid,t,tr[now]))    tr[now]^=t^=tr[now]^=t;
            if(chk(l,t,tr[now]))    ins(l,mid,now<<1,t);
            else    ins(mid+1,r,now<<1|1,t); 
        }
    }
    else if(chk(l,t,tr[now]))    tr[now]=t;
}
int find(int l,int r,int now,int t) // query line id
{
    if(l^r)
    {
        int mid=(l+r)>>1,res;
        if(mid>=t)    res=find(l,mid,now<<1,t);
        else    res=find(mid+1,r,now<<1|1,t);
        if(chk(t,res,tr[now]))    return res;
        else    return tr[now];
    }
    else    return tr[now];
}
void clear(int l,int r,int now)
{
    int mid=(l+r)>>1;
    tr[now]=0;
    if(l^r)    clear(l,mid,now<<1),clear(mid+1,r,now<<1|1);
}
void get_root(int now,int las,int all)
{
    siz[now]=1;
    szf[now]=0;
    for(int to:G[now])
    {
        if((to^las) && !del[to])
        {
            get_root(to,now,all);
            siz[now]+=siz[to];
            szf[now]=std::max(szf[now],siz[to]);
        }
    }
    szf[now]=std::max(szf[now],all-siz[now]);
    if(szf[now]<szf[rt])    rt=now;
}
void get_value(int now,ll prf0,ll prf1,ll prf2,int wr,int las)
{
    if((now^rt) && !wr)    wr=now;
    mxdep=std::max(mxdep,dep[now]=dep[las]+1);
    bool lef=1;
    for(int to:G[now])    if((to^las) && !del[to])
        lef=0,get_value(to,prf0+prf2+a[to],prf1+a[to]*dep[now],prf2+a[to],wr,now);
    if(lef)
        ++top,stk[0][top]=prf0,stk[1][top]=prf1,stk[2][top]=prf2-a[rt],
        stk[3][top]=now,stk[4][top]=dep[now],stk[5][top]=wr;
}
void get_ans(int now)
{
    del[now]=1;
    top=mxdep=0;
    get_value(now,a[now],0,a[now],0,0);
    ++top;
    stk[0][top]=a[now];
    stk[1][top]=stk[2][top]=stk[5][top]=0;
    stk[3][top]=now;
    stk[4][top]=1;
    stk[5][top+1]=stk[5][0]=-1;
    clear(1,mxdep,1);
    int i=1,j;
    while(i<=top)
    {
        j=i;
        while(stk[5][i]==stk[5][j])    ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),++j;
        j=i;
        while(stk[5][i]==stk[5][j])    lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),++j;
        i=j;
    }
    clear(1,mxdep,1);
    i=top;
    while(i)
    {
        j=i;
        while(stk[5][i]==stk[5][j])    ans=std::max(ff(stk[4][j],find(1,mxdep,1,stk[4][j]))+stk[0][j],ans),--j;
        j=i;
        while(stk[5][i]==stk[5][j])    lns[++tot]=Line(stk[2][j],stk[1][j]),ins(1,mxdep,1,tot),--j;
        i=j;
    }
    for(int to:G[now])    if(!del[to])    rt=0,get_root(to,now,siz[to]),get_ans(rt);
}
int main()
{
    sf(n);
    for(int i=1,x,y;i<n;++i)
    {
        sf(x),sf(y);
        G[x].emplace_back(y);
        G[y].emplace_back(x);
    }
    for(int i=1;i<=n;++i)    ssf(a[i]);
    szf[0]=n+1;
    get_root(1,0,n);
    get_ans(rt);
    printf("%lld\n",ans);
    return 0;
}

trees divide and conquer

Solution Set -「CF 1525」
Prev «
Solution -「SCOI 2016」萌萌哒
» Next