0
votes

I need to concatenate a long list of small tensors. Each small tensor is a slice of a given (quite simple) constant matrix. Here is the code:

max_node, counter = 0, 0
batch_size, n_days = (1000, 10)
n_interactions_in = torch.randint(low=100,high=200,size=(batch_size,n_days), dtype=torch.long)
max_interactions = n_interactions_in.max()
delay_table = torch.arange(n_days, device=device, dtype=torch.float).expand([max_interactions, n_days]).t().contiguous()
delay_table = n_days - delay_table - 1
edge_delay_buf = []
for b in range(batch_size):
     delay_vec = [delay_table[d, :n_interactions_in[b, d]] for d in range(n_days)]
     edge_delay_buf.append(torch.cat(delay_vec))
res = torch.cat(edge_delay_buf)

This takes a lot of time. Is there a way to effeciently parrallize the creation of each element in the edge_delay_buf? I have tried multiple variants, such as replacing the for loop with a list concatenation, where the result is a list of lists, then flattening the list and applying torch.cat on the flattened list. However, it didn’t improve by much. For some reason the slicing operation takes too long.

Is there a way to make the slicing faster? Is there a way to make the loop more efficient / parallel?

Note: while I'm using torch in this example, I can also use numpy. Note 2: I apologize for a duplicate post in a different forum.

1
In numpy slicing does not take long; it makes a view. But eventually, when concatenating all those views it has to copy all values into the new array. Obviously if the batch_size is large, that list append step will take time, otherwise I suspect it's the torch.cat step that's the big time consumer. But you should be able time test those steps.hpaulj

1 Answers

0
votes

Start by replacing the inner concatenation by list appending, and only do a single concatenation at the end, it should be much faster.

max_node, counter = 0, 0
batch_size, n_days = (1000, 10)
n_interactions_in = torch.randint(low=100,high=200,size=(batch_size,n_days), dtype=torch.long)
max_interactions = n_interactions_in.max()
delay_table = torch.arange(n_days, device=device, dtype=torch.float).expand([max_interactions, n_days]).t().contiguous()
delay_table = n_days - delay_table - 1
edge_delay_buf = []
for b in range(batch_size):
     delay_vec = [delay_table[d, :n_interactions_in[b, d]] for d in range(n_days)]
     edge_delay_buf += delay_vec
res = torch.cat(edge_delay_buf)

Then, if it is still not fast enough, it is possible to be way more efficient by extracting all the indices at once. Let's see that you have a matrix A for shape [N,M], you can actually extract some elements doing A[B, C], where B is a vector of length K, and C is a matrix of shape [L, K]. Maybe it could fit your need.