Tensorflow 2.0 dataset api's batch is not working as I expected it to work.
I've made a dataset like this.
self.train_dataset = tf.data.Dataset.from_generator(generator=train_generator, output_types=(tf.float32, tf.float32), output_shapes=(tf.TensorShape([6]), tf.TensorShape([])))
This yields DatasetV1Adapter shapes: ((6,), ()), types: (tf.float32, tf.float32), and to this dataset I applied batch function from tf.data.Dataset.
self.train_dataset.batch(1024)
yields DatasetV1Adapter shapes: ((None, 6), (None,)), types: (tf.float32, tf.float32), and changing the batch size doesn't help at all.
From official description of the batch,
The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.
The way I thought this function would work, was to make [batch, 6], [batch,] but didn't work out well.
I originally used pytorch, and started using TF 2.0 recently, and need some help on proper batching. Thanks in advance.