0
votes

I have a tensor M of dimensions [NxQxD] and a 1d tensor of indices idx (of size N). I want to efficiently create a tensor mask of dimensions [NxQxD] such that mask[i,j,k] = 1 iff j <= idx[i], i.e. I want to keep only the idx[i] first dimensions out of Q in the second dimension (dim=1) of M, for every row i.

Thanks!

1

1 Answers

1
votes

It turns out this can be done via a broadcasting trick:

mask_2d = torch.arange(Q)[None, :] < idx[:, None] #(N,Q)
mask_3d = mask[..., None] #(N,Q,1)
masked = mask.float() * data