1
votes

I am implementing a Seq2Seq model with multi-layer bidirectional rnn and attention mechanism and while following this tutorial https://github.com/tensorflow/nmt I got confused about how to manipulate correctly the encoder_state after the bidirectional layer.

Citing the tutorial "For multiple bidirectional layers, we need to manipulate the encoder_state a bit, see model.py, method _build_bidirectional_rnn() for more details." and this is the relevant part of the code (https://github.com/tensorflow/nmt/blob/master/nmt/model.py line 770):

encoder_outputs, bi_encoder_state = (
            self._build_bidirectional_rnn(
            inputs=self.encoder_emb_inp,
            sequence_length=sequence_length,
            dtype=dtype,
            hparams=hparams,
            num_bi_layers=num_bi_layers,
            num_bi_residual_layers=num_bi_residual_layers))

if num_bi_layers == 1:
   encoder_state = bi_encoder_state
else:
   # alternatively concat forward and backward states
   encoder_state = []
   for layer_id in range(num_bi_layers):
      encoder_state.append(bi_encoder_state[0][layer_id])  # forward
      encoder_state.append(bi_encoder_state[1][layer_id])  # backward
   encoder_state = tuple(encoder_state)

So this is what I have now:

def get_a_cell(lstm_size):
    lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
    #drop = tf.nn.rnn_cell.DropoutWrapper(lstm, 
                       output_keep_prob=keep_prob)
    return lstm


encoder_FW = tf.nn.rnn_cell.MultiRNNCell(
    [get_a_cell(num_units) for _ in range(num_layers)])
encoder_BW = tf.nn.rnn_cell.MultiRNNCell(
    [get_a_cell(num_units) for _ in range(num_layers)])


bi_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
encoder_FW, encoder_BW, encoderInput,
sequence_length=x_lengths, dtype=tf.float32)
encoder_output = tf.concat(bi_outputs, -1)

encoder_state = []

for layer_id in range(num_layers):
    encoder_state.append(bi_encoder_state[0][layer_id])  # forward
    encoder_state.append(bi_encoder_state[1][layer_id])  # backward
encoder_state = tuple(encoder_state)

#DECODER -------------------

decoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(num_units) for _ in range(num_layers)])

# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units_attention, encoder_output ,memory_sequence_length=x_lengths)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
              decoder_cell,attention_mechanism,
              attention_layer_size=num_units_attention)

decoder_initial_state = decoder_cell.zero_state(batch_size,tf.float32)
                        .clone(cell_state=encoder_state)

Problem is that I receive the error

The two structures don't have the same nested structure.

First structure: type=AttentionWrapperState 
str=AttentionWrapperState(cell_state=(LSTMStateTuple(c=, h=), 
LSTMStateTuple(c=, h=)), attention=, time=, alignments=, alignment_history=
(), attention_state=)

Second structure: type=AttentionWrapperState 
str=AttentionWrapperState(cell_state=(LSTMStateTuple(c=, h=), 
LSTMStateTuple(c=, h=), LSTMStateTuple(c=, h=), LSTMStateTuple(c=, h=)), 
attention=, time=, alignments=, alignment_history=(), attention_state=)

And this makes kinda sense to me, because we are not including all the layers output, but (I guess) only the last layer. While for the state we are actually concatenating all the layers.

So as I was expecting, when only concatenating the last layer state like the following:

encoder_state = []
encoder_state.append(bi_encoder_state[0][num_layers-1])  # forward
encoder_state.append(bi_encoder_state[1][num_layers-1])  # backward
encoder_state = tuple(encoder_state)

It runs without errors.

In the best of my knowledge there is no part of code in which they transform the encoder_state again before passing it into the attention layer. So how could their code work? And more importantly, is my fix breaking the correct behavior of the attention mechanism?

1

1 Answers

0
votes

Here is the problem:

only the encoder is bi-directional, but you give bi-states to the decoder (which is always uni-directional).

Solution:

What you have to do is simply concat the states, so, you manipulate "uni-directional data" again !

encoder_state = []

for layer_id in range(num_layers):
    state_fw = bi_encoder_state[0][layer_id]
    state_bw = bi_encoder_state[1][layer_id]

    # Merging the fw state and the bw state
    cell_state = tf.concat([state_fw.c, state_bw.c], 1)
    hidden_state= tf.concat([state_fw.h, state_bw.h], 1)

    # This state as the same structure than an uni-directional encoder state
    state = tf.nn.rnn_cell.LSTMStateTuple(c=cell_state, h=hidden_state)

    encoder_state.append(state)

encoder_state = tuple(encoder_state)