1
votes

torch.squeeze can convert the shape of a tensor to not have dimensions of size 1.

I want to squeeze my tensor in all dimensions but one (in this example, not squeeze dim=0).

All I can see in the doc is

dim (int, optional) – if given, the input will be squeezed only in this dimension

I want the opposite:

t = torch.zeros(5, 1, 6, 1, 7, 1)

squeezed = torch.magic_squeeze(keep_dim=3)

assert squeezed == (5, 6, 1, 7)

Can this be done?

2

2 Answers

1
votes

Reshape will let you accomplish what you want to do:

import torch

t = torch.zeros(5, 1, 6, 1, 7, 1)
t = t.reshape((5, 6, 1, 7))
>>> torch.Size([5, 6, 1, 7])
0
votes

You can just squeeze and add a dimension with unsqueeze():

import torch

t = torch.zeros(5, 1, 6, 1, 7, 1)
squeezed = t.squeeze().unsqueeze(2)
print(squeezed.shape)
>>> torch.Size([5, 6, 1, 7])