5
votes

There are quite a few examples on how to use LSTMs alone in TF, but I couldn't find any good examples on how to train CNN + LSTM jointly. From what I see, it is not quite straightforward how to do such training, and I can think of a few options here:

  • First, I believe the simplest solution (or the most primitive one) would be to train CNN independently to learn features and then to train LSTM on CNN features without updating the CNN part, since one would probably have to extract and save these features in numpy and then feed them to LSTM in TF. But in that scenario, one would probably have to use a differently labeled dataset for pretraining of CNN, which eliminates the advantage of end to end training, i.e. learning of features for final objective targeted by LSTM (besides the fact that one has to have these additional labels in the first place).
  • Second option would be to concatenate all time slices in the batch dimension (4-d Tensor), feed it to CNN then somehow repack those features to 5-d Tensor again needed for training LSTM and then apply a cost function. My main concern, is if it is possible to do such thing. Also, handling variable length sequences becomes a little bit tricky. For example, in prediction scenario you would only feed single frame at the time. Thus, I would be really happy to see some examples if that is the right way of doing joint training. Besides that, this solution looks more like a hack, thus, if there is a better way to do so, it would be great if someone could share it.

Thank you in advance !

1

1 Answers

2
votes

For joint training, you can consider using tf.map_fn as described in the documentation https://www.tensorflow.org/api_docs/python/tf/map_fn.

Lets assume that the CNN is built along similar lines as described here https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py.

def joint_inference(sequence):
    inference_fn = lambda image: inference(image)
    logit_sequence = tf.map_fn(inference_fn, sequence, dtype=tf.float32, swap_memory=True)
    lstm_cell = tf.contrib.rnn.LSTMCell(128)
    output_state, intermediate_state = tf.nn.dynamic_rnn(cell=lstm_cell, inputs=logit_sequence)
    projection_function = lambda state: tf.contrib.layers.linear(state, num_outputs=num_classes, activation_fn=tf.nn.sigmoid)
    projection_logits = tf.map_fn(projection_function, output_state)
    return projection_logits

Warning: You might have to look into device placement as described here https://www.tensorflow.org/tutorials/using_gpu if your model is larger than the memory gpu can allocate.

An Alternative would be to flatten the video batch to create an image batch, do a forward pass from CNN and reshape the features for LSTM.