1
votes

I'm new in PyTorch and I come from functional programming languages(where map function is used everywhere). The problem is that I have a tensor and I want to do some operations on each element of the tensor. The operation may be various so I need a function like this:

map : (Numeric -> Numeric) -> Tensor -> Tensor

e.g. map(lambda x: x if x < 255 else -1, tensor) # the example is simple, but the lambda may be very complex

Is there such a function in PyTorch? How should I implement such function?

1

1 Answers

2
votes

Most mathematical operations that are implemented for tensors (and similarly for ndarrays in numpy) are actually applied element wise, so you could write for instance

mask = tensor < 255
result = tensor * mask + (-1) * ~mask

This is a quite general appraoch. For the case that you have right now where you only want to modify certain elements, you can also apply "logical indexing" that let's you overwrite the current tensor:

tensor[mask < 255] = -1

So in python there actually is a map() function but usually there are better ways to do it (better in python; in other languages - like Haskell - map/fmap is obviously prefered in most contexts).

So the key take-away here is that the preferred method is taking advantage of the vectorization. This also makes the code more efficient as those tensor operations are implemented in a low level language, while map() is nothing but a python-for loop that is a lot slower.