LeetCode - Count Complete Tree Nodes
Problem description
Given a complete binary tree, count the number of nodes.
Note:
Definition of a complete binary tree from Wikipedia: In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.
Example:
1
2
3
4
5
6
7
8
Input:
1
/ \
2 3
/ \ /
4 5 6
Output: 6
Analysis
The first thought is to use count nodes for all types of tree.
1
2
3
4
5
6
7
8
9
10
11
public int countNodes(TreeNode root) {
return dfs(root);
}
int dfs(TreeNode node){
if (node == null){
return 0;
}
return dfs(node.right) + dfs(node.left) + 1;
}
However, we don’t use the fact that this tree is a complete tree.
What we need to do is find the amount of nodes in the last level. So that we can use the binary search to find out how many nodes in the last level.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Solution {
// Return tree depth in O(d) time.
public int computeDepth(TreeNode node) {
int d = 0;
while (node.left != null) {
node = node.left;
++d;
}
return d;
}
// Last level nodes are enumerated from 0 to 2**d - 1 (left -> right).
// Return True if last level node idx exists.
// Binary search with O(d) complexity.
public boolean exists(int idx, int d, TreeNode node) {
int left = 0, right = (int)Math.pow(2, d) - 1;
int pivot;
for(int i = 0; i < d; ++i) {
pivot = left + (right - left) / 2;
if (idx <= pivot) {
node = node.left;
right = pivot;
}
else {
node = node.right;
left = pivot + 1;
}
}
return node != null;
}
public int countNodes(TreeNode root) {
// if the tree is empty
if (root == null) return 0;
int d = computeDepth(root);
// if the tree contains 1 node
if (d == 0) return 1;
// Last level nodes are enumerated from 0 to 2**d - 1 (left -> right).
// Perform binary search to check how many nodes exist.
int left = 1, right = (int)Math.pow(2, d) - 1;
int pivot;
while (left <= right) {
pivot = left + (right - left) / 2;
if (exists(pivot, d, root)) left = pivot + 1;
else right = pivot - 1;
}
// The tree contains 2**d - 1 nodes on the first (d - 1) levels
// and left nodes on the last level.
return (int)Math.pow(2, d) - 1 + left;
}
}