Pattern Matching in Java

If you have programmed in Scala you are no doubt familiar with case classes (also known as value classes) and pattern matching, especially in the case of sealed classes (which is said for a class whose subclasses are all known).

Pattern matching allows writing code like this:

/** A List is either empty or an element followed by a List */
abstract sealed class List[T]
case class Empty[T] extends List[T]
case class Cons(head: T, tail: List[T]) extends List[T]
 
/** Use pattern matching to calculate the length of a List */
def length(l: List[T]): int = l match {
  case Empty: 0
  case Cons(head, tail): 1 + length(tail)
}
 
/** Use pattern matching to invoke a callback on each list element. */
def forEach(l: List[T], fun: () => Unit): Unit = l match {
  case Cons(head, tail) => {
    fun(head);
    forEach(tail, fun);
  }
  case _ => () // Nothing to do
}

We will now analyse a few ways to do the same in Java, and evaluate their robustness against three types of errors:

  1. Can we guarantee that the pattern-matching code is consistent? (Executes code expecting case A when meeting an A instance)
  2. Can we guarantee that the technical implementation in the matched class itself is correct?
  3. Can we guarantee that all pattern-matching code covers all cases?

First define the classes:

public abstract class List<T> {
  /** Private constructor to effectively "seal" the class. */
  private List() {}
 
  public static class Empty<T> extends List<T> {}
  public static class Cons<T> extends List<T> {
    public final T head;
    public final List<T> tail;
    public Cons(T head, List<T> tail) {
      this.head = head;
      this.tail = tail;
    }
  }
}

Using instanceof (Levels 2, 1, 7)

A straightforward way to compute the length is as follows:

public int length(List<T> l) {
  if (l instanceof Empty) {
    return 0;
  } else if (l instanceof Cons) {
    return 1 + length(((Cons<T>)l).tail);
  } else {
    throw new IllegalArgumentException("Unexpected list type!");
  }
}

However having to cast the type right after having checked its type feels awkward. It is awkward but Level 2 because casting to the wrong type triggers a compiler warning. Moreover that else clause should be unnecessary (provided l is never null), but is unavoidable to make the code compile, because the compiler doesn’t understand List is effectively “sealed”. As forEach returns no value that is not an issue there:

public void forEach(List<T> l, Consumer<T> fun) {
  if (l instanceof Cons) {
    Cons cons = (Cons)l;
    fun.apply(cons.head);
    forEach(cons.tail, fun);
  } // else: empty list, nothing to do
}

… however we are now forced to create a new local variable to avoid casting twice.

There’s no “technical implementation” at all on the List-side, so I mark it Level 1.

Can we guarantee the matching code is complete? No, we can’t — if we add a case to List, there’s no errors, no warnings, only runtime exceptions, and only when that particular case is met by that particular incomplete matcher. Obviously no unit test could exist for that particular combination, so I mark that property as Level 7 for this implementation.

Using Throwables (Levels 1, 1, 7)

One way to avoid having to test type then cast is to make the class extend Exception:

public abstract class List extends Exception {
  /** Private constructor to effectively "seal" the class. */
  private List() {}
  /** Emptied, to avoid performance penalty */
  protected Throwable fillInStackTrace() { return this; }
  /* Rest of the class as before ... */
}

Matching is then done as follows:

public int length(List l) {
  try {
    throw l;
  } catch (Empty e) {
    return 0;
  } catch (Cons c) {
    return 1 + length(c.tail);
  } catch (List other) {
    throw new IllegalArgumentException("Unexpected list type!");
  }
}

Note how the Java compiler is smart enough to know that exactly one catch clause will be taken, so we don’t need return or throw statements after them. This property is Level 1: if no default clause is provided the compiler complains that “the method must return a value of type int”. However, again, that is rather weak, in that it doesn’t understand that the two catch clauses above are complete because List is effectively sealed.

public void forEach(List l, Consumer<Object> fun) {
  try {
    throw l;
  } catch (Cons c) {
    fun.apply(c.head);
    forEach(c.tail, fun);
  } catch (List other) {
    /* Nothing to do */
  }
}

The bad news: notice that the <T> is gone? The type can no longer be generic. It is disastrous for a collection type like List but would be acceptable for non-generic structured types like arithmetic expressions or xml/json trees.

The good news: type-testing and casting is now grouped into a single “catch” construct, raising matching code correctness from Level 2 to Level 1. You can even use Java 7-provided Multi-catch. Suppose you are matching a Json structure and want to count the number of leaves. You could write something like this:

int countLeaves(Node tree) {
  try {
    throw tree;
  } catch (NumberNode | StringNode | NullNode n) {
    return 1;
  } catch (ArrayNode | ObjectNode a) {
    int total=0;
    // suppose "values()" is a method in common between ArrayNode and ObjectNode...
    // if not, need two separate catch-clauses
    for (Node n : a.values()) {
      total += countLeaves(n);
    }
    return total;
  } catch (Node other) {
    throw new IllegalArgumentException(other.getClass());
  }

As there’s no technical boilerplate on the side of case classes (aside from “extends Exception“), their implementation is still Level 1.

However, like when using instanceof, we still have no way to ensure we covered all cases (this design is really a safer instanceof-then-cast), so that part is still at Level 7.

Visitor Pattern (Levels 1, 3, 1)

This design is the commonly accepted way of doing pattern matching in Java and other object-oriented languages that lack native general pattern matching. Define a visitor interface:

public interface ListVisitor<R, T> {
  R empty();
  R cons(T head, List<T> tail);
}

Add a method to the abstract List class:

public abstract class List<T> {
  public <R> R match(ListVisitor<R, T>);
}

Implement it as follows:

public static class Empty<T> extends List<T> {
  public <R> R match(ListVisitor<R, T> v) {
    return v.empty();
  }
}
public static class Cons<T> extends List<T> {
  /* ... fields and constructor as above ... */
  public <R> match(listVisitor<R, T> v) {
    return v.cons(head, tail);
  }
}

And use it as follows:

int length(List<T> l) {
  return l.match(new ListVisitor<Integer, T>() {
    public Integer empty() {
      return 0;
    }
    public Integer cons(T head, List<T> tail) {
      return 1 + length(tail);
    }
  };
}
 
void forEach(List<T> l, final Consumer<T> fun) {
  l.match(new ListVisitor<Void, T>() {
    public Void empty() {
      return null;
    }
    public Void cons(T head, List<T> tail) {
      fun.apply(head);
      forEach(tail, fun);
      return null;
    }
  };
}

This design is extremely verbose compared to the other approaches, however it is Level 1 when guaranteeing that all cases must be covered, and Level 1 when guaranteeing that client (pattern-matching) code works as it should.

However no guarantees exist that the List implementation is correct: if, by mistake (such as a copy-paste error for a class having lots of cases), I implement Cons.match(v) to call v.empty(), everything compiles with no warning, and programmers using your visitor API will be left wondering why their clearly correct code doesn’t work.

Correctness of the List.match() implementations is therefore Level 3 only because errors are obvious when inspecting them. Can we do better?

Non-destructing Visitors (Levels 1, 1, 1)

The answer is yes, of course, with a very minor change: Have the Visitor interface take full List subclass instances instead of destructing them. I chose here to have all methods have the same name to avoid redundancy, making the code more pleasant to read:

public interface ListVisitor<R, T> {
  R when(Empty<T> e);
  R when(Cons<T> c);
}
 
public abstract class List<T> {
  public <R> R match(ListVisitor<R, T>);
}
 
public static class Empty<T> extends List<T> {
  public <R> R match(ListVisitor<R, T> v) {
    return v.when(this);
  }
}
public static class Cons<T> extends List<T> {
  /* ... fields and constructor as above ... */
  public <R> match(listVisitor<R, T> v) {
    return v.when(this);
  }
}

Used as follows:

int length(List<T> l) {
  return l.match(new ListVisitor<Integer, T>() {
    public Integer when(Empty<T> e) {
      return 0;
    }
    public Integer when(Cons<T> cons) {
      return 1 + length(cons.tail);
    }
  };
}
 
void forEach(List<T> l, final Consumer<T> fun) {
  l.match(new ListVisitor<Void, T>() {
    public Void when(Empty<T> e) {
      return null;
    }
    public Void cons(Cons<T> cons) {
      fun.apply(cons.head);
      forEach(cons.tail, fun);
      return null;
    }
  };
}

The advantage of having all methods take the name is that it eliminates the risk of copy-paste accidents: they are copy-pasted! It is not even possible to call the wrong visitor method because the compiler detects the correct one for you.

Summary

Although unpleasantly verbose, the non-destructive Visitor approach is clearly the winner in terms of robustness, especially as it Level-1-forces all callers to adapt their code when you add a new case.

However, the very second you yield to the temptation of adding support for a “default” implementation (typically by having an abstract implementation of the visitor interface that calls a single abstract method from all methods), be warned that your visitor is no longer safer than the exception-based approach ; it is just more verbose.