Recursion: How to Write a ‘sum’ Function
With all of the images of the previous lesson firmly ingrained in your brain, let’s write a sum
function using recursion!
You can follow along with the source code in this lesson by cloning my project from this Github URL:
Given a List
of integers, such as this one:
val list = List(1, 2, 3, 4)
let’s start tackling the problem in the usual way, by thinking, “Write the function signature first.”
What do we know about the sum
function we want to write? Well, we know a couple of things:
It will take a list of integers as input
Because it returns a sum of those integers, the function will return a single value, an
Int
Armed with only those two pieces of information, I can sketch the signature for a sum
function like this:
def sum(list: List[Int]): Int = ???
Note: For the purposes of this exercise I’m assuming that the integer values will be small, and the list size will also be small. That way we don’t have to worry about all of the
Int
s adding up to aLong
.
At this point a functional programmer will think of a “sum” algorithm as follows:
If the
sum
function is given an empty list of integers, it should return0
. (Because the sum of nothing is zero.)Otherwise, if the list is not empty, the result of the function is the combination of (a) the value of its head element (
1
, in this case), and (b) the sum of the remaining elements in the list (2,3,4
).
A slight restatement of that second sentence is:
“The sum of a list of integers is the sum of the head element, plus the sum of the tail elements.”
As Eckhart Tolle is fond of saying, “That statement is true, is it not?”
(Yes, it is.)
Thinking about a List
in terms of its head and tail elements is a standard way of thinking when writing recursive functions.
Now that we have a little idea of how to think about the problem recursively, let’s see how to implement those sentences in Scala code.
The first sentence above states:
If the
sum
function is given an empty list of integers, it should return0
.
Recursive Scala functions are often implemented using match
expressions. Using (a) that information and (b) remembering that an empty list contains only the Nil
element, you can start writing the body of the sum
function like this:
def sum(list: List[Int]): Int = list match {
case Nil => 0
This is a Scala way of saying, “If the List
is empty, return 0
.” If you’re comfortable with match
expressions and the List
class, I think you’ll agree that this makes sense.
If you prefer using return
statements at this point in your programming career, you can write that code like this:
def sum(list: List[Int]): Int = list match {
case Nil => return 0
Because a pure function doesn’t “return” a value as much as it “evaluates” to a result, you’ll want to quickly drop return
from your vocabulary, but … I also understand that using return
can help when you first start writing recursive functions.
You can also write this function using an if/then expression, but because pattern matching is such a big part of functional programming, I prefer match
expressions.
Because Nil
is equivalent to List()
, you can also write that case
expression like this:
case List() => 0
However, most functional programmers use Nil
, and I’ll continue to use Nil
in this lesson.
That case
expression is a Scala/FP implementation of the first sentence, so let’s move on to the second sentence.
The second sentence says, “If the list is not empty, the result of the algorithm is the combination of (a) the value of its head element, and (b) the sum of its tail elements.”
To split the list into head and tail components, I start writing the second case
expression like this:
case head :: tail => ???
If you know your case
expressions, you know that if sum
is given a list like List(1,2,3,4)
, this pattern has the result of assigning head
to the value 1
, and assigning tail
the value List(2,3,4)
:
head = 1
tail = List(2,3,4)
(If you don’t know your
case
expressions, please refer to the match/case lessons in Chapter 3 of the Scala Cookbook.)
This case
expression is a start, but how do we finish it? Again I go back to the second sentence:
If the list is not empty, the result of the algorithm is the combination of (a) the value of its head element, and (b) the sum of the tail elements.
The “value of its head element” is easy to add to the case
expression:
case head :: tail => head ...
But then what? As the sentence says, “the value of its head element, and the sum of the tail elements,” which tells us we’ll be adding something to head
:
case head :: tail => head + ???
What are we adding to head
? The sum of the list’s tail elements. Hmm, now how can we get the sum of a list of tail elements? How about this:
case head :: tail => head + sum(tail)
Whoa. That code is a straightforward implementation of the sentence, isn’t it?
(I’ll pause here to let that sink in.)
If you combine this new case
expression with the existing code, you get the following sum
function:
def sum(list: List[Int]): Int = list match {
case Nil => 0
case head :: tail => head + sum(tail)
}
And that is a recursive “sum the integers in a List
” function in Scala/FP. No var
’s, no for
loop.
If you’re new to case
expressions, it’s important to note that the head
and tail
variable names in the second case
expression can be anything you want. I wrote it like this:
case head :: tail => head + sum(tail)
but I could have written this:
case h :: t => h + sum(t)
or this:
case x :: xs => x + sum(xs)
This last example uses variable names that are commonly used with FP, lists, and recursive programming. When working with a list, a single element is often referred to as x
, and multiple elements are referred to as xs
. It’s a way of indicating that x
is singular and xs
is plural, like referring to a single “pizza” or multiple “pizzas.” With lists, the head element is definitely singular, while the tail can contain one or more elements. I’ll generally use this naming convention in this book.
To demonstrate that sum
works, you can clone my RecursiveSum project on Github — which uses ScalaTest to test sum
— or you can copy the following source code that extends a Scala App
to test sum
:
object RecursiveSum extends App {
def sum(list: List[Int]): Int = list match {
case Nil => 0
case x :: xs => x + sum(xs)
}
val list = List(1, 2, 3, 4)
val sum = sum(list)
println(sum)
}
When you run this application you should see the output, 10
. If so, congratulations on your first recursive function!
“That’s great,” you say, “but how exactly did that end up printing 10
?”
To which I say, “Excellent question. Let’s dig into that!”
As I’ve noted before, I tend to write verbose code that’s hopefully easy to understand, especially in books, but you can shrink the last three lines of code to this, if you prefer:
println(sum(List(1,2,3,4)))