Friday, September 4, 2015

Given a BST with 2 nodes swapped, fix it

Problem

Given a BST with 2 nodes swapped fix it.

Example
Consider the BST:
Following is the correct BST 
         10
        /  \
       5    20
      / \
     2   8


Now we swap  8 and 20, and BST is changed.

Input Tree:
         10
        /  \
       5    8
      / \
     2   20

In the above tree, nodes 20 and 8 must be swapped to fix the tree.  

In the previous post, we saw how many pairs in the input tree violate the BST property. Here we will fix it.

1. The swapped nodes are not adjacent in the inorder traversal of the BST.
 For example, Nodes 5 and 25 are swapped in {3 5 7 8 10 15 20 25}. 
 The inorder traversal of the given tree is 3 25 7 8 10 15 20 5 
If we observe carefully, during inorder traversal, we find node 7 is smaller than the previous visited node 25. Here save the context of node 25 (previous node). Again, we find that node 5 is smaller than the previous node 20. This time, we save the context of node 5 ( current node ). Finally swap the two node’s values.
2. The swapped nodes are adjacent in the inorder traversal of BST.
  For example, Nodes 7 and 8 are swapped in {3 5 7 8 10 15 20 25}. 
  The inorder traversal of the given tree is 3 5 8 7 10 15 20 25 
Unlike case #1, here only one point exists where a node value is smaller than previous node value. e.g. node 7 is smaller than node 8.

How to Solve? We will maintain three pointers, first, middle and last. When we find the first point where current node value is smaller than previous node value, we update the first with the previous node & middle with the current node. When we find the second point where current node value is smaller than previous node value, we update the last with the current node. In case #2, we will never find the second point. So, last pointer will not be updated. After processing, if the last node value is null, then two swapped nodes of BST are adjacent.

code:
 private void swap(Node a, Node b) {  
   if (n1 == null || n2 == null)  return;
   int tmp = a.val;  
   a.val = b.val;  
   b.val = tmp;  
 }  
   
 public void recoverTree(Node root) {  
   Node cur = root, pre = null, first = null, second = null;  
   // in order travesal should return a sorted list  
   Stack stack = new Stack();  
   while (cur != null) { // find the left most child  
     stack.push(cur);  
     cur = cur.left;  
   }  
   while (!stack.isEmpty()) {  
     cur = stack.pop();  

     // is it wrong?  
     if (pre != null && cur.val < pre.val) {  
       if (first == null) {  
         // the first wrong item should be the bigger one  
         first = pre;  
         second = cur; // there is a chance that the two were swapped  
       } else {  
         // the second wrong item should be the smaller one  
         second = cur;  
         break;  
       }  
     }  

     // go to right child and repeat  
     pre = cur;  
     cur = cur.right;  
     while (cur != null) {  
       stack.push(cur);  
       cur = cur.left;  
     }  
   }  
   
   swap(first, second);  
 }  


References

Given a binary search tree with 2 nodes swapped find number of pairs not following BST properties

Problem

Given a binary search tree with 2 nodes swapped, give the number of nodes not following bst property. Follow up - Fix the BST, in the next post.

Example
Consider the BST:
Following is the correct BST 
         10
        /  \
       5    20
      / \
     2   8


Now we swap  8 and 20, and BST is changed.

Input Tree:
         10
        /  \
       5    8
      / \
     2   20

In the above tree, nodes 20 and 8 must be swapped to fix the tree.  

Now number of pairs not following BST property are 3. The reason is :
  • 10-20
  • 10-8
  • 20-8

 Solution

Method 1 - Using in-order traversal
We can have following solution:
  1. Find the inorder traversal of the input tree. Example - 2, 5, 20, 10, 8
  2. Find the number of inversions in the inorder traversal. That is the answer. Here the inversions are 20-10, 20-8, and 10-8. 
Time complexity - O(n logn) as O(n) time for inorder traversal and O(nlogn) for number of inversions in the sorted array.

For a Given node of a binary tree, print the K distance nodes.

Problem
You are given a function printKDistanceNodes which takes in a root node of a binary tree, a start node and an integer K. Complete the function to print the value of all the nodes (one-per-line) which are a K distance from the given start node in sorted order. Distance can be upwards or downwards.

Example
start node = 18, k = 2 , then output = 2, 19, 25
start node = 18, k = 3,  then output = -4, 3

Solution

We have already seen a similar problem, where we have to find k distance from the root and k distance from the leaf. Find the distance from root is easy. In the second case of printing from bottom to top (k distance from leaves), we know the direction, i.e. we have to go up. But here we have to find the k elements even going upwards.

Note :- Parent pointer is not given.

Method 1 - Using the recursion

(Printing nodes at a disance of K  downwards  is easy). Its a simple recursive function.So moving to nodes which are in upwards direction.

There are two types of nodes to be considered.
1) Nodes in the subtree rooted with target node. For example if the target node is 18 and k is 2, then such nodes are 19 and 25.
2) Other nodes, may be an ancestor of target, or a node in some other subtree. For target node 18 and k is 2, the node 2 comes in this category.

Finding the first type of nodes is easy to implement. Just traverse subtrees rooted with the target node and decrement k in recursive call. When the k becomes 0, print the node currently being traversed (See this for more details). Here we call the function as printkdistanceNodeDown().

How to find nodes of second type? For the output nodes not lying in the subtree with the target node as the root, we must go through all ancestors. For every ancestor, we find its distance from target node, let the distance be d, now we go to other subtree (if target was found in left subtree, then we go to right subtree and vice versa) of the ancestor and find all nodes at k-d distance from the ancestor.
// Recursive function to print all the nodes at distance k in the
// tree (or subtree) rooted with given root. See  
void printkdistanceNodeDown(Node root, int k)
{
    // Base Case
    if (root == null || k < 0)  return;
 
    // If we reach a k distant node, print it
    if (k==0)
    {
        System.out.println(root.data);
        return;
    }
 
    // Recur for left and right subtrees
    printkdistanceNodeDown(root.left, k-1);
    printkdistanceNodeDown(root.right, k-1);
}
 
// Prints all nodes at distance k from a given target node.
// The k distant nodes may be upward or downward.  This function
// Returns distance of root from target node, it returns -1 if target
// node is not present in tree rooted with root.
int printkdistanceNode(Node root, Node target , int k)
{
    // Base Case 1: If tree is empty, return -1
    if (root == null) return -1;
 
    // If target is same as root.  Use the downward function
    // to print all nodes at distance k in subtree rooted with
    // target or root
    if (root == target)
    {
        printkdistanceNodeDown(root, k);
        return 0;
    }
 
    // Recur for left subtree
    int dl = printkdistanceNode(root.left, target, k);
 
    // Check if target node was found in left subtree
    if (dl != -1)
    {
         // If root is at distance k from target, print root
         // Note that dl is Distance of root's left child from target
         if (dl + 1 == k)
            System.out.println(root.data) endl;
 
         // Else go to right subtree and print all k-dl-2 distant nodes
         // Note that the right child is 2 edges away from left child
         else
            printkdistanceNodeDown(root.right, k-dl-2);
 
         // Add 1 to the distance and return value for parent calls
         return 1 + dl;
    }
 
    // MIRROR OF ABOVE CODE FOR RIGHT SUBTREE
    // Note that we reach here only when node was not found in left subtree
    int dr = printkdistanceNode(root.right, target, k);
    if (dr != -1)
    {
         if (dr + 1 == k)
            System.out.println(root.data) endl;
         else
            printkdistanceNodeDown(root.left, k-dr-2);
         return 1 + dr;
    }
 
    // If target was neither present in left nor in right subtree
    return -1;
}


Method 2 - Using the queue

Use a queue of size K to store the root to node path.
Now since, the queue is of size K.As soon as we find the NODE  in tree, the node at front of queue is at a distance K from NODE. It can be the case that the front node is less than K distant from NODE.
So, maintain a counter.

Now start popping a node from queue which is at distant  i from NODE, and print all downwards nodes at distance K-i  in its other subtree.We only need to print the nodes in other  subtree to avoid Error.

Note :- Since we need to print the nodes in sorted order, we can maintain a priority queue to store the nodes and after processing the nodes, we can print it.

References

Print all nodes that are at distance k from a leaf node

Problem

Given a Binary Tree and a positive integer k, print all nodes that are distance k from a leaf node.

Here the meaning of distance is different from previous post. Here k distance from a leaf means k levels higher than a leaf node. For example if k is more than height of Binary Tree, then nothing should be printed. Expected time complexity is O(n) where n is the number nodes in the given Binary Tree.




Example
(Please ignore the empty node, and consider it null)

k = 1, Answer = 2, 19 , 21
k = 2, Answer = 5, 18 , 19

Solution

The idea is to traverse the tree. Keep storing all ancestors till we hit a leaf node. When we reach a leaf node, we print the ancestor at distance k. We also need to keep track of nodes that are already printed as output. For that we use a boolean array visited[].

// This function prints all nodes that are distance k from a leaf node
//   path[] - Store ancestors of a node
//   visited[] - Stores true if a node is printed as output.  A node may be k
//                 distance away from many leaves, we want to print it once 
void kDistantFromLeafUtil(Node node, int path[], bool visited[],
                          int pathLen, int k)
{
    // Base case
    if (node==null) return;
 
    // append this Node to the path array 
    path[pathLen] = node.data;
    visited[pathLen] = false;
    pathLen++;
 
    // it's a leaf, so print the ancestor at distance k only
    // if the ancestor is not already printed  
    if (node.left == null && node.right == null &&
        pathLen-k-1 >= 0 && visited[pathLen-k-1] == false)
    {
        System.out.print(path[pathLen-k-1] + " ");
        visited[pathLen-k-1] = true;
        return;
    }
 
    // If not leaf node, recur for left and right subtrees 
    kDistantFromLeafUtil(node.left, path, visited, pathLen, k);
    kDistantFromLeafUtil(node.right, path, visited, pathLen, k);
}
 
// Given a binary tree and a nuber k, print all nodes that are k
//   distant from a leaf
void printKDistantfromLeaf(Node node, int k)
{
    int[] path = new int[MAX_HEIGHT];
    boolean[] visited = new boolean[MAX_HEIGHT];
    //all the elements false in visited
    Arrays.fill(visited, false);
    kDistantFromLeafUtil(node, path, visited, 0, k);
}


References

Find the distance between 2 nodes in Binary Tree

Problem

Find the distance between two keys in a binary tree, no parent pointers are given. Distance between two nodes is the minimum number of edges to be traversed to reach one node from other.

Example
Dist(-4,3) = 2,
Dist (-4,19) = 4
Dist(21,-4) = 3
Dist(2,-4) = 1

Solution

The distance between two nodes can be obtained in terms of lowest common ancestor. Following is the formula.

Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca) 
'n1' and 'n2' are the two given keys
'root' is root of given Binary Tree.
'lca' is lowest common ancestor of n1 and n2
Dist(n1, n2) is the distance between n1 and n2.

Example take the case of Dist(-4,3)
LCA(-4,3) = 2
Dist(-4,3) = Dist(5,-4)+Dist(5,3) - 2 * (5,2) = 3 + 3 - 2 * 2 = 2

Now lets do the coding.

Code

// Returns level of key k if it is present in tree, otherwise returns -1
int findLevel(Node root, int k, int level)
{
    // Base Case
    if (root == null)
        return -1;
 
    // If key is present at root, or in left subtree or right subtree,
    // return true;
    if (root.key == k)
        return level;
 
    int l = findLevel(root.left, k, level+1);
    return (l != -1)? l : findLevel(root.right, k, level+1);
}
 
// This function returns pointer to LCA of two given values n1 and n2. 
// It also sets d1, d2 and dist if one key is not ancestor of other
// Note that we set the value in findDistUtil for d1,d2 and dist
// d1 -. To store distance of n1 from root
// d2 -. To store distance of n2 from root
// lvl -. Level (or distance from root) of current node
// dist -. To store distance between n1 and n2
Node findDistUtil(Node root, int n1, int n2, Integer d1, Integer d2, 
                   Integer dist, int lvl)
{
    // Base case
    if (root == null) return null;
 
    // If either n1 or n2 matches with root's key, report
    // the presence by returning root (Note that if a key is
    // ancestor of other, then the ancestor key becomes LCA
    if (root.key == n1)
    {
         d1 = lvl;
         return root;
    }
    if (root.key == n2)
    {
         d2 = lvl;
         return root;
    }
 
    // Look for n1 and n2 in left and right subtrees
    Node left_lca  = findDistUtil(root.left, n1, n2, d1, d2, dist, lvl+1);
    Node right_lca = findDistUtil(root.right, n1, n2, d1, d2, dist, lvl+1);
 
    // If both of the above calls return Non-null, then one key
    // is present in once subtree and other is present in other,
    // So this node is the LCA
    if (left_lca!=null && right_lca!=null)
    {
        dist = d1 + d2 - 2*lvl;
        return root;
    }
 
    // Otherwise check if left subtree or right subtree is LCA
    return (left_lca != null)? left_lca: right_lca;
}
 
// The main function that returns distance between n1 and n2
// This function returns -1 if either n1 or n2 is not present in
// Binary Tree.
int findDistance(Node root, int n1, int n2)
{
    // Initialize d1 (distance of n1 from root), d2 (distance of n2 
    // from root) and dist(distance between n1 and n2)
    Integer d1 = -1, d2 = -1, dist;
    Node lca = findDistUtil(root, n1, n2, d1, d2, dist, 1);
 
    // If both n1 and n2 were present in Binary Tree, return dist
    if (d1 != -1 && d2 != -1)
        return dist;
 
    // If n1 is ancestor of n2, consider n1 as root and find level 
    // of n2 in subtree rooted with n1
    if (d1 != -1)
    {
        dist = findLevel(lca, n2, 0);
        return dist;
    }
 
    // If n2 is ancestor of n1, consider n2 as root and find level 
    // of n1 in subtree rooted with n2
    if (d2 != -1)
    {
        dist = findLevel(lca, n1, 0);
        return dist;
    }
 
    return -1;
}

findDistance() is the main function which calculates the distance, which calls findDistUtil which calculates distance as well as find the LCA in case n1 is not the ancestor of n2 or vice versa.
If n1 is ancestor of n2 or vice versa, we use findLevel to simply find the difference between 2 levels.

Time Complexity - O(n) as we do single traversal on the tree

Note that in java we dont have out parameters in function, like we have in c#. Hence I have used Integer Object, so that I can set the value in d1,d2 and dist as we have pass by value for primitive types in java, but we needed pass by reference.

References

Program to count leaf nodes in a binary tree

Problem

Count leaf nodes in a binary tree

Solution

Method 1 - Recursive
Here is the logic:
  1. If node is NULL then return 0.
  2. Else If left and right child nodes are NULL return 1.
  3. Else recursively calculate leaf count of the tree using below formula.
    Leaf count of a tree = Leaf count of left subtree + Leaf count of right subtree

Here is the recursive solution:
int countLeaves(Node node){
  if( node == null )
    return 0;
  if( node.left == null && node.right == null ) {
    return 1;
  } else {
    return countLeaves(node.left) + countLeaves(node.right);
  }
}

Time complexity - O(n)

Method 2 - Iterative
Here we can use Queue. Idea is to use Level Order traversal.

int countLeaves(Node root)
{
  int count=0;
    if(root==null)
      return 0;

    Queue<Node> myqueue = new Queue<Node>();
    myqueue.push(root);

    while(!myqueue.empty())
    {
      Node temp;
       temp=myqueue.pop();//take the top element from the queue
      if(temp.left==null && temp.right==null)
       count++;
      if(temp.left!=null)
       myqueue.push(temp.left);
      if(temp.right!=null)
       myqueue.push(temp.right);
    }
  return count;
}

Time complexity - O(n)
Space Complexity - O(n)  for queue

Referenes