In Keras, using the Flatten()
layer retains the batch size. For eg, if the input shape to Flatten is (32, 100, 100)
, in Keras
output of Flatten is (32, 10000)
, but in PyTorch it is 320000
. Why is it so?
2 Answers
As OP already pointed out in their answer, the tensor operations do not default to considering a batch dimension. You can use torch.flatten()
or Tensor.flatten()
with start_dim=1
to start the flattening operation after the batch dimension.
Alternatively since PyTorch 1.2.0 you can define an nn.Flatten()
layer in your model which defaults to start_dim=1
.
Yes, As mentioned in this thread, PyTorch operations such as Flatten, view, reshape.
In general when using modules like Conv2d
, you don't need to worry about batch size. PyTorch takes care of it. But when dealing directly with tensors, you need to take care of batch size.
In Keras, Flatten()
is a layer. But in PyTorch, flatten()
is an operation on the tensor. Hence, batch size needs to be taken care manually.