2
votes

I have a generator which yields infinite amount of data (Random image crops). I would like to create a tf.Dataset based on let's say 10,000 first data points and cache it to use them to train models?

Currently, I have a generator which takes 1-2 seconds to create each datapoint and this is the main performance blocker. I have to wait a minute to generate a batch of 64 images (the preprocessing() function is very expensive, so I would like to reuse the results).

ds = tf.Dataset.from_generator() method allows us to create such infinite dataset. Instead, I would like to create a finite dataset using N first outputs from the generator and cache it like:

ds = ds.cache().


Alternative solution is to keep generating new data, and using cached datapoints while rendering the generator.

1
In this case I would run a separate tf session creating your inputs in batchsize=1 and storing them to disk. And having a separate training session reading the cache folder and feeding to the model. The only thing, i dont think there is a straightforward way of making your dataset to know about new inputs after initialization - y.selivonchyk

1 Answers

1
votes

You can use the Dataset.cache function with the Dataset.take function to accomplish this.

If everything fits in memory its as simple as doing something like this:

def generate_example():
  i = 0
  while(True):
    print ('yielding value {}'.format(i))
    yield tf.random.uniform((64,64,3))
    i +=1

ds = tf.data.Dataset.from_generator(generate_example, tf.float32)

first_n_datapoints = ds.take(n).cache()

Now note, that if I set n to 3 say then do something trivial like:

for i in first_n_datapoints.repeat():
  print ('')
  print (i.shape)

then I see output confirming that the first 3 values are cached (I only see the yielding value {i} output once for each of the first 3 values generated:

yielding value 0
(64,64,3)
yielding value 1
(64,64,3)
yielding value 2
(64,64,3)
(64,64,3)
(64,64,3)
(64,64,3)
...

If everything does not fit in memory then we can pass a filepath to the cache function where it will cache the generated tensors to disk.

More info here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache