Assume I have the following idris source code:
module Source
import Data.Vect
--in order to avoid compiler confusion between Prelude.List.(++), Prelude.String.(++) and Data.Vect.(++)
infixl 0 +++
(+++) : Vect n a -> Vect m a -> Vect (n+m) a
v +++ w = v ++ w
--NB: further down in the question I'll assume this definition isn't needed because the compiler
-- will have enough context to disambiguate between these and figure out that Data.Vect.(++)
-- is the "correct" one to use.
lemma : reverse (n :: ns) +++ (n :: ns) = reverse ns +++ (n :: n :: ns)
lemma {ns = []} = Refl
lemma {ns = n' :: ns} = ?lemma_rhs
As shown, the base case for lemma
is trivially Refl
. But I can't seem to find a way to prove the inductive case: the repl "just" spits out the following
*source> :t lemma_rhs
phTy : Type
n1 : phTy
len : Nat
ns : Vect len phTy
n : phTy
-----------------------------------------
lemma_rhs : Data.Vect.reverse, go phTy
(S (S len))
(n :: n1 :: ns)
[n1, n]
ns ++
n :: n1 :: ns =
Data.Vect.reverse, go phTy (S len) (n1 :: ns) [n1] ns ++
n :: n :: n1 :: ns
I understand that phTy
stands for "phantom type", the implicit type of the vectors I'm considering. I also understand that go
is the name of the function defined in the where
clause for the definition of the library function reverse
.
Question
How can I continue the proof? Is my inductive strategy sound? Is there a better one?
Context
This has came up in one of my toy projects, where I try to define arbitrary tensors; specifically, this seems to be needed in order to define "full index contraction". I'll elaborate a little bit on that:
I define tensors in a way that's roughly equivalent to
data Tensor : (rank : Nat) -> (shape : Vector rank Nat) -> Type where
Scalar : a -> Tensor Z [] a
Vector : Vect n (Tensor rank shape a) -> Tensor (S rank) (n :: shape) a
glossing over the rest of the source code (since it isn't relevant, and it's quite long and uninteresting as of now), I was able to define the following functions
contractIndex : Num a =>
Tensor (r1 + (2 + r2)) (s1 ++ (n :: n :: s2)) a ->
Tensor (r1 + r2) (s1 ++ s2) a
tensorProduct : Num a =>
Tensor r1 s1 a ->
Tensor r2 s2 a ->
Tensor (r1 + r2) (s1 ++ s2) a
contractProduct : Num a =>
Tensor (S r1) s1 a ->
Tensor (S r2) ((last s1) :: s2) a ->
Tensor (r1 + r2) ((take r1 s1) ++ s2) a
and I'm working on this other one
fullIndexContraction : Num a =>
Tensor r (reverse ns) a ->
Tensor r ns a ->
Tensor 0 [] a
fullIndexContraction {r = Z} {ns = []} t s = t * s
fullIndexContraction {r = S r} {ns = n :: ns} t s = ?rhs
that should "iterate contractProduct
as much as possible (that is, r
times)"; equivalently, it could be possible to define it as tensorProduct
composed with as many contractIndex
as possible (again, that amount should be r
).
I'm including all this becuse maybe it's easier to just solve this problem without proving the lemma
above: if that were the case, I'd be fully satisfied as well. I just thought the "shorter" version above might be easier to deal with, since I'm pretty sure I'll be able to figure out the missing pieces myself.
The version of idris i'm using is 1.3.2-git:PRE
(that's what the repl says when invoked from the command line).
Edit: xash's answer covers almost everything, and I was able to write the following functions
nreverse_id : (k : Nat) -> nreverse k = k
contractAllIndices : Num a =>
Tensor (nreverse k + k) (reverse ns ++ ns) a ->
Tensor Z [] a
contractAllProduct : Num a =>
Tensor (nreverse k) (reverse ns) a ->
Tensor k ns a ->
Tensor Z []
I also wrote a "fancy" version of reverse
, let's call it fancy_reverse
, that automatically rewrites nreverse k = k
in its result. So I tried to write a function that doesn't have nreverse
in its signature, something like
fancy_reverse : Vect n a -> Vect n a
fancy_reverse {n} xs =
rewrite sym $ nreverse_id n in
reverse xs
contract : Num a =>
{auto eql : fancy_reverse ns1 = ns2} ->
Tensor k ns1 a ->
Tensor k ns2 a ->
Tensor Z [] a
contract {eql} {k} {ns1} {ns2} t s =
flip contractAllProduct s $
rewrite sym $ nreverse_id k in
?rhs
now, the inferred type for rhs
is Tensor (nreverse k) (reverse ns2)
and I have in scope a rewrite rule for k = nreverse k
, but I can't seem to wrap my head around how to rewrite the implicit eql
proof to make this type check: am I doing something wrong?