There is a similar issue for numpy so my answer is heavily inspired by their solution. I will compare some of the mentioned methods using perfplot
. I will also generalize the problem to apply a mapping to a tensor (yours is just a specific case).
For the analysis, I will assume the mapping contains all the unique elements in the tensor and the number of elements to small and constant.
import torch
def apply(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
return b.clone().apply_(lambda x: mapping.__getitem__(x))
def bucketize(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
# From `https://stackguides.com/questions/13572448`.
palette, key = zip(*mapping.items())
key = torch.tensor(key)
palette = torch.tensor(palette)
index = torch.bucketize(b.ravel(), palette)
remapped = key[index].reshape(b.shape)
return remapped
def iterate(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
mapping = {k.item(): v.item() for k, v in zip(a, ids)}
return torch.tensor([mapping[x.item()] for x in b])
def argmax(a: torch.Tensor, ids: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (b.view(-1, 1) == a).int().argmax(dim=1)
if __name__ == "__main__":
import perfplot
a = torch.arange(2, 8)
ids = torch.arange(0, 6)
perfplot.show(
setup=lambda n: torch.randint(2, 8, (n,)),
kernels=[
lambda x: apply(a, ids, x),
lambda x: bucketize(a, ids, x),
lambda x: iterate(a, ids, x),
lambda x: argmax(a, ids, x),
],
labels=["apply", "bucketize", "iterate", "argmax"],
n_range=[2 ** k for k in range(25)],
xlabel="len(a)",
)
Running this yields the following plot:
Hence depending on the number of elements in your tensor you can pick either the argmax
method (with the caveats mentioned and the restriction that you have to map the values from 0 to N
), apply
, or bucketize
.
Now if we increase the number of elements to be mapped lets say tens of thousands i.e. a = torch.arange(2, 10002)
and ids = torch.arange(0, 10000)
we get the following results:
This means the speed increase of bucketize
will only be visible for a larger array but still outperforms the other methods (the argmax
method was killed and therefore I had to remove it).
Last, if we have a mapping that does not have all keys present in the tensor we can just update a dictionary with all unique keys:
mapping = {x.item(): x.item() for x in torch.unique(a)}
mapping.update({k.item(): v.item() for k, v in zip(a, ids)})
Now, if the unique elements you want to map is orders of magnitude larger than the array computing this may shift the value of n
for when bucketize
is faster than apply
(since for apply you can change the mapping.__getitem__(x)
for mapping.get(x, x)
.