So suppose I have a k-dim tensor and a 1-dim mask, that is used a lot in pytorch for variable length sequences, and I want to return a tensor that represents the elements up to the first false value in mask. Here is an example:
import torch
a = torch.tensor([[1,2],[3,4],[5,6],[0,0],[0,0],[0,0]])
b = torch.tensor([True,True,True,False,False,False])
# magic goes here, result of c should be:
print(c)
>>> [[1,2],[3,4],[5,6]]
In this example, the input tensor is 2D, but it could be k-d, with any number of values across those dimensions. Only the first dimension needs to match the mask dimension. So doing torch.masked_select doesn't work since the tensor to be truncated isn't 1D like the mask, and since you don't know the dimensionality, squeeze and unsqueeze is not a solution either.
The mask is always going to be true for the first k elements and false for the remaining elements, though if your solution does not "depend" on this, that is fine.
This seems like people would be doing it all the time, yet I cannot find anywhere where this question has been answered.