Yes, when using a BiLSTM the hidden states of the directions are just concatenated (the second part after the middle is the hidden state for feeding in the reversed sequence).
So splitting up in the middle works just fine.
As reshaping works from the right to the left dimensions you won't have any problems in separating the two directions.
Here is a small example:
direction_one_out = torch.tensor(range(5))
direction_two_out = torch.tensor(list(reversed(range(5))))
print('Direction one:')
print(direction_one_out)
print('Direction two:')
print(direction_two_out)
hidden = torch.cat((direction_one_out, direction_two_out), dim=0).view(1, 1, -1)
print('\nYour hidden output:')
print(hidden, hidden.shape)
hidden_reshaped = hidden.view(1, 1, 2, -1)
print('\nReshaped:')
print(hidden_reshaped, hidden_reshaped.shape)
print('\nThis also works for more multiple hidden states in a tensor:')
multi_hidden = hidden.expand(5, 1, 10)
print(multi_hidden, multi_hidden.shape)
print('Directions can be split up just like this:')
multi_hidden = multi_hidden.view(5, 1, 2, 5)
print(multi_hidden, multi_hidden.shape)
Output:
Direction one:
tensor([0, 1, 2, 3, 4])
Direction two:
tensor([4, 3, 2, 1, 0])
Your hidden output:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([1, 1, 10])
Reshaped:
tensor([[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]]]) torch.Size([1, 1, 2, 5])
This also works for more multiple hidden states in a tensor:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],
[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([5, 1, 10])
Directions can be split up just like this:
tensor([[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]],
[[[0, 1, 2, 3, 4],
[4, 3, 2, 1, 0]]]]) torch.Size([5, 1, 2, 5])
Hope this helps! :)