I'm trying to solve a regression problem with a stacked RNN in tensorflow. The RNN output should be fed into a fully connected layer for the final prediction. Currently I'm struggeling on how to feed the RNN output into the final fully_connected layer. My input is of shape [batch_size, max_sequence_length, num_features]
The RNN Layers are created like this:
cells = []
for i in range(num_rnn_layers):
cell = tf.contrib.rnn.LSTMCell(num_rnn_units)
cells.append(cell)
multi_rnn_cell = tf.contrib.rnn.MultiRNNCell(cells)
outputs, states = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=Bx_rnn,dtype=tf.float32)
Outputs is of shape [batch_size, max_sequence_length, num_rnn_units] I tried using only the output of the last time step like this:
final_outputs = tf.contrib.layers.fully_connected(
outputs[:,-1,:],
n_targets,
activation_fn=None)
I also found examples and books recommending to reshape the output and target like this:
rnn_outputs = tf.reshape(outputs, [-1, num_rnn_units])
y_reshaped = tf.reshape(y, [-1])
Since I'm currently using a batch size of 500 and a sequence length of 10000 this leads into huge matrices, really long training times and huge memory consumption.
I've also found many articles recommending unstacking the inputs and stacking outputs again, which I couldn't get to work due to shape mismatches.
What would be the correct way to feed the RNN output into a fully_connected layer? Or should I use the RNN states over outputs?
Edit: For Clarification: I do need these long sequences, because I'm trying to model a physical system. The Input is a single feature, consisting of a white noise. I have multiple outputs (in this specific system 45). Impulses effect System state for round about 10.000 time steps.
i.e. currently I'm trying to model a cars gear bridging which was animated by a shaker. Outputs were measured by 15 acceleration sensors into 3 directions (X,Y & Z).
Batch size of 500 was arbitrarily picked.
Regardless of probably vanishing gradients or potential memory issues by long sequences, I'd be interested in how to feed data correctly. We do have appropriate hardware (i.e. Nvidia Titan V). Furthermore we were already able to model system behaviour by classic DNN's with lags of >3000 time steps with good accuracy.