I would like to index a pytorch tensor with both a boolean mask and normal indices. Something like this:
i = 2
j = 0
mask = torch.randn(480, 360, 3) > 0
tensor = torch.zeros(480, 360, 4, 80)
tensor[mask[..., 0], i, j] = 1
The numpy equivalent works, in pytorch it throws an error:
IndexError: The shape of the mask [480, 360] at index 1 does not match the shape of the indexed tensor [480, 80] at index 1
Any ideas or hints?