2
votes

I've been messing around with a simple tensor library, in which I have defined the following type.

data Tensor : Vect n Nat -> Type -> Type where
  Scalar : a -> Tensor [] a
  Dimension : Vect n (Tensor d a) -> Tensor (n :: d) a

The vector parameter of the type describes the tensor's "dimensions" or "shape". I am currently trying to define a function to safely index into a Tensor. I had planned to do this using Fins but I ran into an issue. Because the Tensor is of unknown order, I could need any number of indices, each of which requiring a different upper bound. This means that a Vect of indices would be insufficient, because each index would have a different type. That drove me to look at using tuples (called "pairs" in Idris?) instead. I wrote the following function to compute the necessary type.

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

This function worked as I expected, calculating the appropriate index type from a dimension vector.

> TensorIndex [4,4,3] -- (Fin 4, Fin 4, Fin 3)
> TensorIndex [2] -- Fin 2
> TensorIndex [] -- ()

But when I tried to define the actual index function...

index : {d : Vect n Nat} -> TensorIndex d -> Tensor d a -> a
index () (Scalar x) = x
index (a,as) (Dimension xs) = index as $ index a xs
index a (Dimension xs) with (index a xs) | Tensor x = x

...Idris raised the following error on the second case (oddly enough it seemed perfectly okay with the first).

Type mismatch between
         (A, B) (Type of (a,as))
and
         TensorIndex (n :: d) (Expected type)

The error seems to imply that instead of treating TensorIndex as an extremely convoluted type synonym and evaluating it like I had hoped it would, it treated it as though it were defined with a data declaration; a "black-box type" so to speak. Where does Idris draw the line on this? Is there some way for me to rewrite TensorIndex so that it works the way I want it to? If not, can you think of any other way to write the index function?

2

2 Answers

9
votes

Your definitions will be cleaner if you define Tensor by induction over the list of dimensions whilst the Index is defined as a datatype.

Indeed, at the moment you are forced to pattern-match on the implicit argument of type Vect n Nat to see what shape the index has. But if the index is defined directly as a piece of data, it then constrains the shape of the structure it indexes into and everything falls into place: the right piece of information arrives at the right time for the typechecker to be happy.

module Tensor

import Data.Fin
import Data.Vect

tensor : Vect n Nat -> Type -> Type
tensor []        a = a
tensor (m :: ms) a = Vect m (tensor ms a)

data Index : Vect n Nat -> Type where
  Here : Index []
  At   : Fin m -> Index ms -> Index (m :: ms)

index : Index ms -> tensor ms a -> a
index Here     a = a
index (At k i) v = index i $ index k v
4
votes

Your life becomes so much easier if you allow for a trailing () in your TensorIndex, since then you can just do

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: ds} (i, is) (Dimension xs) = index is (index i xs)

If you want to keep your definition of TensorIndex, you'll need to have separate cases for ds = [_] and ds = _::_::_ to match the structure of TensorIndex:

TensorIndex : Vect n Nat -> Type
TensorIndex []      = ()
TensorIndex (d::[]) = Fin d
TensorIndex (d::ds) = (Fin d, TensorIndex ds)

index : {ds : Vect n Nat} -> TensorIndex ds -> Tensor ds a -> a
index {ds = []} () (Scalar x) = x
index {ds = _ :: []} i (Dimension xs) with (index i xs) | (Scalar x) = x
index {ds = _ :: _ :: _} (i, is) (Dimension xs) = index is (index i xs)

The reason this works and yours didn't is because here, each case of index corresponds exactly to one TensorIndex case, and so TensorIndex ds can be reduced.