14
votes

Hi I have a question about how to collect the correct result from a BI-LSTM module’s output.

Suppose I have a 10-length sequence feeding into a single-layer LSTM module with 100 hidden units:

lstm = nn.LSTM(5, 100, 1, bidirectional=True)

output will be of shape:

[10 (seq_length), 1 (batch),  200 (num_directions * hidden_size)]
# or according to the doc, can be viewed as
[10 (seq_length), 1 (batch),  2 (num_directions), 100 (hidden_size)]

If I want to get the 3rd (1-index) input’s output at both directions (two 100-dim vectors), how can I do it correctly?

I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?

A thing bothering me is that when do reverse feeding, the 3rd (1-index) output vector is calculated from the 8th(1-index) input, right?

Will pytorch automatically take care of this and group output considering direction?

Thanks!

2

2 Answers

8
votes

Yes, when using a BiLSTM the hidden states of the directions are just concatenated (the second part after the middle is the hidden state for feeding in the reversed sequence).
So splitting up in the middle works just fine.

As reshaping works from the right to the left dimensions you won't have any problems in separating the two directions.


Here is a small example:

# so these are your original hidden states for each direction
# in this case hidden size is 5, but this works for any size
direction_one_out = torch.tensor(range(5))
direction_two_out = torch.tensor(list(reversed(range(5))))
print('Direction one:')
print(direction_one_out)
print('Direction two:')
print(direction_two_out)

# before outputting they will be concatinated 
# I'm adding here batch dimension and sequence length, in this case seq length is 1
hidden = torch.cat((direction_one_out, direction_two_out), dim=0).view(1, 1, -1)
print('\nYour hidden output:')
print(hidden, hidden.shape)

# trivial case, reshaping for one hidden state
hidden_reshaped = hidden.view(1, 1, 2, -1)
print('\nReshaped:')
print(hidden_reshaped, hidden_reshaped.shape)

# This works as well for abitrary sequence lengths as you can see here
# I've set sequence length here to 5, but this will work for any other value as well
print('\nThis also works for more multiple hidden states in a tensor:')
multi_hidden = hidden.expand(5, 1, 10)
print(multi_hidden, multi_hidden.shape)
print('Directions can be split up just like this:')
multi_hidden = multi_hidden.view(5, 1, 2, 5)
print(multi_hidden, multi_hidden.shape)

Output:

Direction one:
tensor([0, 1, 2, 3, 4])
Direction two:
tensor([4, 3, 2, 1, 0])

Your hidden output:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([1, 1, 10])

Reshaped:
tensor([[[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]]]) torch.Size([1, 1, 2, 5])

This also works for more multiple hidden states in a tensor:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([5, 1, 10])
Directions can be split up just like this:
tensor([[[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]]]) torch.Size([5, 1, 2, 5])

Hope this helps! :)

7
votes

I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?

The answer is YES.

The output tensor of LSTM module output is the concatenation of forward LSTM output and backward LSTM output at corresponding postion in input sequence. And h_n tensor is the output at last timestamp which is output of the lsat token in forward LSTM but the first token in backward LSTM.

In [1]: import torch
   ...: lstm = torch.nn.LSTM(input_size=5, hidden_size=3, bidirectional=True)
   ...: seq_len, batch, input_size, num_directions = 3, 1, 5, 2
   ...: in_data = torch.randint(10, (seq_len, batch, input_size)).float()
   ...: output, (h_n, c_n) = lstm(in_data)
   ...: 

In [2]: # output of shape (seq_len, batch, num_directions * hidden_size)
   ...: 
   ...: print(output)
   ...: 
tensor([[[ 0.0379,  0.0169,  0.2539,  0.2547,  0.0456, -0.1274]],

        [[ 0.7753,  0.0862, -0.0001,  0.3897,  0.0688, -0.0002]],

        [[ 0.7120,  0.2965, -0.3405,  0.0946,  0.0360, -0.0519]]],
       grad_fn=<CatBackward>)

In [3]: # h_n of shape (num_layers * num_directions, batch, hidden_size)
   ...: 
   ...: print(h_n)
   ...: 
tensor([[[ 0.7120,  0.2965, -0.3405]],

        [[ 0.2547,  0.0456, -0.1274]]], grad_fn=<ViewBackward>)

In [4]: output = output.view(seq_len, batch, num_directions, lstm.hidden_size)
   ...: print(output[-1, 0, 0])  # forward LSTM output of last token
   ...: print(output[0, 0, 1])  # backward LSTM output of first token
   ...: 
tensor([ 0.7120,  0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547,  0.0456, -0.1274], grad_fn=<SelectBackward>)

In [5]: h_n = h_n.view(lstm.num_layers, num_directions, batch, lstm.hidden_size)
   ...: print(h_n[0, 0, 0])  # h_n of forward LSTM
   ...: print(h_n[0, 1, 0])  # h_n of backward LSTM
   ...: 
tensor([ 0.7120,  0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547,  0.0456, -0.1274], grad_fn=<SelectBackward>)