1
votes

I am implementing seq2seq model for text summerization using tensorflow. For encoder I'm using a bidirectional RNN layer. encoding layer:

    def encoding_layer(self, rnn_inputs, rnn_size, num_layers, keep_prob, 
                   source_vocab_size, 
                   encoding_embedding_size,
                   source_sequence_length,
                   emb_matrix):

    embed = tf.nn.embedding_lookup(emb_matrix, rnn_inputs)

    stacked_cells = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.LSTMCell(rnn_size), keep_prob) for _ in range(num_layers)])

    outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked_cells, 
                                                             cell_bw=stacked_cells, 
                                                             inputs=embed, 
                                                             sequence_length=source_sequence_length, 
                                                             dtype=tf.float32)

    concat_outputs = tf.concat(outputs, 2)

    return concat_outputs, state[0]

For decoder I'm using attention mechanism. Decoding Layer:

    def decoding_layer_train(self, encoder_outputs, encoder_state, dec_cell, dec_embed_input, 
                         target_sequence_length, max_summary_length, 
                         output_layer, keep_prob, rnn_size, batch_size):
    """
    Create a training process in decoding layer 
    :return: BasicDecoderOutput containing training logits and sample_id
    """

    dec_cell = tf.contrib.rnn.DropoutWrapper(dec_cell, 
                                             output_keep_prob=keep_prob)


    train_helper = tf.contrib.seq2seq.TrainingHelper(dec_embed_input, target_sequence_length)

    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(rnn_size, encoder_outputs,
                                                               memory_sequence_length=target_sequence_length)

    attention_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell, attention_mechanism,
                                                         attention_layer_size=rnn_size/2)

    state = attention_cell.zero_state(dtype=tf.float32, batch_size=batch_size)
    state = state.clone(cell_state=encoder_state)

    decoder = tf.contrib.seq2seq.BasicDecoder(cell=attention_cell, helper=train_helper, 
                                              initial_state=state,
                                              output_layer=output_layer) 
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations=max_summary_length)

    return outputs

Now, initial state of BasicDecoder function expects state of shape = (batch_size, rnn_size). My encoder outputs two states(forward & backward) of shape= (batch_size, rnn_size).

To make it work I'm using only one state of encoder(forward state). So, I want to know the possible ways to use both backward encoding and forward encoding of encoding layer. Should I add both forward and backward states?

P.S. - decoder don't use bidirectional layer.

2

2 Answers

0
votes

If you want to use only the backward encoding:

# Get only the last cell state of the backward cell
(_, _), (_, cell_state_bw) = tf.nn.bidirectional_dynamic_rnn(...)
# Pass the cell_state_bw as the initial state of the decoder cell
decoder = tf.contrib.seq2seq.BasicDecoder(..., initial_state=cell_state_bw, ...) 

What I suggest you do:

# Get both last states
(_, _), (cell_state_fw, cell_state_bw) = tf.nn.bidirectional_dynamic_rnn(...)
# Concatenate the cell states together
cell_state_final = tf.concat([cell_state_fw.c, cell_state_bw.c], 1)
# Concatenate the hidden states together
hidden_state_final = tf.concat([cell_state_fw.h, cell_state_bw.h], 1)
# Create the actual final state
encoder_final_state = tf.nn.rnn_cell.LSTMStateTuple(c=cell_state_final, h=hidden_state_final)
# Now you can pass this as the initial state of the decoder

However, beware, the size of the decoder cell has to be twice the size of the encoder cell for the second approach to work.

0
votes

Most of the things already covered in previous responses.

Regarding your concern "Should I add both forward and backward states?", according to me we should use both the states of encoder. Otherwise we are not utilizing the trained backward encoder state. Moreover "bidirectional_dynamic_rnn", should have two different layers of LSTM cells: One for FW state and another one for BW state.