2
votes

Tensorflow 2.0 dataset api's batch is not working as I expected it to work.

I've made a dataset like this.

self.train_dataset = tf.data.Dataset.from_generator(generator=train_generator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([6]), tf.TensorShape([])))

This yields DatasetV1Adapter shapes: ((6,), ()), types: (tf.float32, tf.float32), and to this dataset I applied batch function from tf.data.Dataset.

self.train_dataset.batch(1024)

yields DatasetV1Adapter shapes: ((None, 6), (None,)), types: (tf.float32, tf.float32), and changing the batch size doesn't help at all.

From official description of the batch,

The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

The way I thought this function would work, was to make [batch, 6], [batch,] but didn't work out well.

I originally used pytorch, and started using TF 2.0 recently, and need some help on proper batching. Thanks in advance.

1

1 Answers

2
votes

You can get the desired result by setting,

train_dataset = train_dataset.batch(2, drop_remainder=True)

drop_remainder=False by default. When that's the case, the first dimension must be None as there will (most probably) be a batch with < batch_size elements at the end of the dataset, because the number of samples is not divisible by batch_size.