Understand Recursion & Memoization

Let's understand recursion and memoization with a few examples.

Recursion is a technique to solve problems by

  1. dividing the original problem into one or more similar sub-problems,

  2. solving the sub-problems by calling the same method again, and then

  3. using the solutions of sub-problems to get the solution to the actual problem.

Sometimes, recursion can be slow as we might end up solving the same problem multiple times. To avoid this, we use memoization. The concept is to solve the problem once and remember the solution to the problem. And in the future, if we encounter the same problem, we just return the solution, instead of solving it again.

When to use recursion?

When you think the problem can be broken down into smaller sub-problems and their solution can be used to solve the actual problem, you can use recursion. And if using recursion is slow (in many cases), you can use memoization with recursion.

Although many problems which can be solved by recursion can be solved without recursion as well, using recursion is more readable and expressive.

How to use recursion?

In programming, recursion basically means a function or method calling itself. In general, while calling itself the input to the function is changed, or it might end up calling itself infinite times and you will get the stack-overflow error.

Let's take some examples to understand it better.

Factorial (n!)

Factorial of a non-negative integer n, is the product of all positive integers less than or equal to n. That means,

fact(n) = 1 X 2 X ... X (n - 1) X n

Let's solve this problem using 2 techniques - iteration and recursion.

Factorial using iteration

We can simply iterate from 1 to n and multiply each number to get the value of fact(n).

public static int fact(int n) {
    int answer = 1;
    for (int i = 1; i <= n; i++) {
        answer = answer * i;
    }
    return answer;
}

This works fine and is readable as well. Still, we can use recursion to calculate fact(n). That will be more readable.

Factorial using recursion

To calculate the value of fact(n) using recursion, we need to break this problem into smaller sub-problem. If we pay close attention, we see that:

fact(n) = n x (n - 1) x ... x 2 x 1
fact(n) = n x fact(n - 1)

Here, we break down the problem of calculating the factorial of n into calculating the factorial of n - 1, which is a subproblem of the original problem. So, we can use recursion to solve this problem. Let's look at the code:

public static int fact(int n) {
    return n * fact(n - 1);
}

Here, we call the fact method again, but with the argument n - 1 . We assume that it will give us the correct value (n - 1)! and then we multiply this value with n to get n!. Let's call this method from the main method with n = 5.

class Main {
    public static int fact(int n) {
        return n * fact(n - 1);
    }
    public static void main(String[] args) {
        System.out.println(fact(5));
    }
}

We encountered an error that says:

Exception in thread "main" java.lang.StackOverflowError

This is because we are calling fact method infinite times.

fact(5) = 5 * fact(4)
        = 5 * 4 * fact(3)
        = 5 * 4 * 3 * fact(2)
        = 5 * 4 * 3 * 2 * fact(1)
        = 5 * 4 * 3 * 2 * 1 * fact(0)
        = 5 * 4 * 3 * 2 * 1 * 0 * fact(-1)
        = 5 * 4 * 3 * 2 * 1 * 0 * -1 * fact(-2)
... and so on

Since we have not specified when to stop, we end up having too many recursive calls of fact function to be handled by computer memory. We need to tell the fact method when to stop recursion or what is the base case. For the factorial function, the base case is n == 0, since 0! = 1 . Let's rewrite the fact method again so that it returns 1 for n = 0 .

public static int fact(int n) {
    if (n == 0) return 1;
    return n * fact(n - 1);
}

Now, run the main method and we see the correct output:

120

Here is how it works:

fact(5) = 5 * fact(4)
        = 5 * (4 * fact(3))
        = 5 * (4 * (3 * fact(2)))
        = 5 * (4 * (3 * (2 * fact(1))))
        = 5 * (4 * (3 * (2 * (1 * fact(0)))))
        = 5 * (4 * (3 * (2 * (1 * 1))))
        = 5 * (4 * (3 * (2 * 1)))
        = 5 * (4 * (3 * 2))
        = 5 * (4 * 6)
        = 5 * 24
        = 120

Now, let's take another example.

Fibonacci sequence

The Fibonacci sequence is a sequence in which each number is the sum of the two preceding ones, that means:

fib(n) = fib(n - 1) + fib(n - 2), where fib(0) = 0, fib(1) = 1.

The task here is to calculate the nth Fibonacci number (fib(n))We won't look at the iterative way to solve this problem. We directly jump to a recursive solution.

Fibonacci Number using recursion

Since we know the base case and we know how to divide the actual problem into smaller problems, we can easily write the method to do that recursively:

public static int fib (int n) {
    if (n == 0 || n == 1) return n;
    return fib (n - 1) + fib (n - 2);
}

The above solution is clean and accurate, but there is a problem with that. Let's look at the calls being made to fib method when we call fib(6).

Image showing calls made to fib method to calculate the 6th fibonacci number

Referring to the above image, let's answer some questions.

  1. How many times do we calculate the values of fib(6)?

    Answer: Once

  2. How many times do we calculate the value of fib(5)?
    Answer: Once

  3. How many times do we calculate the value of fib(4)?

    Answer: 2 times

  4. How many times do we calculate the value of fib(3)?
    Answer: 3 times

  5. How many times do we calculate the value of fib(2)?
    Answer: 5 times

  6. How many times do we calculate the value of fib(1)?

    Answer: 8 times

  7. How many times do we calculate the value of fib(0)?

    Answer: 5 times

As we can see we are calculating fib(x) multiple times for many values of x, which makes the method run very slow. To avoid calculating the same value multiple times, we can use memoization.

Fibonacci Number using Memoization

Basically, we calculate the value once and remember it for the lifetime of the program. And to remember the value, we can use Map or any other similar data structure. While calculating the value of fib(n), we first check in the map if we have previously calculated the value of fib for the given n. If yes, we return the value. Otherwise, we calculate the value of fib(n) using recursion and store the calculated value in the map. So that, next time we get the value of fib for n. Let's implement it in code:

import java.util.Map;
import java.util.HashMap;

class Main {
    private static final Map<Integer, Integer> map = new HashMap<>();
    public static int fib (int n) {
        if (n == 0 || n == 1) return n;
        if (map.containsKey (n)) return map.get(n);
        int answer = fib (n - 1) + fib (n - 2);
        map.put (n, answer);
        return answer;
    }
    ...
}

In the above code snippet, we are checking if the map has n as key (if it does, that means, we have calculated fib(n) and it is stored in the map). If yes, then we return the corresponding value. Otherwise, we calculate fib(n) using recursive formula and store the (n, fib(n)) key-value pair in the map for future reference.

That's how recursion and memoization work. Let's take a final example and solve it using iteration, recursion, and memoization.

Pascal's triangle

It's hard for me to write the definition of Pascal's Triangle formally, so I am attaching 2 images of it from Wikipedia:

Pascal's Triangle

In Pascal's triangle, each number is the sum of the two numbers directly above it.

In programming, it is represented as a matrix or 2-d array:

Pascal's Triangle

We can see, the value of Pascal's triangle, which can be represented as P(r, c) for row r and column c is (r >= c)

  • 1 if c = 0

  • 1 if r = c

  • P(r - 1, c - 1) + P(r - 1, c), else

Now, the problem is that we are given an integer n. And we need to print Pascal's Triangle till the row n. (The first row is considered as row 0).

Iterative way to Pascal's Triangle

Let's create a 2-d array P with n + 1 rows and n + 1 columns. For each row r the first column (c = 0) and the last column (c = r) is 1. The rest can be calculated using the formula above. Let's write code for it.

private static void printPascalTriangle (int n) {
    int[][] P = new int[n + 1][n + 1];
    for (int r = 0; r <= n; r++) {
        P[r][0] = 1;
        P[r][r] = 1;
        for (int c = 1; c < r; c++) {
            P[r][c] = P[r - 1][c - 1] + P[r - 1][c];
        }
    }
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(P[r][c] + " ");
        }
        System.out.println();
    }
}

This method prints Pascal's Triangle from row 0 to row n. Now, let's solve it using recursion.

Iterative way to Pascal's Triangle

Let's create a method pascalNumber(int r, int c) that gives the number at the row r and column c in the Pascal's Triangle. We already have a recursive formula to calculate this with base conditions from the definition of Pascal's Triangle. So, it is quite easy to implement this method.

private static int pascalNumber (int r, int c) {
    if (c == 0) return 1;
    if (r == c) return 1;
    return pascalNumber(r - 1, c - 1) + pascalNumber(r - 1, c);
}

private static void printPascalTriangle (int n) {
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(pascalNumber(r, c) + " ");
        }
        System.out.println();
    }    
}

In the above code snippet, we are making some recursive calls to solve sub-problems. But, we will end up solving the same problem again and again and the program will become too slow. So, let's use memoization to improve the performance.

Memoization way to Pascal's Triangle

To store the calculated value for (r, c) pair, we can use a 2-d array P. We initialize P in printPascalTriangle method. If the value of P[r][c] is zero, that means we have not calculated the value for (r, c) yet, and we calculate this value using recursion and store it. Otherwise, if it is not zero, we simply return it.

private static int[][] P;
private static int pascalNumber (int r, int c) {
    if (P[r][c] != 0) return P[r][c];
    if (c == 0) return 1;
    if (r == c) return 1;
    P[r][c] = pascalNumber(r - 1, c - 1) + pascalNumber(r - 1, c);
    return P[r][c];
}

private static void printPascalTriangle (int n) {
    P = new int[n + 1][n + 1];
    for (int r = 0; r <= n; r++) {
        for (int c = 0; c <= r; c++) {
            System.out.print(pascalNumber(r, c) + " ");
        }
        System.out.println();
    }
}

This way we don't calculate the same value again and again. This improves the time complexity of the method.

Conclusion

I hope you have learned something new from this article. The source code can be found here. Please, give your suggestions and subscribe to my newsletter. It will be a great help. Thank you for your time. Will meet soon.