I have a tensor A
of sizetorch.Size([32, 32, 3, 3])
and I want to split it and extract a tensor B
of size torch.Size([16, 16, 3, 3])
from it. The tensor can be 1d or 4d and split has to be according to the given new tensor dimensions. I have been able to generate the target dimensions but I'm unable to split and extract the values from the source tensor. I ave tried torch.narrow
but it takes only 3 arguments and I need 4 in many cases. torch.split
takes dim as an int due to which tensor is split along one dimension only. But I want to split it along multiple dimensions.
0
votes
1 Answers
1
votes
You have multiple options:
- use
.split
multiple times - use
.narrow
multiple times - use slicing
e.g.:
t = torch.rand(32, 32, 3, 3)
t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])
t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])
t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])