2
votes

I am trying to study Elixir and function programming and have had trouble understanding this example in the book Elixir in Action.

defmodule ListHelper do
  def sum([]), do: 0
  def sum([head | tail]) do
    head + sum(tail)
  end
end

ListHelper.sum([12,3,4])

The return value of this is 19, but what I am not understanding is how the values are being accumulated.

I thought head was being continuously updated and then when the pattern matched to [] then the accumulated head would be added to 0 and the function would exit, but after playing with it I'm now thinking that's not what's going on. Can someone offer an alternative explanation for what's going on in this example? If I need to explain some more I can try to revisit this.

2

2 Answers

6
votes

sum([head | tail]) is head + sum(tail), so sum([12,3,4]) is 12 + sum([3,4]), sum([3,4]) is 3 + sum([4]) and sum([4]) is 4 + sum([]). sum([]) is 0, so in total we get:

sum([12,3,4]) = 12 + sum([3,4])
              = 12 + 3 + sum([4])
              = 12 + 3 + 4 + sum([])
              = 12 + 3 + 4 + 0
              = 19

The recursive call to sum isn't a tail-call, so no tail-call optimization will happen here. To make it TCO, one needs to make a recursive call to sum to be the last one.

defmodule ListHelper do
  def sum([], acc), do: acc
  def sum([head | tail], acc),
    do: sum(tail, acc + head)
end

ListHelper.sum([12,3,4], 0)
#⇒ 19
3
votes

I thought head was being continuously updated...

Nope. A new, separate head variable is created for each call to sum(). That is accomplished by creating a new stack frame, where a frame contains all the local variables created by the function call--including the parameter variables like head.

If you write something like:

x = 3 + func1()

elixir has to evaluate func1() in order to calculate a value for the x variable. And, the definition of func1() may create its own x variable, so elixir allocates a new stack frame to calculate the return value of func1(). Once elixir calculates the return value of func1(), that value is substituted into the line above:

           42
            |
            V

x = 3 + func1()

x = 3 + 42
x = 45

...and elixir can calculate a value for the x variable.

The same thing happens if you write:

4 + sum([5,6,7])

The only difference is that elixir will allocate many stack frames on the stack to calculate the return value of sum([5,6,7]). With a recursive function call, the return value of each stack frame will depend on the return value of another stack frame. Only when the base case is reached, and sum([]) returns 0 can elixir start filling in the required values inside each stack frame to calculate a return value.

  4 + sum([5,6,7]) 
           |  
           |
     #1    V  sum([5,6,7])
   +--------------------+              
   | head = 5           |              
   | tail = [6,7]       |              
   |                    |             
   | return:            |            
   |   head + sum(tail) |                    
   |    5 +  sum([6,7]) |
   +-------------|------+
                 |
        #2       V  sum([6,7])
      +----------------------+   
      | head = 6             |    
      | tail = [7]           |   
      |                      |        
      | return:              |             
      |     head + sum(tail) |         
      |       6  + sum([7])  |
      +-----------------|----+
                        |        
            #3          V  sum([7])
          +----------------------+     
          | head = 7             |     
          | tail = []            |     
          |                      |     
          | return:              |     
          |     head + sum(tail) |     
          |       7  + sum([])   |     
          +----------------|-----+      
                           |
                #4         V   sum([])
             +-----------------------+
             |  return: 0            |
             +-----------------------+

Note that there are three separate head variables in existence at the same time. Once the bottom stack frame returns, that sets in motion the following steps:

 4 + sum([5,6,7]) 
           ^
           | 
           +---18-----------<----------+
   +--------------------+              |
   | head = 5           |              |
   | tail = [6,7]       |              |
   |                    |              ^
   | return:            |              |
   |   head + sum(tail) |              |      
   |    5 +  sum([6,7]) ---->---18-----+  #7
   +-------------^------+
                 |
                 +--13------<----------+
      +----------------------+         |
      | head = 6             |         |
      | tail = [7]           |         |
      |                      |         ^
      | return:              |         |     
      |     head + sum(tail) |         |
      |       6  + sum([7]) -|-->--13--+  #6
      +-----------------^----+
                        |        
                        +---7--<-------+
          +----------------------+     |
          | head = 7             |     ^
          | tail = []            |     |
          |                      |     |
          | return:              |     ^
          |     head + sum(tail) |     |
          |       7  + sum([]) --|>--7-+     
          +----------------^-----+      
                           |
                           +----0--<---+
               +-----------------+     |
               |  return: 0    --|>--0-+  #5
               +-----------------+

If you add some print statements to your code, you can see how the steps are performed:

defmodule My do
  def sum([]) do
    IO.puts("Inside sum([]):\n\treturning 0")
    0
  end
  def sum([head | tail]) do
    IO.puts("Inside sum([head|tail]), head=#{head} tail=#{inspect(tail, charlists: :as_lists)}")
    val = head + sum(tail)
    IO.puts("Inside sum([head|tail]), head=#{head} tail=#{inspect(tail, charlists: :as_lists)}")
    IO.puts("\treturning #{val}")
    val
  end
end

My.sum([5,6,7])

At the command line:

~/elixir_programs$ elixir a.exs
Inside sum([head|tail]), head=5 tail=[6, 7]
Inside sum([head|tail]), head=6 tail=[7]
Inside sum([head|tail]), head=7 tail=[]
Inside sum([]):
    returning 0
Inside sum([head|tail]), head=7 tail=[]
    returning 7
Inside sum([head|tail]), head=6 tail=[7]
    returning 13
Inside sum([head|tail]), head=5 tail=[6, 7]
    returning 18