LeetCode #3786 — HARD

Total Sum of Interaction Cost in Tree Groups

Break down a hard problem into reliable checkpoints, edge-case handling, and complexity trade-offs.

Solve on LeetCode
The Problem

Problem Statement

You are given an integer n and an undirected tree with n nodes numbered from 0 to n - 1. This is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi] indicates an undirected edge between nodes ui and vi.

You are also given an integer array group of length n, where group[i] denotes the group label assigned to node i.

  • Two nodes u and v are considered part of the same group if group[u] == group[v].
  • The interaction cost between u and v is defined as the number of edges on the unique path connecting them in the tree.

Return an integer denoting the sum of interaction costs over all unordered pairs (u, v) with u != v such that group[u] == group[v].

Example 1:

Input: n = 3, edges = [[0,1],[1,2]], group = [1,1,1]

Output: 4

Explanation:

All nodes belong to group 1. The interaction costs between the pairs of nodes are:

  • Nodes (0, 1): 1
  • Nodes (1, 2): 1
  • Nodes (0, 2): 2

Thus, the total interaction cost is 1 + 1 + 2 = 4.

Example 2:

Input: n = 3, edges = [[0,1],[1,2]], group = [3,2,3]

Output: 2

Explanation:

  • Nodes 0 and 2 belong to group 3. The interaction cost between this pair is 2.
  • Node 1 belongs to a different group and forms no valid pair. Therefore, the total interaction cost is 2.

Example 3:

Input: n = 4, edges = [[0,1],[0,2],[0,3]], group = [1,1,4,4]

Output: 3

Explanation:

Nodes belonging to the same groups and their interaction costs are:

  • Group 1: Nodes (0, 1): 1
  • Group 4: Nodes (2, 3): 2

Thus, the total interaction cost is 1 + 2 = 3.

Example 4:

Input: n = 2, edges = [[0,1]], group = [9,8]

Output: 0

Explanation:

All nodes belong to different groups and there are no valid pairs. Therefore, the total interaction cost is 0.

Constraints:

  • 1 <= n <= 105
  • edges.length == n - 1
  • edges[i] = [ui, vi]
  • 0 <= ui, vi <= n - 1
  • group.length == n
  • 1 <= group[i] <= 20
  • The input is generated such that edges represents a valid tree.
Patterns Used

Roadmap

  1. Brute Force Baseline
  2. Core Insight
  3. Algorithm Walkthrough
  4. Edge Cases
  5. Full Annotated Code
  6. Interactive Study Demo
  7. Complexity Analysis
Step 01

Brute Force Baseline

Problem summary: You are given an integer n and an undirected tree with n nodes numbered from 0 to n - 1. This is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi] indicates an undirected edge between nodes ui and vi. You are also given an integer array group of length n, where group[i] denotes the group label assigned to node i. Two nodes u and v are considered part of the same group if group[u] == group[v]. The interaction cost between u and v is defined as the number of edges on the unique path connecting them in the tree. Return an integer denoting the sum of interaction costs over all unordered pairs (u, v) with u != v such that group[u] == group[v].

Baseline thinking

Start with the most direct exhaustive search. That gives a correctness anchor before optimizing.

Pattern signal: Array · Tree

Example 1

3
[[0,1],[1,2]]
[1,1,1]

Example 2

3
[[0,1],[1,2]]
[3,2,3]

Example 3

4
[[0,1],[0,2],[0,3]]
[1,1,4,4]
Step 02

Core Insight

What unlocks the optimal approach

  • Do a postorder DFS, count how many nodes of each group are in each subtree.
  • For each edge, contribution = <code>subtree_count * (total_count - subtree_count)</code>.
  • Sum these over all edges and groups.
Interview move: turn each hint into an invariant you can check after every iteration/recursion step.
Step 03

Algorithm Walkthrough

Iteration Checklist

  1. Define state (indices, window, stack, map, DP cell, or recursion frame).
  2. Apply one transition step and update the invariant.
  3. Record answer candidate when condition is met.
  4. Continue until all input is consumed.
Use the first example testcase as your mental trace to verify each transition.
Step 04

Edge Cases

Minimum Input
Single element / shortest valid input
Validate boundary behavior before entering the main loop or recursion.
Duplicates & Repeats
Repeated values / repeated states
Decide whether duplicates should be merged, skipped, or counted explicitly.
Extreme Constraints
Largest constraint values
Re-check complexity target against constraints to avoid time-limit issues.
Invalid / Corner Shape
Empty collections, zeros, or disconnected structures
Handle special-case structure before the core algorithm path.
Step 05

Full Annotated Code

Source-backed implementations are provided below for direct study and interview prep.

// Accepted solution for LeetCode #3786: Total Sum of Interaction Cost in Tree Groups
// Auto-generated Java example from go.
class Solution {
    public void exampleSolution() {
    }
}
// Reference (go):
// // Accepted solution for LeetCode #3786: Total Sum of Interaction Cost in Tree Groups
// package main
// 
// import (
// 	"math/bits"
// 	"slices"
// )
// 
// // https://space.bilibili.com/206214
// func interactionCosts1(n int, edges [][]int, group []int) (ans int64) {
// 	g := make([][]int, n)
// 	for _, e := range edges {
// 		x, y := e[0], e[1]
// 		g[x] = append(g[x], y)
// 		g[y] = append(g[y], x)
// 	}
// 
// 	mx := slices.Max(group)
// 	total := make([]int, mx+1)
// 	for _, x := range group {
// 		total[x]++
// 	}
// 
// 	var dfs func(int, int) []int
// 	dfs = func(x, fa int) []int {
// 		cntX := make([]int, mx+1)
// 		cntX[group[x]] = 1
// 		for _, y := range g[x] {
// 			if y == fa {
// 				continue
// 			}
// 			cntY := dfs(y, x)
// 			for i, c := range cntY {
// 				ans += int64(c) * int64(total[i]-c)
// 				cntX[i] += c
// 			}
// 		}
// 		return cntX
// 	}
// 	dfs(0, -1)
// 	return
// }
// 
// func interactionCosts(n int, edges [][]int, group []int) (ans int64) {
// 	g := make([][]int, n)
// 	for _, e := range edges {
// 		x, y := e[0], e[1]
// 		g[x] = append(g[x], y)
// 		g[y] = append(g[y], x)
// 	}
// 
// 	mx := slices.Max(group)
// 	total := make([]int, mx+1)
// 	for _, x := range group {
// 		total[x]++
// 	}
// 
// 	for target, tot := range total {
// 		if tot == 0 {
// 			continue
// 		}
// 		var dfs func(int, int) int
// 		dfs = func(x, fa int) (cntX int) {
// 			if group[x] == target {
// 				cntX = 1
// 			}
// 			for _, y := range g[x] {
// 				if y == fa {
// 					continue
// 				}
// 				cntY := dfs(y, x)
// 				ans += int64(cntY) * int64(tot-cntY)
// 				cntX += cntY
// 			}
// 			return
// 		}
// 		dfs(0, -1)
// 	}
// 	return
// }
// 
// func interactionCosts3(n int, edges [][]int, group []int) (ans int64) {
// 	g := make([][]int, n)
// 	for _, e := range edges {
// 		v, w := e[0], e[1]
// 		g[v] = append(g[v], w)
// 		g[w] = append(g[w], v)
// 	}
// 
// 	dfn := make([]int, n)
// 	ts := 0
// 	pa := make([][17]int, n)
// 	dep := make([]int, n)
// 	var build func(int, int)
// 	build = func(v, p int) {
// 		dfn[v] = ts
// 		ts++
// 		pa[v][0] = p
// 		for _, w := range g[v] {
// 			if w != p {
// 				dep[w] = dep[v] + 1
// 				build(w, v)
// 			}
// 		}
// 	}
// 	build(0, -1)
// 	mx := bits.Len(uint(n))
// 	for i := range mx - 1 {
// 		for v := range pa {
// 			p := pa[v][i]
// 			if p != -1 {
// 				pa[v][i+1] = pa[p][i]
// 			} else {
// 				pa[v][i+1] = -1
// 			}
// 		}
// 	}
// 	uptoDep := func(v, d int) int {
// 		for k := uint32(dep[v] - d); k > 0; k &= k - 1 {
// 			v = pa[v][bits.TrailingZeros32(k)]
// 		}
// 		return v
// 	}
// 	getLCA := func(v, w int) int {
// 		if dep[v] > dep[w] {
// 			v, w = w, v
// 		}
// 		w = uptoDep(w, dep[v])
// 		if w == v {
// 			return v
// 		}
// 		for i := mx - 1; i >= 0; i-- {
// 			pv, pw := pa[v][i], pa[w][i]
// 			if pv != pw {
// 				v, w = pv, pw
// 			}
// 		}
// 		return pa[v][0]
// 	}
// 
// 	nodesMap := map[int][]int{}
// 	for i, x := range group {
// 		nodesMap[x] = append(nodesMap[x], i)
// 	}
// 
// 	vt := make([][]int, n)   // 虚树
// 	isNode := make([]int, n) // 用来区分是关键节点还是 LCA
// 	for i := range isNode {
// 		isNode[i] = -1
// 	}
// 	addVtEdge := func(v, w int) {
// 		vt[v] = append(vt[v], w) // 往虚树上添加一条有向边
// 	}
// 	const root = 0
// 	st := []int{root} // 用根节点作为栈底哨兵
// 
// 	for val, nodes := range nodesMap {
// 		// 对于相同点权的这一组关键节点 nodes,构建虚树
// 		slices.SortFunc(nodes, func(a, b int) int { return dfn[a] - dfn[b] })
// 		vt[root] = vt[root][:0] // 重置虚树
// 		st = st[:1]
// 		for _, v := range nodes {
// 			isNode[v] = val
// 			if v == root {
// 				continue
// 			}
// 			vt[v] = vt[v][:0]
// 			lca := getLCA(st[len(st)-1], v) // 路径的拐点(LCA)也加到虚树中
// 			// 回溯,加边
// 			for len(st) > 1 && dfn[lca] <= dfn[st[len(st)-2]] {
// 				addVtEdge(st[len(st)-2], st[len(st)-1])
// 				st = st[:len(st)-1]
// 			}
// 			if lca != st[len(st)-1] { // lca 不在栈中(首次遇到)
// 				vt[lca] = vt[lca][:0]
// 				addVtEdge(lca, st[len(st)-1])
// 				st[len(st)-1] = lca // 加到栈中
// 			}
// 			st = append(st, v)
// 		}
// 		// 最后的回溯,加边
// 		for i := 1; i < len(st); i++ {
// 			addVtEdge(st[i-1], st[i])
// 		}
// 
// 		var dfs func(int) int
// 		dfs = func(v int) (size int) {
// 			// 如果 isNode[v] != t,那么 v 只是关键节点之间路径上的「拐点」
// 			if isNode[v] == val {
// 				size = 1
// 			}
// 			for _, w := range vt[v] {
// 				sz := dfs(w)
// 				wt := dep[w] - dep[v] // 虚树边权
// 				// 贡献法
// 				ans += int64(wt) * int64(sz) * int64(len(nodes)-sz)
// 				size += sz
// 			}
// 			return
// 		}
// 
// 		rt := root
// 		if isNode[rt] != val && len(vt[rt]) == 1 {
// 			// 注意 root 只是一个哨兵,不一定在虚树上,得从真正的根节点开始
// 			rt = vt[rt][0]
// 		}
// 		dfs(rt)
// 	}
// 
// 	return
// }
Step 06

Interactive Study Demo

Use this to step through a reusable interview workflow for this problem.

Press Step or Run All to begin.
Step 07

Complexity Analysis

Time
O(n)
Space
O(h)

Approach Breakdown

LEVEL ORDER
O(n) time
O(n) space

BFS with a queue visits every node exactly once — O(n) time. The queue may hold an entire level of the tree, which for a complete binary tree is up to n/2 nodes = O(n) space. This is optimal in time but costly in space for wide trees.

DFS TRAVERSAL
O(n) time
O(h) space

Every node is visited exactly once, giving O(n) time. Space depends on tree shape: O(h) for recursive DFS (stack depth = height h), or O(w) for BFS (queue width = widest level). For balanced trees h = log n; for skewed trees h = n.

Shortcut: Visit every node once → O(n) time. Recursion depth = tree height → O(h) space.
Coach Notes

Common Mistakes

Review these before coding to avoid predictable interview regressions.

Off-by-one on range boundaries

Wrong move: Loop endpoints miss first/last candidate.

Usually fails on: Fails on minimal arrays and exact-boundary answers.

Fix: Re-derive loops from inclusive/exclusive ranges before coding.

Forgetting null/base-case handling

Wrong move: Recursive traversal assumes children always exist.

Usually fails on: Leaf nodes throw errors or create wrong depth/path values.

Fix: Handle null/base cases before recursive transitions.