I have a situation for which I am using nested for-loops, but I want to know if there's a faster way of doing this using some advanced indexing in Pytorch.
I have a tensor named t
:
t = torch.randn(3,8)
print(t)
tensor([[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920, -0.3160, -2.1152],
[ 0.4681, -0.1577, 1.4437, 0.2660, 0.1665, 0.8744, -0.1435, -0.1116],
[ 0.9318, 1.2590, 2.0050, 0.0537, 0.6181, -0.4128, -0.8411, -2.3160]])
I want to create a new tensor which indexes values from t
.
Let's say these indexes are stored in variable indexes
indexes = [[(0, 1, 4, 5), (0, 1, 6, 7), (4, 5, 6, 7)],
[(2, 3, 4, 5)],
[(4, 5, 6, 7), (2, 3, 6, 7)]]
Each inner tuple in indexes
represents four indexes that are to be taken from a row.
As an example, based on these indexes my output would be a 6x4 dimension tensor (6 is the total number of tuples in indexes
, and 4 corresponds to one value in a tuple)
For instance, this is what I want to do:
#counting the number of tuples in indexes
count_instances = sum([1 for lst in indexes for tupl in lst])
#creating a zero output matrix
final_tensor = torch.zeros(count_instances,4)
final_tensor[0] = t[0,indexes[0][0]]
final_tensor[1] = t[0,indexes[0][1]]
final_tensor[2] = t[0,indexes[0][2]]
final_tensor[3] = t[1,indexes[1][0]]
final_tensor[4] = t[2,indexes[2][0]]
final_tensor[5] = t[2,indexes[2][1]]
The final output looks like this: print(final_tensor)
tensor([[-1.1258, -1.1524, 0.8487, 0.6920],
[-1.1258, -1.1524, -0.3160, -2.1152],
[ 0.8487, 0.6920, -0.3160, -2.1152],
[ 1.4437, 0.2660, 0.1665, 0.8744],
[ 0.6181, -0.4128, -0.8411, -2.3160],
[ 2.0050, 0.0537, -0.8411, -2.3160]])
I created a function build_tensor
(shown below) to achieve this with nested for-loops, but I want to know if there's a faster way of doing it with simple indexing in Pytorch. I want a faster way of doing it because I'm doing this operation hundreds of times with bigger index and t sizes.
Any help?
def build_tensor(indexes, t):
#count tuples
count_instances = sum([1 for lst in indexes for tupl in lst])
#create a zero tensor
final_tensor = torch.zeros(count_instances,4)
final_tensor_idx = 0
for curr_idx, lst in enumerate(indexes):
for tupl in lst:
final_tensor[final_tensor_idx] = t[curr_idx,tupl]
final_tensor_idx+=1
return final_tensor