0
votes

I would like to batch matrix multiplication by taking slices of large tensors.

Say I have A of shape [N, 1, 4], B of shape [N, 4, 4]. I would like to first slice them along batch dimension, getting [b, 1, 4], and [b, 4, 4] which is not necessarily contiguous, but getting results of shape [b, 4] by doing matrix multiplication in batches. Is there a way to do that using Eigen?

1
Why would you want to do this rather than multiple the tensors, then extract the submatrices? - Mansoor
@Mansoor sorry I guess I wasn’t clear. The slicing part does not really matter(should have left out that), what I want to do essentially is just multiply tensor of shape (b, 1, 4) and (b, 4, 4) in batches. - jack

1 Answers

0
votes

I'm not sure if this is an efficient way to perform batch matrix multiplication for Eigen Tensors but one solution might be mapping tensor pages as matrices and performing general matrix multiplication:

#include <Eigen/Dense>
#include <unsupported/Eigen/CXX11/Tensor>

typedef Eigen::Tensor<double, 3> Tensor3d;

inline void batchedTensorMultiplication(const Tensor3d& A, const Tensor3d& B, const std::vector<int>& batchIndices, Tensor3d& C)
{
    Eigen::DenseIndex memStepA = A.dimension(0) * A.dimension(1);
    Eigen::DenseIndex memStepB = B.dimension(0) * B.dimension(1);
    Eigen::DenseIndex memStepC = C.dimension(0) * C.dimension(1);
    int outputBatchIndex = 0;

    for (int batchIndex : batchIndices)
    {
        Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageA(A.data() + batchIndex * memStepA, A.dimension(0), A.dimension(1));
        Eigen::Map<const Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageB(B.data() + batchIndex * memStepB, B.dimension(0), B.dimension(1));
        Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>> pageC(C.data() + outputBatchIndex * memStepC, C.dimension(0), C.dimension(1));

        outputBatchIndex++;

        pageC.noalias() = pageA * pageB;
    }
}

int main() 
{
    constexpr int N = 50;
    std::vector<int> batchIndices = { 0,1,2,3,4,9,10,11,12,13 };

    Tensor3d A(1, 4, N), B(4, 4, N), C(1, 4, (int)batchIndices.size());

    batchedTensorMultiplication(A, B, batchIndices, C);

    return 0;
}