0
votes

I am trying to print the kth smallest element in an BST. The first solution is using in-order traversal. Next solution is finding the index of the current node by calculation the size of its left subtree. Complete algo:

Find size of left subtree:
   1.If size = k-1, return current node
   2.If size>k return (size-k)th node in right subtree
   3.If size<k return kth node in left subtree

This can be implemented using a separate count function which looks something like

public class Solution {
    public int kthSmallest(TreeNode root, int k) {
        //what happens if root == null
        //what happens if k > total size of tree
        return kthSmallestNode(root,k).val;

    }

    public static TreeNode kthSmallestNode(TreeNode root,int k){
        if(root==null) return root;
        int numberOfNodes = countNodes(root.left);

        if(k == numberOfNodes ) return root;
        if(k<numberOfNodes ) return kthSmallestNode(root.left,k);
        else return kthSmallestNode(root.right,k-numberOfNodes );
    } 

    private static int countNodes(TreeNode node){
        if(node == null) return 0;
        else return 1+countNodes(node.left)+countNodes(node.right);
    }
}

But I see that we count the size for same trees multiple times, so one way is to maintain an array to store thes sizes like the DP way.

But I want to write a recursive solution for this.And here is the code I have written.

class Node {
    int data;
    Node left;
    Node right;

    public Node(int data, Node left, Node right) {
        this.left = left;
        this.data = data;
        this.right = right;
    }
}

public class KthInBST
{
    public static Node createBST(int headData)
    {
          Node head = new Node(headData, null, null);
          //System.out.println(head.data);
          return head;
    }

    public static void insertIntoBst(Node head, int data) 
    {

            Node newNode = new Node(data, null, null);
            while(true) {

            if (data > head.data) {
                if (head.right == null) {
                    head.right = newNode;
                    break;
                } else {
                    head = head.right;
                }
            } else {
                if (head.left == null) {
                    head.left = newNode;
                    break;
                } else {
                    head = head.left;
                }
            }
          }
    }

    public static void main(String[] args) 
    {
           Node head = createBST(5);
           insertIntoBst(head, 7);
           insertIntoBst(head, 6);

           insertIntoBst(head, 2);
           insertIntoBst(head, 1);
           insertIntoBst(head, 21);
           insertIntoBst(head, 11);
           insertIntoBst(head, 14);
           insertIntoBst(head, 3);


           printKthElement(head, 3);
    }

    public static int printKthElement(Node head, int k)
    {
         if (head == null) {
            return 0;
         }

         int leftIndex  = printKthElement(head.left, k);

         int index = leftIndex + 1;


         if (index == k) {
            System.out.println(head.data);
         } else if (k > index) {
            k = k - index;
            printKthElement(head.right, k);
         } else {
            printKthElement(head.left, k);
         }
         return index;
    }


}

This is printing the right answer but multiple times, I figured out why it is printing multiple times but not understanding how to avoid it. And also If I want to return the node instead of just printing How do I do it? Can anyone please help me with this?

1
"finding the index of the current node by calculation the size of its left subtree". This is no better than in-order traversal, unless each node always keeps the size of its subtree. - n. 1.8e9-where's-my-share m.
Yeah, you are right.But this is just another way to do it and I wanted to implement it using recursion - MaPY
You can avoid printing the value multiple times by removing the else part [else { printKthElement(head.left, k); }]. As far as returning the node is concerned, I could not come up with anything as I am not that good in java, also java is all pass by value. I thought two things: 1. Take a class variable (I personally don't like this way) 2. Use a callback (Overkill) - vishal-wadhwa
Is the OP still interested? I have a solution in mind which might just work. - vishal-wadhwa
@vishal-wadhwa Yes Im still looking for a solution. - MaPY

1 Answers

0
votes

Objective:

Recursively finding the kth smallest element in a binary search tree and returning the node corresponding to that element.

Observation:

The number of elements smaller than the current element is the size of the left subtree so instead of recursively calculating its size, we introduce a new member in class Node, that is, lsize which represents the size of the left subtree of current node.

Solution:

At each node we compare the size of left subtree with the current value of k:

  1. if head.lsize + 1 == k: current node in our answer.
  2. if head.lsize + 1 > k: elements in left subtree are more than k, that is, the k the smallest element lies in the left subtree. So, we go left.
  3. if head.lsize + 1 < k: the current element alongwith all the elements in the left subtree are less than the kth element we need to find. So, we go to the right subtree but also reduce k by the amount of elements in left subtree + 1(current element). By subtracting this from k we make sure that we have already taken into account the number of elements which are smaller than k and are rooted as the left subtree of current node (including the current node itself).

Code:

class Node {
    int data;
    Node left;
    Node right;
    int lsize;

    public Node(int data, Node left, Node right) {
        this.left = left;
        this.data = data;
        this.right = right;
        lsize = 0;
    }
}

public static void insertIntoBst(Node head, int data) {

        Node newNode = new Node(data, null, null);
        while (true) {

            if (data > head.data) {
                if (head.right == null) {
                    head.right = newNode;
                    break;
                } else {
                    head = head.right;
                }
            } else {
                head.lsize++; //as we go left, size of left subtree rooted 
                             //at current node will increase, hence the increment.
                if (head.left == null) {
                    head.left = newNode;
                    break;
                } else {
                    head = head.left;
                }
            }
        }
    }

    public static Node printKthElement(Node head, int k) {
        if (head == null) {
            return null;
        }

        if (head.lsize + 1 == k) return head;
        else if (head.lsize + 1 > k) return printKthElement(head.left, k);
        return printKthElement(head.right, k - head.lsize - 1);
    }

Changes:

  1. A new member lsize has been introduced in class Node.
  2. Slight modification in insertIntoBst.
  3. Major changes in printKthElement.

Corner case:

Add a check to ensure that k is between 1 and the size of the tree otherwise a null node will be returned resulting in NullPointerException.

This is working on the test cases I have tried, so far. Any suggestions or corrections are most welcome. :)