0
votes

I have a tensor Aof sizetorch.Size([32, 32, 3, 3]) and I want to split it and extract a tensor B of size torch.Size([16, 16, 3, 3]) from it. The tensor can be 1d or 4d and split has to be according to the given new tensor dimensions. I have been able to generate the target dimensions but I'm unable to split and extract the values from the source tensor. I ave tried torch.narrow but it takes only 3 arguments and I need 4 in many cases. torch.split takes dim as an int due to which tensor is split along one dimension only. But I want to split it along multiple dimensions.

1

1 Answers

1
votes

You have multiple options:

  • use .split multiple times
  • use .narrow multiple times
  • use slicing

e.g.:

t = torch.rand(32, 32, 3, 3)

t0, t1 = t.split((16, 16), 0)
print(t0.shape, t1.shape)
>>> torch.Size([16, 32, 3, 3]) torch.Size([16, 32, 3, 3])

t00, t01 = t0.split((16, 16), 1)
print(t00.shape, t01.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])

t00_alt, t01_alt = t[:16, :16, :, :], t[16:, 16:, :, :]
print(t00_alt.shape, t01_alt.shape)
>>> torch.Size([16, 16, 3, 3]) torch.Size([16, 16, 3, 3])