Is there any map function in Pytorch? (something like map
in python).
I need to map a 1xDxhxw tensor variable to a 1x(9D)xhxw tensor, to augment embedding of each pixel with its 8 neighbour embeddings. Is there any functionality in Pytorch that lets me do that efficiently?
I tried using map in Python this way:
n, d, h, w = embedding.size()
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)
embedding = map(lambda i, j, M: M[:, :, i-1:i+2, j-1:j+2], range(1, h), range(1, w), embedding)
But it does not work for w > 2
and h > 2
.