0
votes

When looking at the RNN example at Tensorflow im having an issue with how the initial state is constructed. At build time of the graph we limit the graph to only handle input of one batch size. This is an issue for me since I want to be able feed in a single example and get a prediction for that single example.

The part of the code that restricts this is:

initial_state = state = tf.zeros([batch_size, lstm.state_size])

So my question is how can I expand the example so that I can use a variable batch size so that I can use the same model for training with batch size and then use single example for predictions?

2

2 Answers

2
votes

This is how I'm doing this. You can pass the batch_size as a variable like this:

batch_size = tf.placeholder(tf.int32)
init_state = cell.zero_state(batch_size, tf.float32)

where cell is one of RNN cells (BasicLSTMCell, BasicGRUCell, MultiRNNCell, etc). However, if you're preserving the state over multiple batches that won't work since its' size has to be constant.

0
votes

The Tensorflow text generation tutorial explains how to do this (now TF 2.0). It seems that the batch_size becomes part of the built model, so you have to rebuild/reload from the saved weights with a new batch size:

https://www.tensorflow.org/tutorials/text/text_generation#restore_the_latest_checkpoint

To keep this prediction step simple, use a batch size of 1.

Because of the way the RNN state is passed from timestep to timestep, the model only accepts a fixed batch size once built.

To run the model with a different batch_size, we need to rebuild the model and restore the weights from the checkpoint.

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()

I don't know for sure why you have to do this, but I always assumed it's because batching for recurrent layers requires management of multiple, parallel hidden state pipelines, so it preallocates them.