“This is a pretty strange piece of code, and it may take a few moments of thought to figure out what’s going on.”
– Real World OCaml
A few weeks ago, fellow Hacker Schooler Chen Lin and I were trying to solve a simple graph problem in Haskell. I was all ready to charge forward with something quite like the Java implementation I learned back in undergrad, but my fellow Hacker Schooler had some hesitation around whether this kind of structure would work in Haskell.
After a little bit of Googling, I found out that the canonical solution in Haskell involves something intriguingly dubbed tying the knot. I stared blankly at this HaskellWiki page with my fellow Hacker Schooler, trying to understand it quickly enough to have a useful conversation about it, and failed. We threw a couple of other ideas around and then decided to both pursue other projects. I moved on, Chen moved on, and I’m not sure either of us thought much about it…
…until yesterday, when I ran into tying the knot again. This time, it was hiding deep within (of all things!) the chapter on imperative programming in Real World OCaml, and I was unhurried and determined. “Abstract concept, I am going to understand you so hard,” I thought, jaw set.
The motivation for the Real World OCaml example is a theoretically-unrelated concept called memoization. If you were to ask a human to calculate something complicated, and then you were to ask them immediately to calculate the same thing again, chances are good they’d give you the same answer from memory the second time you asked, rather than redoing the calculation. Computers can do that too, if you help them along!
(Like all the code in this post, this code was cribbed from the example code from Real World OCaml and then marked up or reworked by me, in a desperate attempt to figure out what the heck is going on.)
(* wrap a memoizer around an arbitrary function. *) let memoize f = (* make a hash table that accepts whatever (consistent) * types of keys/values you might have *) let table = Hashtbl.Poly.create () in ( fun x -> (* and when we actually get called, *) match Hashtbl.find table x with (* look up x in the hash table *) | Some y -> y (* if we found a result, return it *) | None -> (* if not, *) let y = f x in (* calculate the result *) Hashtbl.add_exn table ~key:x ~data:y; (* and store for later *) y (* and then return it *) )
So that’s a memoizer for some arbitrary function f. We can do this because functional languages are bad-ass, although we’d get poor results if the function f had side effects:
utop # let printy x = printf "%s" x; 0;; val printy : string -> int = <fun> utop # let print_memo = memoize printy;; val print_memo : string -> int = <fun> utop # print_memo "I am the very model of a modern major general";; I am the very model of a modern major general- : int = 0 utop # print_memo "I am the very model of a modern major general";; - : int = 0
That’s OK, since we’re going to be good functional programmers and only write pure functions, which take an input and return an output and that’s it. No, that’s not the problem. The problem is this:
let rec naive_fib i = if i <= 1 then 1 else naive_fib (i - 1) + naive_fib (i - 2)
If you’ve ever written any recursive function, odds are pretty good that you’ve written this one. Its performance, as written, is terrible; each recursive call generates two more recursive calls, so it scales as 2^n with the size of the input. Not good! Moreover, most of those calls are recalculating the same thing over and over again - we’ll get 2 copies of
naive_fib (i - 2), each of those will spawn a copy of
naive_fib (i-3) and
naive_fib (i-4), and on and on until we get to
naive_fib (i - (i - 1)), at which point the recursion will bottom out and we can finally start returning values up the call stack. This is a pretty obvious candidate for memoizing, but it’s not quite so simple to memoize as we might hope:
utop # let not_quite_memoized_enough x = (memoize naive_fib) x ;; val not_quite_memoized_enough : int -> int = <fun> utop # not_quite_memoized_enough 40;; (* loooooooong pause *) - : int = 165580141 utop # not_quite_memoized_enough 40;; (* loooooooong pause *) - : int = 165580141 utop # not_quite_memoized_enough 41;; (* looooooooooooooooooong pause *) - : int = 267914296
Not only did we not get a quicker computation the second time we tried to calculate the 40th Fibonacci number, we also didn’t get any of the benefit of memoization within the function. And the computation of the 41st number didn’t use the already-known value of the 40th number! I’d say this isn’t even
not_quite_memoized_enough; this is
Every time we run
not_quite_memoized_enough, we’re invoking calling
memoize for what seems, within the scope of
not_quite_memoized_enough, to be the first time. Each call is memoizing the computed result, but that memory is immediately discarded; subsequent calls operate on an empty hash table.
How can we keep one hash table in scope for all calls to
naive_fib? If we keep a memoized version of
naive_fib around, and parameterize our function to accept another function to call to calculate the Fibonacci sequence, maybe we can just use that.
utop # let memoized = (memoize naive_fib);; val memoized : int -> int = <fun> utop # let fib_caller input_function i = if i <= 1 then 1 else input_function (i-1) + input_function (i-2);; val fib_caller : (int -> int) -> int -> int = <fun> utop # fib_caller memoized 40;; (* loooooong pause *) - : int = 165580141 utop # fib_caller memoized 40;; (* instantaneous! *) - : int = 165580141 utop # fib_caller memoized 41;; (* looooooooooooooooong pause. *) - : int = 267914296
Hm. It seems like we’re memoizing the results of top-level calls (e.g.
fib_caller memoized 40), but not the results of our internal calls - the
fib_caller memoized 39,
37, etc that are generated as we calculate the value of
fib_caller memoized 40. This means we’re also not consulting our hashtable for the results of internal calculations, even at the top level:
utop # fib_caller memoized 37;; (* loooooong pause. *) - : int = 39088169 utop # fib_caller memoized 38;; (* looooooooong pause. *) - : int = 63245986 utop # fib_caller memoized 39;; (* looooooooooong pause. *) - : int = 102334155
We can’t get very far by just wrapping
naive_fib in this memoizer - we need to break into the function itself and inject memoization directly into its cold, black, dead heart. Every call to the function has to run through the memoizer. We can try to programmatically construct a function that does this:
utop # let try_to_memoize_and_recurse fib_caller x = let rec f = memoize (fun x -> fib_caller f x) in f x ;; Error: This kind of expression is not allowed as right-hand side of `let rec'
…but we can’t
let rec f = anything that isn’t
- a function definition
- a constructor
- or the lazy keyword.
…although that “or the lazy keyword” is pretty intriguing.
utop # let lazily_memoize_and_recurse nonrecursive x = let rec f = lazy (memoize (fun x -> nonrecursive f x )) in f x ;; Error: This expression has type ('a -> 'b) lazy_t This is not a function; it cannot be applied.
lazy, in OCaml, explicitly defers the value of a computation until you manually force it later. We’re allowed to use
lazy here because reasons - by the time we actually have to do something with
f, rather than just refer to it, it will have been defined.
We can’t just make something lazy (i.e., postpone the value of its computation until it’s forced); the type system rightly tells us that a promise to compute something is not the same as computing it. Forcing just the result looks like it works, until you try to actually compute something:
utop # let lazily_memoize_and_recurse nonrecursive x = let rec f = lazy (memoize (fun x -> nonrecursive f x )) in (Lazy.force f) x ;; val lazily_memoize_and_recurse : (('a -> 'b) lazy_t -> 'a -> 'b) -> 'a -> 'b = <fun> utop # lazily_memoize_and_recurse fib_caller 40;; Error: This expression has type (int -> int) -> int -> int but an expression was expected of type ('a -> 'b) lazy_t -> 'a -> 'b Type int -> int is not compatible with type ('a -> 'b) lazy_t
We need to also force the value of
f in our call to
utop # let lazily_memoize_and_recurse nonrecursive x = let rec f = lazy (memoize (fun x -> nonrecursive (Lazy.force f) x )) in (Lazy.force f) x ;; val lazily_memoize_and_recurse : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b = <fun> utop # lazily_memoize_and_recurse fib_caller 40;; (* instantaneous! *) - : int = 165580141
There are a few other ways to do this, including, of all things, a C# implementation. The Real World OCaml explanation makes reference first to an imperative version (not too surprising, since the explanation is in the imperative programming section!), which is also instructive:
let memoize_and_recurse nonrecursive x = let function_reference = ref (fun _ -> assert false) in let f = memoize (fun x -> nonrecursive !function_reference x) in function_reference := f; f x
This is another way of referring to something before we’ve defined it -
function_reference is first just a reference placeholder function, then later we define it properly to point back to itself. It’s not much different, down in the weeds, from the more purely-functional version above that uses laziness, and one can use either definition as an aid for understanding the other. Or maybe one could begin by understanding neither, use both definitions, tie them together, and somehow arrive at an understanding of both?