0
votes

I have a the following tensor lets call it lookup_table:

tensor([266, 103,  84,  12,  32,  34,   1, 523,  22, 136, 268, 432,  53,  63,
        201,  51, 164,  69,  31,  42, 122, 131, 119,  36, 245,  60,  28,  81,
          9, 114, 105,   3,  41,  86, 150,  79, 104, 120,  74, 420,  39, 427,
         40,  59,  24, 126, 202, 222, 145, 429,  43,  30,  38,  55,  10, 141,
         85, 121, 203, 240,  96,   7,  64,  89, 127, 236, 117,  99,  54,  90,
         57,  11,  21,  62,  82,  25, 267,  75, 111, 518,  76,  56,  20,   2,
         61, 516,  80,  78, 555, 246, 133, 497,  33, 421,  58, 107,  92,  68,
         13, 113, 235, 875,  35,  98, 102,  27,  14,  15,  72,  37,  16,  50,
        517, 134, 223, 163,  91,  44,  17, 412,  18,  48,  23,   4,  29,  77,
          6, 110,  67,  45, 161, 254, 112,   8, 106,  19, 498, 101,   5, 157,
         83, 350, 154, 238, 115,  26, 142, 143])

And I have another tensor lets call it data, which looks like this:

tensor([[517, 235, 236,  76,  81,  25, 110,  59, 245,  39],
        [523, 114, 350, 246,  30, 222,  39, 517, 106,   2],
        [ 35, 235, 120,  99, 266,  63, 236, 133, 412,  38],
        [134,   2, 497,  21,  78,  60, 142, 498,  24,  89],
        [ 60, 111, 120, 145,  91, 141, 164,  81, 350,  55]])

Now I want something which looks similar to this:

tensor([112, 100, ..., 40],
       [7, 29, ..., 2],
       ...,          ])

I want to use my data tensor to get the index of the lookup table.
Basically I want to vectorize this:

(lookup_table == data).nonzero()

So that this works for multidimensional arrays.

I have read this, but they are not working for my case:
How Pytorch Tensor get the index of specific value
How Pytorch Tensor get the index of elements?
Pytorch tensor - How to get the indexes by a specific tensor