Published on

Codeforces Round 890 - E1. Permutree

Authors
  • avatar
    Name
    Wong Yen Hong
    Twitter

Problem URL: https://codeforces.com/contest/1856/problem/E1

For some reasons, this problem had surprisingly many solves during the contest. It seems like many people are aware of the O(n2)O(n^2) dp on tree trick. Anyway, let's jump straight into the problem and examine the obscure time complexity of the solution.

Abridged Statement

You're given a tree of nn nodes, 1n50001 \leq n \leq 5000 and an array pp where pip_i denotes the parent of node i+1i+1. Let's define f(a)f(a) as the score of a permutation aa of 11 to nn on the tree. Score in this context is defined as the number of pairs (u,v)(u, v) such that a[u]<a[lca(u,v)]<a[v]a[u] \lt a[lca(u, v)] \lt a[v] where lca(u,v)lca(u,v) denotes the lowest common ancestor of node uu and vv. Our task is to find the maximum f(a)f(a) among all the possible permutation aa to the tree.

More about the lowest common ancestor (LCA) of two nodes can be found in here.

Example

Assuming the input of our tree is as follows:

5
1 1 3 3

The given input can be visualized as follows:

tree-example

Assume a=[2,1,4,5,3]a = [2, 1, 4, 5, 3]:

tree-example

The score f(a)f(a) will be 44, because there are 44 pairs of node u,vu, v that satisfy a[u]<a[lca(u,v)]<a[v]a[u] \lt a[lca(u, v)] \lt a[v]:

  • (2,3)(2,3) since lca(2,3)=1lca(2,3)=1 and 1<2<41<2<4
  • (2,4)(2,4) since lca(2,4)=1lca(2,4)=1and 1<2<51<2<5
  • (2,5)(2,5) since lca(2,5)=1lca(2,5)=1and 1<2<31<2<3
  • (5,4)(5,4) since lca(5,4)=3lca(5,4)=3 and 3<4<53<4<5

Solution

Let's first focus in finding the number of pairs (u,v)(u,v) for an individual node xx that satisfies a[u]<a[x]<a[v]a[u] \lt a[x] \lt a[v] where x=lca(u,v)x = lca(u,v), let's also call this the score of node xx.

For example, let's try to find the score of node 11 in the following tree.

tree-example

Specifically, we will find the number of pairs (u,v)(u,v) that has lca(u,v)=1lca(u,v) = 1 and a[u]<a[1]<a[v]a[u] < a[1] < a[v]. The number beside each of the node ii represents a[i]a[i].

We could manually count the scores, but that will only help in understanding the problem and won't be helping in finding the solution. Hence, we'll focus on finding a more efficient and systematic way of counting the number of pairs that satisfy the condition. To do this, we will gather some observations.

The first observation is that for any two nodes in the subtree of a children (of node ii) won't have node ii as their LCA. For example, node 55 and node 66 belongs in the subtree of node 22 (children of node 11), the LCA of node 55 and node 66 is node 22 instead of 11. This is pretty obvious, because the common ancestor (not necessarily lowest) of two nodes in the subtree of node ii won't be higher than node ii.

Thus, we can further abstract our tree to the following way.

tree-example

The abstracted tree only contains the necessary information of the subtree of each children now where dd denotes the number of nodes in the subtree (including itself) and ss denotes the number of nodes ii in the subtree that has a[i]<a[1]a[i] < a[1]. For example, in the subtree of node 22, the set of nodes in that subtree are {2,5,6}\{2, 5, 6\} which has a size of 33 and the set of nodes ii that has a[i]<a[1]a[i] < a[1] are {5,6}\{5, 6\} which has a size of 22.

*The term "subtree" in this blog will simply refer to the subtree with the children as the root.

To find the score, we can use the formula below:

i=1m((d[i]s[i])j=1,jims[j])\begin{align*} \sum_{i=1}^{m} \left((d[i] - s[i]) \cdot \sum_{j=1, j \neq i}^{m} s[j]\right) \end{align*}

where mm denotes the number of children/subtrees for node 11. This is basically equivalent to pairing every node ii that has a[i]<a[1]a[i] < a[1] with every node jj where a[1]<a[j]a[1] < a[j] such that they are in a different subtree. We can use the same formula to find the number of pairs (u,v)(u,v) for any node ii as LCA. To find f(a)f(a) we can sum the scores of each LCA using the above formula.

Now we know how to find the score for any node ii as LCA assuminng we have pre-assigned a permutation to the tree. But, that's not our main objective for this problem, we're tasked to maximize the value f(a)f(a) instead of counting f(a)f(a) based off a pre-defined aa.

So, how do we find the optimal aa that maximizes f(a)f(a)?

The question is - Do we even need to find it? Do we even need to know about aa to find our answer? Based on the previous calculation, the only things that we need to calculate the score of each LCA is just ss and dd of each subtree.

It turns out we don't really need to know the permutation aa to count the answer.

For each LCA, we can simply assign the value of ss for each subtree of its children where sds \leq d (because the total number of nodes in the subtree that is smaller than a[lca]a[lca] cannot be bigger than the total nodes in the subtree) and there will always be a valid aa that will give the same set of ss and dd. To see why this is true, we will demonstrate in the same example above.

tree-example

Assuming we don't know the permutation aa and we know the value of dd for each subtree of children ii and the target value of ss. One way to achieve is as follows:

tree-example

Generally, for LCA xx, we can set x=o[i=1ms[i]]+1x = o[\sum_{i=1}^m s[i]] + 1 where o[i]o[i] denotes the ii-th smallest number of the available numbers. For each of the subtree ii, we set s[i]s[i] of the nodes in the subtree with distinct value smaller than a[x]a[x].

Now, we can use Dynamic Programming (DP) to maximize the value of each LCA xx.

Let's define the DP state as follows:

dp[i][j]dp[i][j] = the maximum scores for the first ii subtrees (children of xx) such that there are jj nodes that is <a[x]< a[x]

To understand the DP transition, let's first look at the difference betweeen the score for mm children and m+1m+1 children first.

(i=1m+1(d[i]s[i])j=1,jim+1s[j])(i=1m(d[i]s[i])j=1,jims[j])=(i=1m(d[i]s[i])s[m+1])+((d[m+1]s[m+1])i=1ms[i])\begin{align*} &\left(\sum_{i=1}^{m+1} (d[i] - s[i]) \cdot \sum_{j=1, j \neq i}^{m+1} s[j]\right) - \left(\sum_{i=1}^{m} (d[i] - s[i]) \cdot \sum_{j=1, j \neq i}^{m} s[j]\right)\\ =& \left(\sum_{i=1}^{m} (d[i] - s[i]) \cdot s[m+1]\right) + \left((d[m+1]-s[m+1]) \cdot \sum_{i=1}^{m} s[i]\right) \end{align*}

The equation can be derrived by cancelling off each similar term on both side of the subtract. And it leads us to the DP transition below

dp[i][j+k]=max(dp[i][j+k],dp[i1][j]+(((l=1)i1d[l])j)k+((d[i]k)j))\begin{align*} dp[i][j+k] = max\left(dp[i][j+k], dp[i-1][j] + \left(\left(\sum_{(l = 1)}^{i - 1} d[l]\right) - j\right) \cdot k + \left( (d[i]-k) \cdot j \right) \right) \end{align*}

In each transition, we're basically finding the maximum score for the first ii subtrees if we set s[i]=ks[i] = k.

Finally, we just need to sum up the DP answer for each LCA.

But wait a minute, isn't this DP transition O(N3)O(N^3)? Well, hold your horses for a moment. I will prove the O(N2)O(N^2) time complexity in the analysis below.

Source Code

Source Code
#include <bits/stdc++.h>
#define all(x) begin(x),end(x)
#define fir first
#define sec second
#define sz(x) x.size()
#define pb push_back
 
using namespace std;
using ll = long long;
using vi = vector<int>;
using pi = pair<int,int>;
using pdb = pair<double,double>;
using pll = pair<ll,ll>;
using vll = vector<ll>;
using ull = unsigned long long;
const double EPS = (1e-6);
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
 
void setio(string s){
    freopen((s + ".in").c_str(),"r",stdin);
    freopen((s + ".out").c_str(),"w",stdout);
}

pll dfs(int node, vector<vi>& adj){
    ll score = 0;
    vi desc(sz(adj[node]));
    int total = 0;
    for(int i{}; i < sz(adj[node]); i++){
        pll cur = dfs(adj[node][i], adj);
        score += cur.fir;
        total += cur.sec;
        desc[i] = cur.sec;
    }

    vll dp(total+1);
    vll dp1(total+1);
    int cur = 0;
    ll mx = 0;
    for(ll i{}; i < sz(adj[node]); i++){
        for(ll j{}; j <= cur; j++){
            for(ll k{}; k <= desc[i]; k++){
                dp[j+k] = max(dp[j+k], dp1[j] + k * (cur - j) + j * (desc[i] - k));
                mx = max(mx, dp[j+k]);
            }
        }
        cur += desc[i];
        swap(dp, dp1);
        fill(all(dp), 0);
    }

    return {mx+score, total + 1};

}

void solve(){
int n;
cin >> n;

    vector<vi> adj(n);
    for(int i{1}; i <= n-1; i++){
        int p;
        cin >> p;
        p--;
        adj[p].pb(i);
    }
    cout << dfs(0, adj).fir << '\n';

}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int t = 1;
    while(t--){
        solve();
    }
    return 0;
}

Time Complexity Analysis

Before going into this section, you can have a read on the visual proof of the O(N2)O(N^2) complexity of certain tree DPs in this article.

Now let's observe our DP transition in pseudocode.


for i in range(0, total children of root):
	for j in range(0, total nodes in the first i-1 subtrees):
		for k in range(0, total nodes in the ith subtree):
			// transition

How do we match this to the one we see in the article?

Although it has different transition, but they're essentially the same thing! For each batch of nodes in the ii-th subtree, we're matching all of them to the previous nodes. The total number of operations is essentially the total number of pairs in the subtree. Therefore, the time complexity of this transition is O(n2)O(n^2), because the total pairs of nodes is n2\leq n^2. O(n2)O(n^2) can easily pass n5000n \leq 5000 in 2 seconds.