3
votes

I'm learning Idris and I thought I'd try to implement Quicksort for Vect types.

But I'm having a hard time with the utility method that should, given a pivot element and a vector, split the vector in two, one with the elements ≤ pivot and another with those > pivot.

This is trivial for lists:

splitListOn : Ord e => (pivot : e) -> List e -> (List e, List e)
splitListOn pivot [] = ([], [])
splitListOn pivot (x :: xs) = let (ys, zs) = splitListOn pivot xs in
                              if x <= pivot then (x :: ys, zs) 
                                            else (ys, x :: zs)

*Test> splitListOn 3 [1..10]
([1, 2, 3], [4, 5, 6, 7, 8, 9, 10]) : (List Integer, List Integer)

But for Vect I need to express the fact that the sum of the lengths of the two returned Vects is equal to the length of the input Vect.

I clearly need to return a dependent pair. The number of elements ≤ pivot seems a good candidate for the first value, but my first try:

splitVectOn : Ord e => e -> Vect n e -> (k ** (Vect k e, Vect (n - k) e))

complains (rightly so) that it doesn't know whether k ≤ n:

When checking type of Main.splitVectOn:
When checking argument smaller to function Prelude.Nat.-:
        Can't find a value of type 
                LTE k n

I can add such a thing LTE k n to the type signature to reassure the type checker, but then I don't know how to recursively create a return value k that passes the predicate.

I mean, not even for the base case, where n = k = 0:

splitVectOn : Ord e => LTE k n =>
              e -> Vect n e -> (k ** (Vect k e, Vect (n - k) e))
splitVectOn _ [] = (_ ** ([], []))

The error mentions both a k1 and a k, which suggests that there might be something wrong with the type signature:

When checking right hand side of splitVectOn with expected type
        (k1 : Nat ** (Vect k e, Vect (0 - k) e))

When checking argument a to constructor Builtins.MkPair:
        Type mismatch between
                Vect 0 e (Type of [])
        and
                Vect k e (Expected type)

        Specifically:
                Type mismatch between
                        0
                and
                        k

I also thought of using a Fin to express the invariant:

splitVectOn : Ord e => e -> Vect n e ->
              (k : Fin (S n) ** (Vect (finToNat k) e, Vect (??? (n - k)) e))

but then I don't know how to perform the subtraction (which should be possible, because a Fin (S n) is always ≤ n)

1

1 Answers

3
votes

You can add the required proof to the output type like so:

(k ** pf : LTE k n ** (Vect k e, Vect (n - k) e))

Here is how we can define this function:

-- auxiliary lemma
total
minusSuccLte : n `LTE` m -> S (m `minus` n) = (S m) `minus` n
minusSuccLte {m} LTEZero = cong $ minusZeroRight m
minusSuccLte (LTESucc pf) = minusSuccLte pf

total
splitVectOn : Ord e => (pivot : e) -> Vect n e ->
                        (k ** pf : LTE k n ** (Vect k e, Vect (n - k) e))
splitVectOn pivot [] = (0 ** LTEZero ** ([], []))
splitVectOn pivot (x :: xs) = 
  let (k ** lte ** (ys, zs)) = splitVectOn pivot xs in
  if x <= pivot then (S k ** LTESucc lte ** (x :: ys, zs))
  else
    let xzs = replace {P = \n => Vect n e} (minusSuccLte lte) (x :: zs) in
    (k ** lteSuccRight lte ** (ys, xzs))

Another approach to the same problem is to give the following spec to the splitVectOn fucntion:

total
splitVectOn : Ord e => (pivot : e) -> Vect n e -> 
              (k1 : Nat ** k2 : Nat ** (k1 + k2 = n, Vect k1 e, Vect k2 e))

i.e. we (existentially) quantify over the lengths of the output vectors and add the condition that the sum of those lengths must be equal to the length of the input vector. This k1 + k2 = n condition can be omitted, of course, which would simplify the implementation a lot.

Here is an implementation of the function with the modified spec:

total
splitVectOn : Ord e => (pivot : e) -> Vect n e -> 
              (k1 : Nat ** k2 : Nat ** (k1 + k2 = n, Vect k1 e, Vect k2 e))
splitVectOn pivot [] = (0 ** 0 ** (Refl, [], []))
splitVectOn pivot (x :: xs) =
  let (k1 ** k2 ** (eq, ys, zs)) = splitVectOn pivot xs in
  if x <= pivot then (S k1 ** k2 ** (cong eq, x :: ys, zs))
  else let eq1 = sym $ plusSuccRightSucc k1 k2 in
       let eq2 = cong {f = S} eq in
       (k1 ** S k2 ** (trans eq1 eq2, ys, x :: zs))