Remember that unification is a syntactic operation, which in languages like Idris is augmented with straightforward reductions according to pattern-matching. It doesn't know all of the facts that we can prove!
We can certainly prove in Idris that if n+m=0 then m = 0 and n = 0:
sumZero : (n, m : Nat) -> plus n m = Z -> (n=Z, m=Z)
sumZero Z m prf = (refl, prf)
sumZero (S k) m refl impossible
but this doesn't make the unifier know about this fact, as that would make type checking undecidable.
Returning to your original issue: if I translate the type of your partition into English, it says "for all natural numbers m
and n
, for all Boolean predicates p
over a
, given a vector of length plus m n
, I can produce a pair consisting of a vector of length m
and vector of length n
". In other words, to call your function, I would need to know in advance how many elements of the vector satisfy the predicate, because I need to supply m
and n
at the call site!
What I think you really want is a partition
function that, given a vector of length n
, returns a pair of vectors whose lengths add up to n
. We can express this using a dependent pair, which is the type theory version of existential quantification. The translation of "a pair of vectors whose lengths add up to n
" is "there exists some m
and m'
and vectors with these lengths such that the sum of m
and m'
is my input n
."
This type looks like this:
partition : (a -> Bool) -> Vect n a -> (m ** (m' ** (Vect m a, Vect m' a, m+m'=n)))
and the complete implementation looks like this:
partition : (a -> Bool) -> Vect n a -> (m ** (m' ** (Vect m a, Vect m' a, m+m'=n)))
partition p [] = (Z ** (Z ** ([], [], refl)))
partition p (x :: xs) with (p x, partition p xs)
| (True, (m ** (m' ** (ys, ns, refl)))) = (S m ** (m' ** (x::ys, ns, refl)))
| (False, (m ** (m' ** (ys, ns, refl)))) =
(m ** (S m' ** (ys, x::ns, sym (plusSuccRightSucc m m'))))
That's a bit of a mouthful, so let's dissect it.
To implement the function, we begin by pattern-matching on the input Vect:
partition p [] = (Z ** (Z ** ([], [], refl)))
Note that the only possible output is what's on the right-hand side, or else we couldn't have constructed the refl
. We know that n
is Z
due to the unification of n
with the index of the constructor Nil
of Vect
.
In the recursive case, we examine the first element of the vector. Here, I use the with
rule because it's readable, but we could have used an if
on the right hand side instead of matching on p x
on the left.
partition p (x :: xs) with (p x, partition p xs)
In the True
case, we add the element to the first subvector. Because plus
reduces on its first argument, we can construct the equality proof using refl
because the addition becomes exactly the right thing:
| (True, (m ** (m' ** (ys, ns, refl)))) = (S m ** (m' ** (x::ys, ns, refl)))
In the False
case, we need to do a bit more work, because plus m (S m')
can't unify with S (plus m m')
. Remember how I said unification doesn't have access to the equalities that we can prove? The library function plusSuccRightSucc
does what we need, though:
| (False, (m ** (m' ** (ys, ns, refl)))) =
(m ** (S m' ** (ys, x::ns, sym (plusSuccRightSucc m m'))))
For reference, the type of plusSuccRightSucc
is:
plusSuccRightSucc : (left : Nat) ->
(right : Nat) ->
S (plus left right) = plus left (S right)
and the type of sym
is:
sym : (l = r) -> r = l
One thing that's missing in the above function is the fact that the function actually partitions the Vect
. We can add this by making the result vectors consist of dependent pairs of elements and evidence that each element satisfies either p
or not p
:
partition' : (p : a -> Bool) ->
(xs : Vect n a) ->
(m ** (m' ** (Vect m (y : a ** so (p y)),
Vect m' (y : a ** so (not (p y))),
m+m'=n)))
partition' p [] = (0 ** (0 ** ([], [], refl)))
partition' p (x :: xs) with (choose (p x), partition' p xs)
partition' p (x :: xs) | (Left oh, (m ** (m' ** (ys, ns, refl)))) =
(S m ** (m' ** ((x ** oh)::ys, ns, refl)))
partition' p (x :: xs) | (Right oh, (m ** (m' ** (ys, ns, refl)))) =
(m ** (S m' ** (ys, (x ** oh)::ns, sym (plusSuccRightSucc m m'))))
If you want to get even crazier, you can also have each element prove that it was an element of the input vector, and that all elements of the input vector are in the output exactly once, and so forth. Dependent types give you the tools to do these things, but it's worth considering the complexity tradeoff in each case.