Skip to content

Commit

Permalink
Misc changes following feedback from HKT
Browse files Browse the repository at this point in the history
  • Loading branch information
noelwelsh committed Oct 13, 2023
1 parent 3a45542 commit b136170
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/pages/adt/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ So algebraic data types consist of sum and product types.

Algebraic data types are closed worlds, which means they cannot be extended after they have been defined. In practical terms this means we have to modify the source code where we define the algebraic data type if we want to add or remove elements.

The closed world property is important because it gives us guarantees we would not otherwise have. In particular, it allows the compiler to check, when we use an algebraic data type, that we handle all possible cases and alert us if we don't. This is known as **exhaustivity checking**. This is an example of how functional programming prioritizes reasoning about code---in this case automated reasoning by the compiler---over other properties such as extensibility.
The closed world property is important because it gives us guarantees we would not otherwise have. In particular, it allows the compiler to check, when we use an algebraic data type, that we handle all possible cases and alert us if we don't. This is known as **exhaustivity checking**. This is an example of how functional programming prioritizes reasoning about code---in this case automated reasoning by the compiler---over other properties such as extensibility. We'll learn more about exhaustivity checking soon.
29 changes: 29 additions & 0 deletions src/pages/adt/scala.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,32 @@ We've seen that the Scala 3 representation of algebraic data types, using `enum`
- Scala 3's doesn't currently support nested `enums` (`enums` within `enums`). This may change in the future, but right now it can be more convenient to use the Scala 2 representation to express this without having to convert to disjunctive normal form.

- Scala 2's representation can express things that are almost, but not quite, algebraic data types. For example, if you define a method on an `enum` you must be able to define it for all the members of the `enum`. Sometimes you want a case of an `enum` to have methods that are only defined for that case. To implement this you'll need to use the Scala 2 representation instead.


#### Exercise: Tree {-}

To gain a bit of practice defining algebraic data types, code the following description in Scala (your choice of version, or do both.)

A `Tree` with elements of type `A` is:

- a `Leaf` with a value of type `A`; or
- a `Node` with a left and right child, which are both `Trees` with elements of type `A`.

<div class="solution">
We can directly translate this binary tree into Scala. Here's the Scala 3 version.

```scala mdoc:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])
}
```

In the Scala 2 encoding we write

```scala mdoc:reset:silent
sealed abstract class Tree[A] extends Product with Serializable
final case class Leaf[A](value: A) extends Tree[A]
final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A]
```
</div>
8 changes: 6 additions & 2 deletions src/pages/adt/structural-corecursion.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Structural corecursion is the opposite—more correctly, the dual—of structural recursion.
Whereas structural recursion tells us how to take apart an algebraic data type,
structural corecursion tells us how to build up an algebraic data type.
structural corecursion tells us how to build up, or construct, an algebraic data type.
We can use structural corecursion whenever the output of a method or function is an algebraic data type.


Expand Down Expand Up @@ -48,6 +48,8 @@ enum MyList[A] {
}
```

The output of this method is a `MyList`.
Since we need to construct a `MyList` we can use structural corecursion.
The structural corecursion strategy says we write down all the constructors and then consider the conditions that will cause us to call each constructor.
So our starting point is to just write down the two constructors, and put in dummy conditions.

Expand Down Expand Up @@ -104,7 +106,9 @@ We recognised that we were producing a `List`, that there were two possibilities
Formalizing structural corecursion as a separate strategy allows us to be more conscious of where we apply it.
Finally, notice how I switched from an `if` expression to a pattern match expression as we progressed through defining `map`.
This is perfectly fine.
Both kinds of expression can achieve the same effect, though if we wanted to continue using an `if` we'd have to define a method (for example, `isEmpty`) that allows us to distinguish an `Empty` element from a `Pair`.
Both kinds of expression achieve the same effect.
Pattern matching is a little bit safer due to exhaustivity checking.
If we wanted to continue using an `if` we'd have to define a method (for example, `isEmpty`) that allows us to distinguish an `Empty` element from a `Pair`.
This method would have to use pattern matching in its implementation, so avoiding pattern matching directly is just pushing it elsewhere.


Expand Down
282 changes: 272 additions & 10 deletions src/pages/adt/structural-recursion.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,30 +309,195 @@ In these situations we can use dynamic dispatch instead.
We'll learn more about this when we look at generalized algebraic data types.


#### Exercise: Methods for Tree {-}

In a previous exercise we created a `Tree` algebraic data type:

```scala mdoc:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])
}
```

Or, in the Scala 2 encoding:

```scala mdoc:reset:silent
sealed abstract class Tree[A] extends Product with Serializable
final case class Leaf[A](value: A) extends Tree[A]
final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A]
```

Let's get some practice with structural recursion and write some methods for `Tree`. Implement

* `size`, which returns the number of values (`Leafs`) stored in the `Tree`;
* `contains`, which returns `true` if the `Tree` contains a given element of type `A`, and `false` otherwise; and
* `map`, which creates a `Tree[B]` given a function `A => B`

Use whichever you prefer of pattern matching or dynamic dispatch to implement the methods.

<div class="solution">
I chose to use pattern matching to implement these methods. I'm using the Scala 3 encoding so I have no choice.

I start by creating the method declarations with empty bodies.

```scala mdoc:reset:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def size: Int =
???

def contains(element: A): Boolean =
???

def map[B](f: A => B): Tree[B] =
???
}
```

Now these methods all transform an algebraic data type so I can implement them using structural recursion. I write down the structural recursion skeleton for `Tree`, remembering to apply the recursion rule.

```scala
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def size: Int =
this match {
case Leaf(value) => ???
case Node(left, right) => left.size ??? right.size
}

def contains(element: A): Boolean =
this match {
case Leaf(value) => ???
case Node(left, right) => left.contains(element) ??? right.contains(element)
}

def map[B](f: A => B): Tree[B] =
this match {
case Leaf(value) => ???
case Node(left, right) => left.map(f) ??? right.map(f)
}
}
```

Now I can use the other reasoning techniques to complete the method declarations.
Let's work through `size`.

```scala
def size: Int =
this match {
case Leaf(value) => 1
case Node(left, right) => left.size ??? right.size
}
```

I can reason independently by case.
The size of a `Leaf` is, by definition, 1.

```scala
def size: Int =
this match {
case Leaf(value) => 1
case Node(left, right) => left.size ??? right.size
}
```

Now I can use the rule for reasoning about recursion: I assume the recursive calls successfully compute the size of the left and right children. What is the size then of the combined tree? It must be the sum of the size of the children. With this, I'm done.

```scala
def size: Int =
this match {
case Leaf(value) => 1
case Node(left, right) => left.size + right.size
}
```

I can use the same process to work through the other two methods, giving me the complete solution shown below.

```scala mdoc:reset:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def size: Int =
this match {
case Leaf(value) => 1
case Node(left, right) => left.size + right.size
}

def contains(element: A): Boolean =
this match {
case Leaf(value) => element == value
case Node(left, right) => left.contains(element) || right.contains(element)
}

def map[B](f: A => B): Tree[B] =
this match {
case Leaf(value) => Leaf(f(value))
case Node(left, right) => Node(left.map(f), right.map(f))
}
}
```
</div>


### Folds as Structural Recursions

Let's finish by looking at the fold method as an abstraction over structural recursion.
We know that every algebraic data type has a structural recursion skeleton that is determined entirely by the structure of the algebraic data type.
For `MyList`, defined as
If you did the `Tree` exercise above, you will have noticed that we wrote the same kind of code again and again.
Here are the methods we wrote.
Notice the left-hand sides of the pattern matches are all the same, and the right-hand sides are very similar.

```scala mdoc:reset:silent
```scala
def size: Int =
this match {
case Leaf(value) => 1
case Node(left, right) => left.size + right.size
}

def contains(element: A): Boolean =
this match {
case Leaf(value) => element == value
case Node(left, right) => left.contains(element) || right.contains(element)
}

def map[B](f: A => B): Tree[B] =
this match {
case Leaf(value) => Leaf(f(value))
case Node(left, right) => Node(left.map(f), right.map(f))
}
```

This is the point of structural recursion: we recognize and formalize this similarity.
However, as programmers we might want to abstract over this repetition.
Can we write a method that captures everything that doesn't change in a structural recursion, and allows the caller to pass arguments for everything that does change?
It turns out we can. For any algebraic data type we can define at least one method, called a fold, that captures all the parts of structural recursion that don't change and allows the caller to specify all the problem specific parts.

Let's see how this is done using the example of `MyList`.
Recall the definition of `MyList` is

```scala mdoc:silent
enum MyList[A] {
case Empty()
case Pair(head: A, tail: MyList[A])
}
```

the skeleton is
We know the structural recursion skeleton for `MyList` is

```scala
aList match {
case Empty() => ???
case Pair(head, tail) => ??? recursion(tail)
}
def doSomething[A](list: MyList[A]) =
list match {
case Empty() => ???
case Pair(head, tail) => ??? doSomething(tail)
}
```

For any algebraic data type we can define at least one method, called a fold, that captures all the parts of structural recursion that don't change and allows the caller to specify all the problem specific parts.
For `MyList` this means defining a method
Implementing fold for `MyList` means defining a method

```scala
def fold[A, B](list: MyList[A]): B =
Expand Down Expand Up @@ -382,6 +547,7 @@ def foldLeft[A,B](list: MyList[A], empty: B, f: (A, B) => B): B =
```

which is `foldLeft`, the tail-recursive variant of fold for a list.
(We'll talk about tail-recursion in the next chapter.)

We can follow the same process for any algebraic data type to create its folds.
The rules are:
Expand All @@ -396,3 +562,99 @@ Returning to `MyList`, it has:
- two cases, and hence two parameters to fold (other than the parameter that is the list itself);
- `Empty` is a constructor with no arguments and hence we use a parameter of type `B`; and
- `Pair` is a constructor with one parameter of type `A` and one recursive parameter, and hence the corresponding function has type `(A, B) => B`.


#### Exercise: Tree Fold {-}

Implement a fold for `Tree` defined earlier.
There are several different ways to traverse a tree (pre-order, post-order, and in-order).
Just choose whichever seems easiest.

<div class="solution">
I start by add the method declaration without a body.

```scala mdoc:reset:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def fold[B]: B =
???
}
```

Next step is to add the structural recursion skeleton.

```scala
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def fold[B]: B =
this match {
case Leaf(value) => ???
case Node(left, right) => left.fold ??? right.fold
}
}
```

Now I follow the types to add the method parameters.
For the `Leaf` case we want a function of type `A => B`.

```scala
enum Tree[A] {
case Leaf(value: A => B)
case Node(left: Tree[A], right: Tree[A])

def fold[B](leaf: A => B): B =
this match {
case Leaf(value) => leaf(value)
case Node(left, right) => left.fold ??? right.fold
}
}
```

For the `Node` case we want a function that combines the two recursive results, and therefore has type `(B, B) => B`.

```scala mdoc:reset:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def fold[B](leaf: A => B)(node: (B, B) => B): B =
this match {
case Leaf(value) => leaf(value)
case Node(left, right) => node(left.fold(leaf)(node), right.fold(leaf)(node))
}
}
```
</div>


#### Exercise: Using Fold {-}

Prove to yourself that you can replace structural recursion with calls to fold, by redefining `size`, `contains`, and `map` for `Tree` using only fold.

<div class="solution">
```scala mdoc:reset:silent
enum Tree[A] {
case Leaf(value: A)
case Node(left: Tree[A], right: Tree[A])

def fold[B](leaf: A => B)(node: (B, B) => B): B =
this match {
case Leaf(value) => leaf(value)
case Node(left, right) => node(left.fold(leaf)(node), right.fold(leaf)(node))
}

def size: Int =
this.fold(_ => 1)(_ + _)

def contains(element: A): Boolean =
this.fold(_ == element)(_ || _)

def map[B](f: A => B): Tree[B] =
this.fold(v => Leaf(f(v)))((l, r) => Node(l, r))
}
```
</div>
1 change: 1 addition & 0 deletions src/pages/preface/contributors.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Francis Devereux,
Ghislain Vaillant,
Gregor Ihmor,
Henk-Jan Meijer,
HigherKindedType,
Janne Pelkonen,
Jason Scott,
Javier Arrieta,
Expand Down

0 comments on commit b136170

Please sign in to comment.