0
votes

I'm learning how to build a seq2seq model based on this TensorFlow 2 NMT tutorial, and I'm trying to expand upon it by stacking multiple RNN layers for the encoder and decoder. However, I'm having trouble retrieving the output which corresponds to the hidden state of the encoder.

Here's my code for building the stacked bidirectional GRUCell layers in the encoder:

# Encoder initializer
def __init__(self, n_layers, dropout, ...):
    ...
    gru_cells = [layers.GRUCell(units, 
                                recurrent_initializer='glorot_uniform',
                                dropout=dropout)
                 for _ in range(n_layers)]
    self.gru = layers.Bidirectional(layers.RNN(gru_cells,
                                               return_sequences=True,
                                               return_state=True))

Assuming the above is correct, I then call the layer I created:

# Encoder call method
def call(self, inputs, state):
    ...
    list_outputs = self.gru(inputs, initial_state=state)
    print(len(list_outputs)) # test

list_outputs has length 3 when n_layers = 1, which is expected behavior according to this SO post. When I increase n_layers by one, I find that the number outputs increases by two, which I presume are the forward and reverse final states of the new layer. So 2 layers -> 5 outputs, 3 layers -> 7 outputs, etc. However, I can't figure out which output corresponds to which layer and in which direction.

Ultimately what I'd like to know is: how can I get the forward and reverse final states of the last layer in this stacked bidirectional RNN? If I understand the seq2seq model correctly, they make up the hidden state that is passed to the decoder.

1

1 Answers

0
votes

After digging through TensorFlow source code for the RNN and Bidirectional classes, my best guess for the output format of a stacked bidirectional RNN layer is the following 1+2n tuple, where n is the number of stacked layers:

  • [0] concatenation of forward and backward state across the RNN
  • [1 : len//2 + 1] final state of forward layers, from first to last
  • [len//2 + 1:] final state of reverse layers, from first to last