Feb 20 2006

A WOW moment with continuations

Published by Brian at 8:34 am under To Be Categorized

OK, I just had a “WOW” moment with Ocaml- one of those moments where the penny drops, the light goes on, the code becomes incredibly more simple, and you just have to sit back and say “WOW!”. These moments happen most often when something you think you understood suddenly reveals itself to have unexpected depth and power. They’re still happening to me on a pretty regular basis with Ocaml.

So anyways, to start getting into the code, I was playing around with implementing Okasaki’s random access lists. Go read that paper, it’s worth the time. But the short form is that it’s a slightly more complicated version of singly linked lists in which nodes are stored in perfect trees, each tree holding 2n-1 nodes. The trees are then held in a linked list. This allows prepending in O(1), but getting the ith element of the list in O(log i)- note that for an N-element list, this means the worst case access time is O(log N), but for small i (elements near the head of the list) the access time is much shorter than that.

Here is the basic implementation and functions on the data structure I had written:

type 'a node_t =
    | Node of 'a * 'a node_t * 'a node_t
    | Leaf of 'a
;;

type 'a t = (int * 'a node_t) list;;

let empty : 'a t = [];;

let is_empty (lst : 'a t) = lst == [];;

let length (lst: 'a t) = List.fold_left (fun s (i, _) -> s + i) 0 lst;;

let head : 'a t -> 'a = function
    | (c, Node(x, _, _)) :: _ when c > 1 -> x
    | (1, Leaf(x)) :: _ -> x
    | [] -> invalid_arg "Ralist.head"
    | _ -> assert false
;;

let tail : 'a t -> 'a t = function
    | (c, Node(_, l, r)) :: t when c > 1 -> (c/2, l) :: (c/2, r) :: t
    | (1, Leaf(_)) :: t -> t
    | [] -> invalid_arg "Ralist.tail"
    | _ -> assert false
;;

let cons x : 'a t -> 'a t = function
    | (c, l) :: (c', r) :: t when c == c' ->
        ((2*c + 1), Node(x, l, r)) :: t
    | t -> (1, Leaf(x)) :: t
;;

let rec nth (lst: 'a t) idx =
    if (idx < 0) then
        invalid_arg "Ralist.nth"
    else
    match lst with
        | (c, _) :: t when c <= idx -> nth t (idx - c)
        | (c, n) :: _ ->
            let rec loop idx clen = function
                | Node(x, l, r) ->
                    if (idx == 0) then
                        x
                    else if (idx <= clen) then
                        loop (idx-1) (clen/2) l
                    else
                        loop (idx - clen - 1) (clen/2) r
                | Leaf(x) ->
                    let _ = assert (idx == 0) in
                    x
            in
            loop idx (c/2) n
        | [] -> invalid_arg "Ralist.nth"
;;

So now we’re getting into the WOW moment. I wanted to write a search function. This function would take a standard ordering compare function, and search a random access list that is assumed to be in sorted order. It’d have a type of ('a -> int) -> 'a t -> 'a. But the important aspect of it was that it’d have O(log N) complexity. The binary tree structure of the random access list should have made this easy, I thought. Binary trees lend themselves to binary searches. Unfortunately, the structure of the trees used here complicate things. The root node in the tree is the first element in the tree, not a middle element (like in normal trees). Which means the middle element is in fact the root node of the right subtree. So we have to look at the right subtree before we know wether the node we’re looking for is in the right subtree or the left subtree (assuming it exists). Worse yet, we don’t have a tree, we have a whole list of trees. So I started trying various things to see if I could cook up a solution.

The problem started to clarify itself when I realized both problems were the same problem in slightly different disguises- that I don’t know what tree the node is in until I overshoot. When I’m walking down the list of trees, I always want to go on to the next tree, unless I discover I discover I’ve overshot (that the first node of the tree is after the node I’m looking for), at which point I know the node I’m looking for is in the tree just passed. Likewise, if I know the node is somewhere in this subtree, I always go right- unless I discover I’ve overshot, in which case the node I’m looking for is in the left subtree. In both cases, what I need is someway to “jump back one step”.

At which point the sunbeam burst through the clouds to light on me, the Hallelujah chorus started up, and I started yelling “Continuations, Elwood! Continuations!”

The complexity was comming from the fact that the called function needed to know from whence it was called, so that it could backtrack correctly. The idea was that instead of having the called function know this, I’d simply pass in a backtrack function. When the called function realized it had overshot, it would call the backtrack function. The calling function could simply create a quick, anonomous function that did the right thing when called. Now, the called function doesn’t need to know where it was called from, the calling function just needs to pass in what is to be done when backtracking is needed. Now, instead of a feeback complexity loop, we have an amazing simplicity.

Here’s the code:

let rec search f (lst: 'a t) =
    let not_found () = raise Not_found in
    let rec aloop back = function
        | Node(x, l, r) ->
            let c = f x in
            if c < 0 then
                back ()
            else if c == 0 then
                x
            else
                aloop (fun () -> aloop not_found l) r
        | Leaf(x) ->
            let c = f x in
                if c < 0 then
                    back ()
                else if c == 0 then
                    x
                else
                    raise Not_found
    in
    let rec bloop back = function
        | (_, Node(x, l, r)) :: t ->
            let c = f x in
            if c < 0 then
                back ()
            else if c == 0 then
                x
            else
                bloop (fun () -> aloop (fun () -> aloop not_found l) r) t
        | (_, Leaf(x)) :: t ->
            let c = f x in
            if c < 0 then
                back ()
            else if (c == 0) then
                x
            else
                bloop not_found t
        | [] -> back ()
    in
    bloop not_found lst
;;

OK, so I’m allocating O(log N) continuations. I’ll take that hit- it’s not that expensive. And it’s worth it for a function that’s easy to verify by simple code inspection that a) it works, and b) it has worst-case O(log N) behavior.

This is part of what I was talking about with it being not so much what the language makes possible, but what the language makes simple. Now that it’s been explained (well, sorta), and the code writtin in Ocaml, I could write this in Java fairly easily. But it wouldn’t be a solution a non-functional programmer would have ever come up with, and the Java version would be nigh unto impossible to figure out without knowing functional programming. The Java implementation of this algorithm would not have served to help teach functional programming to Java programmers. This is what Chia and I mean by Ocaml messing with your brain, and giving you different ways to think about problems.

Popularity: 3% [?]

4 Responses to “A WOW moment with continuations”

  1. Candideon 20 Feb 2006 at 12:09 pm

    Functional is beautiful.

    BTW, I editted your post to be in the “Whitepaper Links” category, since you’re linking to a whitepaper.

  2. bhurt-awon 20 Feb 2006 at 12:35 pm

    One of the things you snickered about over lunch had a point. I hadn’t realized until you laughed about it that there was anything exceptional about using as my primary data structure a list of tuples of ints and variant type trees. It was the natural Ocaml data structure to use, and not at all clumsy to work with. Reaching in to grab the left subtree of the tree of the first tuple in the list is a trivial bit of pattern matching, hardly worth commenting.

    But it becomes an issue when I start considering how I’d write this code in Java. Each one of those pattern matches expands out into a fair bit of Java code. If I was willing to write my own list class (probably) I could collapse the tuple and list elements into a single object, for a minor savings in space (O(log N) words).

    But how would I represent the variant type? Should that be a class heirarchy all by itself? Note that optimizing the leaf nodes to be smaller is a large optimization- better than half of my nodes are going to be leaf nodes and not have any children. Which means having leaf nodes with no child pointers saves me large hunks of memory- O(N) words. So I want to be able to keep the special case leaf nodes, but that means at least two classes (one for branch nodes, one for leaf nodes) and an interface. And how do I tell them apart- do I have a special virtual function, or do I use reflection? I suspect the former. But that means that these are not simple C struct objects (objects with only public non-static member variables and no member functions).

    And then there is the private, anonymous, first-class, and/or inner functions I’m using all over the place. In the search function, my anonymous inner functions have anonymous inner functions. In Ocaml, my response is “So?” In Java, life becomes interesting, as you anonymous inner classes get anonymous inner classes. Is that even legal in Java? It is certain to cause you standard Java programmer’s head to explode.

  3. bhurt-awon 23 Feb 2006 at 4:28 pm

    A start on a Java version of the above code, for those who might find Java easier to parse:

    public class Ralist {
    
        /* A tree node */
        private class Node {
            public final Object data;
            public final Node left;
            public final Node right;
            public Node(Object d, Node l, Node r) {
                data = d;
                left = l;
                right = r;
            }
        };
    
        /* A list element- we implement our own, singly linked, applicative
         * list here.
         */
        private class Elem {
            public final int count;
            public final Node tree;
            public final Elem next;
            public Elem(int c, Node t, Elem n) {
                count = c;
                tree = t;
                next = n;
            }
        }
    
        private final Elem lst;
    
        /* Creates an empty list */
        public Ralist() {
            lst = null;
        }
    
        private Ralist(Elem e) {
            lst = e;
        }
    
        public boolean isEmpty() { return lst == null; }
    
        public int length() {
            Elem curr = lst;
            int len = 0;
            while (curr != null) {
                len += curr.count;
                curr = curr.next;
            }
            return len;
        }
    
        public Object head() {
            if (lst == null) {
               throw new IllegalArgumentException();
            } else {
                return lst.tree.data;
            }
        }
    
        public Ralist tail() {
            if (lst == null) {
               throw new IllegalArgumentException();
            } else if (lst.count == 1) {
                return new Ralist(lst.next);
            } else {
                return new Ralist(new Elem(lst.count/2, lst.tree.left,
                                           (new Elem(lst.count/2, lst.tree.right,
                                                     lst.next))));
            }
        }
    
        public Ralist cons(Object x) {
            if ((lst != null) && (lst.next != null) &&
                (lst.count == lst.next.count))
            {
                return new Ralist(new Elem((2*lst.count) + 1,
                                  (new Node(x, lst.tree, lst.next.tree)),
                                  lst.next.next));
            } else {
                return new Ralist(new Elem(1, new Node(x, null, null), lst));
            }
        }
    
        public Object nth(int idx) {
            if (idx  < 0) {
                throw new IllegalArgumentException();
            }
            Elem curr = lst;
            while ((curr != null) && (curr.count <= idx)) {
                idx -= curr.count;
                curr = curr.next;
            }
            if (curr == null) {
                throw new IndexOutOfBoundsException();
            }
            Node node = curr.tree;
            int c = curr.count/2;
            while (idx != 0) {
                if (idx <= c) {
                    node = node.left;
                    idx -= 1;
                } else {
                    node = node.right;
                    idx -= c + 1;
                }
            }
            return node.data;
        }
    
    };
    

  4. Hanson 25 Dec 2006 at 11:00 am

    Thanks a lot. Now for the first time I really understand continuations.

Trackback URI | Comments RSS

Leave a Reply

Green Web Hosting! This site hosted by DreamHost.