I'm using the dataset API, reading data as follows:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
I now want to use flat_map in order to filter out some, while duplicating some other samples dynamically at training time (this is the input function leading to my model).
The API for flat_map requires to return a Dataset object, however I don't know how to create that. Here's a pseudo-code implementation of what I want to achieve:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
How can I implement this in the flat_map function?
NOTE: I guess it's possible to implement this via a py_func, but I'd prefer to avoid this.