1
votes

I am trying to maximize GPU occupancy during training. I have variable length sequences that I would like densely packed into fixed length batches. Essentially, I want short sequences to be followed by another sequence, and I want long sequences to be split such that they continue in the next batch. Example:

// Say batch size is 2 and desired sequence length is 4
s1 = [a, b, c, d, e, f]
s2 = [x, y, z]
s3 = [l, m, n, o]

// Resulting batches:
b1 = [[a, b, c, d]
      [x, y, z, l]]
b2 = [[e, f, _, _]
      [m, n, o, _]]

Is there an easy way to do this in Tensorflow? My sequences are coming from a tf.TextLineReader:

file_queue = tf.train.string_input_producer('./example_text')
reader = tf.TextLineReader()
key, sentence = reader.read(file_queue)
// convert string to int32 vector
sequence_tensor = to_sequence(sentence)

// what I wish I had:
batch = tf.fixed_length_batch_from_variable_length_sequences(
    sequence_tensor, batch_size, fixed_length)

Thank you in advance for any suggestions.

1

1 Answers

0
votes

OK, I have a working example that is almost what I was hoping for. The following code produces batches in the manner I would like, but it requires passing data in and out of the TF session with the use of placeholders. I want to be able to build these batches entirely from within the TF graph.

Hopefully I am being silly and there is some obvious solution someone can point out. Also please forgive the camelCase.

import tensorflow as tf

def buildBatch(seqLength, batchSize):

    def lineToSequence(line):
        line = tf.expand_dims(line, axis=0)
        line = tf.sparse_tensor_to_dense(tf.string_split(line), '_')
        line = tf.concat([line, [['<GO>']]], 1)
        return line

    data = tf.contrib.data.TextLineDataset(['./exampleFile.txt'])
    data = data.map(lambda line: lineToSequence(line))
    iterator = data.make_initializable_iterator()

    # Grab lines from the file until the the sequence length is met and shave off any extra
    def getFixedLengthSequence(start):
        c = lambda s: tf.shape(s)[1] < seqLength # while sequence is is too short
        b = lambda s: tf.concat([s, iterator.get_next()], 1) # concatenate the next line
        sentences = tf.while_loop(c, b, [start], back_prop=False, parallel_iterations=1,
            shape_invariants=[tf.TensorShape([1, None])])

        clippedToLength = tf.expand_dims(sentences[0, :seqLength], axis=0)
        leftover = tf.expand_dims(sentences[0, seqLength:], axis=0)
        return clippedToLength, leftover

    # Placeholders pass in the start of each sequence (which are saved from the last batch)
    startOfThisBatch = [tf.placeholder(tf.string, shape=[1,None]) for i in range(batchSize)]
    # Capture what is leftover from each sequence so it can be passed in to start the next batch
    startOfNextBatch = [tf.TensorArray(tf.string, size=1) for i in range(batchSize)]

    # Build the batch
    thisBatch = []
    for i, seqStart in enumerate(startOfThisBatch):
        seq, leftover = getFixedLengthSequence(seqStart)
        thisBatch.append(seq)
        startOfNextBatch[i] = startOfNextBatch[i].write(0, leftover)
    thisBatch = tf.concat(thisBatch, axis=0)
    startOfNextBatch = [b.read(0) for b in startOfNextBatch]

    return thisBatch, startOfThisBatch, startOfNextBatch, iterator.initializer


def printBatch():
    sequenceLength = 10
    batchSize = 3

    batch, startOfThisBatch, startOfNextBatch, iteratorInit = buildBatch(sequenceLength, batchSize)
    # The very first batch starts with <GO> tokens
    batchStarts = [[['<GO>']]]*batchSize

    sv = tf.train.Supervisor()
    with sv.managed_session() as sess:
        sess.run(iteratorInit)
        for b in range(4):
            # Populate feed dict with the beginning of each sequence in the batch
            feed = {}
            for i in range(batchSize):
                feed[startOfThisBatch[i]] = batchStarts[i]

            # Call TF to get this batch and the starting sequences of the next batch
            out, batchStarts = sess.run([batch, startOfNextBatch], feed_dict=feed)

            print 'Batch', b, ':'
            for seq in out:
                print " ".join(seq)
            print

printBatch()

Result:

Batch 0 :  
<GO> A spokesman said the company has been affected by  
<GO> Having a little flexibility on that issue would go  
<GO> Long before the advent of e-commerce , Wal-Mart 's 

Batch 1 :  
the credit crunch in the United States . <GO> Abu  
a long way to putting together a final package .  
founder Sam Walton set out his vision for a successful  

Batch 2 :  
Dhabi is going ahead to build solar city and no  
<GO> Her back was torn open , her liver was  
retail operation : " We let folks know we 're  

Batch 3 :  
pollution city . <GO> Now it has 175 staging centers  
ruptured , one of her lungs had collapsed and the  
interested in them and that they 're vital to us--  

Notice that each sentence continues in the following batch. The example text file used is from the 1-billion word benchmark dataset and contains one sentence on each line.