8
votes

This may seem like a basic question, but I am unable to work it through.

In the forward pass of my neural network, I have an output tensor of shape 8x3x3, where 8 is my batch size. We can assume each 3x3 tensor to be a non-singular matrix. I need to find the inverse of these matrices. The PyTorch inverse() function only works on square matrices. Since I now have 8x3x3, how do I apply this function to every matrix in the batch in a differentiable manner?

If I iterate through the samples and append the inverses to a python list, which I then convert to a PyTorch tensor, should it be a problem during backprop? (I am asking since converting PyTorch tensors to numpy to perform some operations and then back to a tensor won't compute gradients during backprop for such operations)

I also get the following error when I try to do something like that.

a = torch.arange(0,8).view(-1,2,2)
b = [m.inverse() for m in a]
c = torch.FloatTensor(b)

TypeError: 'torch.FloatTensor' object does not support indexing

2

2 Answers

9
votes

EDIT:

As of Pytorch version 1.0, torch.inverse now supports batches of tensors. See here. So you can simply use the built-in function torch.inverse

OLD ANSWER

There are plans to implement batched inverse soon. For discussion, see for example issue 7500 or issue 9102. However, as of the time of writing, the current stable version (0.4.1), no batch inverse operation is available.

Having said that, recently batch support for torch.gesv was added. This can be (ab)used to define your own batched inverse operation along the following lines:

def b_inv(b_mat):
    eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
    b_inv, _ = torch.gesv(eye, b_mat)
    return b_inv

I found that this gives good speed-ups over a for loop when running on GPU.

1
votes

You could split the tensor using torch.functional.unbind(), apply inverse to every element of the result, and then stack back:

a = torch.arange(0,8).view(-1,2,2)
b = [t.inverse() for t in torch.functional.unbind(a)]
c = torch.functional.stack(b)