I'm trying to index the maximum elements along the last dimension in a multidimensional tensor. For example, say I have a tensor
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
Here idx stores the maximum indices, which may look something like
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
I want to be able to access these indices and assign to another tensor based on them. Meaning I want to be able to do
B = torch.new_zeros(A.size())
B[idx] = A[idx]
where B is 0 everywhere except where A is maximum along the last dimension. That is B should store
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
This is proving to be much more difficult than I expected, as the idx does not index the array A properly. Thus far I have been unable to find a vectorized solution to use idx to index A.
Is there a good vectorized way to do this?