Sometimes I need to modify some of the values in a pytorch tensor. For example, given a tensor x
, I need to multiply its positive part by 2 and multiply its negative part by 3:
import torch
x = torch.randn(1000, requires_grad=True)
x[x>0] = 2 * x[x>0]
x[x<0] = 3 * x[x<0]
y = x.sum()
y.backward()
However such inplace operations always break the graph for autograd:
Traceback (most recent call last):
File "test_rep.py", line 4, in <module>
x[x>0] = 2 * x[x>0]
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
Therefore, so far I've been using the following workaround:
import torch
x = torch.randn(1000, requires_grad=True)
y = torch.zeros_like(x, device=x.device)
y[x>0] = 2 * x[x>0]
y[x<0] = 3 * x[x<0]
z = y.sum()
z.backward()
which results in manually creating new tensors. I wonder if there is a better way to do this.