45
votes

Is there a Pytorch-internal procedure to detect NaNs in Tensors? Tensorflow has the tf.is_nan and the tf.check_numerics operations ... Does Pytorch have something similar, somewhere? I could not find something like this in the docs...

I am looking specifically for a Pytorch internal routine, since I would like this to happen on the GPU as well as on the CPU. This excludes numpy - based solutions (like np.isnan(sometensor.numpy()).any()) ...

5
this might be of help: x.isnan().any() - Charlie Parker

5 Answers

73
votes

You can always leverage the fact that nan != nan:

>>> x = torch.tensor([1, 2, np.nan])
tensor([  1.,   2., nan.])
>>> x != x
tensor([ 0,  0,  1], dtype=torch.uint8)

With pytorch 0.4 there is also torch.isnan:

>>> torch.isnan(x)
tensor([ 0,  0,  1], dtype=torch.uint8)
37
votes

Starting with PyTorch 0.4.1 there is the detect_anomaly context manager, which automatically inserts assertions equivalent to assert not torch.isnan(grad).any() between all steps of backward propagation. It's very useful when issues arise during backward pass.

18
votes

As suggested by @cleros in the comment on @nemo's answer, you can get this as a boolean using the any() operator:

torch.isnan(your_tensor).any()
4
votes

If you want to call it on a tensor directly:

import torch

x = torch.randn(5, 4)
print(x.isnan().any())

out:

import torch
x = torch.randn(5, 4)
print(x.isnan().any())
tensor(False)
3
votes

True if any value is nan:

torch.any(tensor.isnan())

True if all is nan:

torch.all(tensor.isnan())