Knowledge Guide
HomeDSACompany Practice

hard Sum of Distances in Tree

Problem Statement

You are given an undirected connected tree with n nodes marked from 0 to n - 1 and n - 1 edges.

You are given an integer n and the array edges where edges[i] = [ai, bi] represents that there is an edge between nodes ai and bi in the tree,

Calculate the total distance from each node to all other nodes in the tree and return these sums in an array, where the ith element corresponds to the total distance from node i.

Examples

Example 1:

Image
Image

Example 2:

Image
Image

Example 3:

Image
Image

Try it yourself

Try solving this question here:

✅ Solution Sum of Distances in Tree

Problem Statement

You are given an undirected connected tree with n nodes marked from 0 to n - 1 and n - 1 edges.

You are given an integer n and the array edges where edges[i] = [ai, bi] represents that there is an edge between nodes ai and bi in the tree,

Calculate the total distance from each node to all other nodes in the tree and return these sums in an array, where the ith element corresponds to the total distance from node i.

Examples

Example 1:

  • Input: n = 4, edges = [[0, 1], [1, 2], [1, 3]]
Image
Image
  • Expected Output: [5, 3, 5, 5]
  • Justification: From node 0, the distances to the other nodes are 1 (to node 1), 2 (to node 2), and 3 (to node 3), summing to 6. For node 1, the distance to all nodes is 1, summing to 3. Nodes 2 and 3 have the same distance profile as node 0.

Example 2:

  • Input: n = 3, edges = [[0, 1], [1, 2]]
Image
Image
  • Expected Output: [3, 2, 3]
  • Justification: Node 0 has a total distance of 1 (to node 1) + 2 (to node 2) = 3. Node 1 is centrally located, so its total distance is 1 (to node 0) + 1 (to node 2) = 2. Node 2 has the same total distance as node 0.

Example 3:

  • Input: n = 5, edges = [[0, 2], [0, 3], [1, 3], [2, 4]]
Image
Image
  • Expected Output: [6, 10, 7, 7, 10]
  • Justification: The tree diagram shows the distance from each node to other nodes. By summing the distance from each node to other nodes, we can get the answer.

Solution

To solve this problem, we'll employ a two-step strategy using depth-first search (DFS). Initially, we'll calculate the sum of distances from the root node to all other nodes and the count of nodes in each subtree, including the root itself. This initial step helps us understand the layout of the tree in terms of distances and connections. Secondly, we'll adjust these sums for each node based on the initial calculations. The rationale is that moving the root from one node to its adjacent node decreases the distance to the subtree of the adjacent node by one (per node in that subtree) but increases the distance to all other nodes by one.

This approach is effective because it leverages the tree's structure to efficiently compute distances without recalculating them entirely for each node.

Step-by-Step Algorithm

  1. Initialize Data Structures:

    • Create an adjacency list tree to represent the tree, where tree is a list of sets (in Python) or an ArrayList of HashSet (in Java), to efficiently manage the edges and children of each node.
    • Initialize two arrays, count and result, of length n. count[i] will store the number of nodes in the subtree rooted at node i, including itself, and result[i] will store the sum of distances from node i to all nodes in its subtree.
  2. Build the Tree:

    • For each edge [u, v] in edges, add v to the adjacency list of u and u to the adjacency list of v. This step effectively builds the undirected tree structure.
  3. First DFS (Post-order):

    • Perform a depth-first search starting from an arbitrary root node (usually 0), traversing all the way down to the leaf nodes.
    • For each node node visited, do the following:
      • Recursively visit all unvisited children child of node.
      • After visiting a child child, update count[node] by adding count[child] to account for the nodes in the subtree rooted at child.
      • Update result[node] by adding result[child] + count[child]. This step accumulates the total distance to the subtree nodes from node.
  4. Second DFS (Pre-order):

    • Perform another DFS traversal starting from the same root node.
    • For each node node visited, for every child child of node, adjust result[child] to represent the total distance from child to all other nodes. This is done by taking result[node], subtracting count[child] to remove the distances within the subtree rooted at child, and adding n - count[child] to include the distances to all nodes outside the subtree rooted at child.
  5. Return the Result:

    • After completing the second DFS, the result array contains the sum of distances from each node to all other nodes. Return this array as the final output.

Algorithm Walkthrough

Input: n = 5, edges = [[0, 2], [0, 3], [1, 3], [2, 4]]

To clearly understand the algorithm walkthrough for the input ( n = 5 ), edges = ([[0, 2], [0, 3], [1, 3], [2, 4]]), let's start by drawing the tree diagram based on the edges provided.

Tree Diagram

The tree structure based on the given edges is as follows:

    0
   / \
  2   3
 /     \
4       1

This diagram represents the tree structure created by the edges. Node 0 is connected to nodes 2 and 3, node 2 is further connected to node 4, and node 3 is connected to node 1.

Algorithm Walkthrough

Initialization:

  • Create a list of sets (or similar structure) for maintaining the adjacency list of the tree. Let's call it tree.
  • Initialize two arrays, count and result, of length n (5 in this case), where all elements of count are set to 1 (as each node counts itself) and all elements of result are set to 0.

Build the Tree:

  • Add edges to the tree, resulting in the following structure:
    • 0: {2, 3}
    • 1: {3}
    • 2: {0, 4}
    • 3: {0, 1}
    • 4: {2}

First DFS (Post-order Traversal):

  1. Start DFS from node 0, which has children 2 and 3.
  2. Go to node 2:
    • Node 2 has a child, node 4. Visit node 4:
      • Node 4 has no children. count[4] remains 1, and result[4] is 0 since there are no subtrees.
    • After visiting node 4, update count[2] to 2 (itself + node 4) and result[2] to 1 (distance to node 4).
  3. Go to node 3:
    • Node 3 has a child, node 1. Visit node 1:
      • Node 1 has no children. count[1] remains 1, and result[1] is 0.
    • After visiting node 1, update count[3] to 2 (itself + node 1) and result[3] to 1 (distance to node 1).
  4. Return to node 0, update count[0] to 5 (itself + count[2] + count[3]) and result[0] to 6 (result[2] + count[2] + result[3] + count[3]).
  • After this step, we have:
    • count: [5, 1, 2, 2, 1] (total nodes reachable including self)
    • result: [6, 0, 1, 1, 0] (sum of distances to subtree nodes)

Second DFS (Pre-order Traversal):

  1. Start at Node 0 (Root Node):

    • Node 0's result is already calculated from the post-order DFS. No adjustments needed at the starting point.
  2. Move to Node 2 (Child of Node 0):

    • Adjust result[2] using the formula: result[parent] - count[child] + (n - count[child]).
    • Calculation: result[2] = result[0] - count[2] + (5 - count[2]).
    • result[2] becomes 6 - 2 + (5 - 2) = 7.
    • Dive into Node 4 (Child of Node 2):
      • Adjust result[4] using the same formula: result[4] = result[2] - count[4] + (5 - count[4]).
      • result[4] becomes 7 - 1 + (5 - 1) = 10.
  3. Back to Node 0, and Move to Node 3 (Another Child of Node 0):

    • Adjust result[3] using the formula: result[0] - count[3] + (5 - count[3]).
    • result[3] becomes 6 - 2 + (5 - 2) = 7.
    • Dive into Node 1 (Child of Node 3):
      • Adjust result[1] using the same formula: result[1] = result[3] - count[1] + (5 - count[1]).
      • result[1] becomes 7 - 1 + (5 - 1) = 10.
  4. Completion of Pre-order Traversal:

    • At this stage, we have adjusted the result array for each node to reflect the total distance to all other nodes, utilizing the parent-child relationship dynamics calculated during the post-order traversal.
    • The final answer is [ 6, 10, 7, 7, 10 ].

Code

java
import java.util.*;

public class Solution {

  int[] count, result;
  List<Set<Integer>> tree;

  public int[] sumOfDistancesInTree(int n, int[][] edges) {
    // Initialize variables
    tree = new ArrayList<>();
    count = new int[n];
    result = new int[n];
    Arrays.fill(count, 1); // Initialize count of each node as 1 (itself)

    // Create adjacency list representation of the tree
    for (int i = 0; i < n; ++i) {
      tree.add(new HashSet<>());
    }

    // Populate the adjacency list
    for (int[] edge : edges) {
      tree.get(edge[0]).add(edge[1]);
      tree.get(edge[1]).add(edge[0]);
    }

    // Perform post-order DFS traversal to calculate initial distances and counts
    postOrder(0, -1);
    // Perform pre-order DFS traversal to adjust distances based on tree structure
    preOrder(0, -1);

    // Return the resulting distances
    return result;
  }

  // Post-order DFS to calculate initial distances and counts
  private void postOrder(int node, int parent) {
    for (int child : tree.get(node)) {
      if (child == parent) continue; // Skip the parent node
      postOrder(child, node);
      // Update count and result arrays for the current node
      count[node] += count[child];
      result[node] += result[child] + count[child];
    }
  }

  // Pre-order DFS to adjust distances based on tree structure
  private void preOrder(int node, int parent) {
    for (int child : tree.get(node)) {
      if (child == parent) continue; // Skip the parent node
      // Update result array for the child node
      result[child] = result[node] - count[child] + count.length - count[child];
      preOrder(child, node);
    }
  }

  public static void main(String[] args) {
    Solution solution = new Solution();

    // Example 1
    System.out.println(
      Arrays.toString(
        solution.sumOfDistancesInTree(
          4,
          new int[][] { { 0, 1 }, { 1, 2 }, { 1, 3 } }
        )
      )
    );

    // Example 2
    System.out.println(
      Arrays.toString(
        solution.sumOfDistancesInTree(3, new int[][] { { 0, 1 }, { 1, 2 } })
      )
    );

    // Example 3
    System.out.println(
      Arrays.toString(
        solution.sumOfDistancesInTree(
          5,
          new int[][] { { 0, 2 }, { 0, 3 }, { 1, 3 }, { 2, 4 } }
        )
      )
    );
  }
}

Complexity Analysis

Time Complexity

  • : The algorithm performs two depth-first searches (DFS) across the tree of N nodes. Each DFS visits every node exactly once. The operations performed per node are constant time, assuming the tree representation allows for access to children (which is true for both adjacency lists and sets). Therefore, the overall time complexity is linear with respect to the number of nodes in the tree.

Space Complexity

  • : The space complexity is also linear. This is due to several factors:
    • The adjacency list (or set) representation of the tree, which stores N-1 edges but essentially requires space since each node appears in the list/set.
    • The recursion stack for the DFS, which in the worst case, could go as deep as N in the case of a skewed tree.
    • The count and result arrays, each of which has a length of N.
🤖 Don't fully get this? Learn it with Claude

Stuck on Sum of Distances in Tree? Open Claude, copy a block below, and it'll teach you this exact concept — visually and interactively.

🪜 Hint ladder (no spoilers)

Progressively stronger hints — you still solve it.

I'm working on the problem **Sum of Distances in Tree** (DSA). Give me a HINT LADDER: start with the tiniest nudge, then wait. Only reveal the next, stronger hint when I ask. Do NOT show the full solution unless I type 'show solution'. Keep me doing the thinking. If you're unsure or a claim isn't standard, say so and reason from first principles instead of guessing.
🎨 Explain the approach visually

See the technique, not just code.

Explain the optimal approach to **Sum of Distances in Tree** with a VISUAL walkthrough: trace it on a small concrete example using ASCII art / a step-by-step diagram, narrate what changes each step, then give time & space complexity with a one-line derivation. If you're unsure or a claim isn't standard, say so and reason from first principles instead of guessing.
🔍 Review my solution

Catch bugs, edge cases, sub-optimality.

I'll paste my solution to **Sum of Distances in Tree**. Review it for correctness, missed edge cases, and time/space complexity, then coach me toward the optimal — don't just rewrite it. Ask me to paste my code now. If you're unsure or a claim isn't standard, say so and reason from first principles instead of guessing.
🔁 Drill the pattern

Lock in recognition with look-alikes.

Give me 2 problems that use the SAME underlying pattern as **Sum of Distances in Tree**. For each, let me attempt first, then review my answer and name the trigger signal that reveals the pattern. If you're unsure or a claim isn't standard, say so and reason from first principles instead of guessing.

📝 My notes