2
votes

Is there any map function in Pytorch? (something like map in python).

I need to map a 1xDxhxw tensor variable to a 1x(9D)xhxw tensor, to augment embedding of each pixel with its 8 neighbour embeddings. Is there any functionality in Pytorch that lets me do that efficiently?

I tried using map in Python this way:

n, d, h, w = embedding.size()
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)
embedding = map(lambda i, j, M: M[:, :, i-1:i+2, j-1:j+2], range(1, h), range(1, w), embedding)

But it does not work for w > 2 and h > 2.

1

1 Answers

0
votes

From your question, it is not clear what you are attempting to accomplish.

Note that full Python is supported in PyTorch, but what you are doing is creating a map object in your last line of code. The following should work for your purposes (? I'm guessing) though:

import torch
import torch.nn as nn
n, d, h, w = 20, 3, 32, 32
embedding = torch.randn(n, d, h, w)
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)

new = [embedding[:,:, (i-1):(i+2), (j-1):(j+2)] for i, j in zip(range(1,h), range(1,w))]

Note however, that there are more elegant ways to chunk up a tensor (e.g. torch.chunk() or to operate on patches with convolutions (e.g. torch.nn.Conv2d)