I am implementing a Seq2Seq model with multi-layer bidirectional rnn and attention mechanism and while following this tutorial https://github.com/tensorflow/nmt I got confused about how to manipulate correctly the encoder_state after the bidirectional layer.
Citing the tutorial "For multiple bidirectional layers, we need to manipulate the encoder_state a bit, see model.py, method _build_bidirectional_rnn() for more details." and this is the relevant part of the code (https://github.com/tensorflow/nmt/blob/master/nmt/model.py line 770):
encoder_outputs, bi_encoder_state = (
self._build_bidirectional_rnn(
inputs=self.encoder_emb_inp,
sequence_length=sequence_length,
dtype=dtype,
hparams=hparams,
num_bi_layers=num_bi_layers,
num_bi_residual_layers=num_bi_residual_layers))
if num_bi_layers == 1:
encoder_state = bi_encoder_state
else:
# alternatively concat forward and backward states
encoder_state = []
for layer_id in range(num_bi_layers):
encoder_state.append(bi_encoder_state[0][layer_id]) # forward
encoder_state.append(bi_encoder_state[1][layer_id]) # backward
encoder_state = tuple(encoder_state)
So this is what I have now:
def get_a_cell(lstm_size):
lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
#drop = tf.nn.rnn_cell.DropoutWrapper(lstm,
output_keep_prob=keep_prob)
return lstm
encoder_FW = tf.nn.rnn_cell.MultiRNNCell(
[get_a_cell(num_units) for _ in range(num_layers)])
encoder_BW = tf.nn.rnn_cell.MultiRNNCell(
[get_a_cell(num_units) for _ in range(num_layers)])
bi_outputs, bi_encoder_state = tf.nn.bidirectional_dynamic_rnn(
encoder_FW, encoder_BW, encoderInput,
sequence_length=x_lengths, dtype=tf.float32)
encoder_output = tf.concat(bi_outputs, -1)
encoder_state = []
for layer_id in range(num_layers):
encoder_state.append(bi_encoder_state[0][layer_id]) # forward
encoder_state.append(bi_encoder_state[1][layer_id]) # backward
encoder_state = tuple(encoder_state)
#DECODER -------------------
decoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(num_units) for _ in range(num_layers)])
# Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units_attention, encoder_output ,memory_sequence_length=x_lengths)
decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
decoder_cell,attention_mechanism,
attention_layer_size=num_units_attention)
decoder_initial_state = decoder_cell.zero_state(batch_size,tf.float32)
.clone(cell_state=encoder_state)
Problem is that I receive the error
The two structures don't have the same nested structure.
First structure: type=AttentionWrapperState
str=AttentionWrapperState(cell_state=(LSTMStateTuple(c=, h=),
LSTMStateTuple(c=, h=)), attention=, time=, alignments=, alignment_history=
(), attention_state=)
Second structure: type=AttentionWrapperState
str=AttentionWrapperState(cell_state=(LSTMStateTuple(c=, h=),
LSTMStateTuple(c=, h=), LSTMStateTuple(c=, h=), LSTMStateTuple(c=, h=)),
attention=, time=, alignments=, alignment_history=(), attention_state=)
And this makes kinda sense to me, because we are not including all the layers output, but (I guess) only the last layer. While for the state we are actually concatenating all the layers.
So as I was expecting, when only concatenating the last layer state like the following:
encoder_state = []
encoder_state.append(bi_encoder_state[0][num_layers-1]) # forward
encoder_state.append(bi_encoder_state[1][num_layers-1]) # backward
encoder_state = tuple(encoder_state)
It runs without errors.
In the best of my knowledge there is no part of code in which they transform the encoder_state again before passing it into the attention layer. So how could their code work? And more importantly, is my fix breaking the correct behavior of the attention mechanism?