0
votes

I used tf.Variable for W (weights) and b (biases), but tf.placeholder for X (input batch) and Y (expected values for this batch). And all works ok. But today I found this topic: Tensorflow github issues And quote:

Feed_dict does a single-threaded memcpy of contents from Python runtime into TensorFlow runtime. If data is needed on GPU, then you'll have an additional CPU->GPU transfer. I'm used to seeing up to 10x improvement in performance when switching from feed_dict to native TensorFlow (Variable/Queue)

And now I try to find how to use tf.Variable or Queue for input data and no feed_dict, for speed improvement, especcially for batches. Cause I need change data batch one by one. And when all batches are done - end of epoch. And than from the begining, second epoch, etc...

But sorry, I don't understand how can I use that.

1
see cifar10_input.py from this tutorial tensorflow.org/versions/r0.11/tutorials/deep_cnn/index.htmlmdaoust

1 Answers

1
votes

Here is a self-contained example of how you might use queues to feed training batches:

import tensorflow as tf

IMG_SIZE = [30, 30, 3]
BATCH_SIZE_TRAIN = 50

def get_training_batch(batch_size):
    ''' training data pipeline -- normally you would read data from files here using
    a TF reader of some kind. '''
    image = tf.random_uniform(shape=IMG_SIZE)
    label = tf.random_uniform(shape=[])

    min_after_dequeue = 100
    capacity = min_after_dequeue + 3 * batch_size
    images, labels = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue)
    return images, labels

# define the graph
images_train, labels_train = get_training_batch(BATCH_SIZE_TRAIN)
'''inference, training and other ops generally are defined here too'''

# start a session
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    ''' do something interesting here -- training, validation, etc'''
    for _ in range(5):
        # typical training step where batch data are drawn from the training queue
        py_images, py_labels = sess.run([images_train, labels_train])
        print('\nData from queue:')
        print('\tImages shape, first element: ', py_images.shape, py_images[0][0, 0, 0])
        print('\tLabels shape, first element: ', py_labels.shape, py_labels[0])

    # close threads
    coord.request_stop()
    coord.join(threads)