I'm trying to make a Tensorflow graph where part of the graph is already pre-trained and running in prediction mode, while the rest trains. I've defined my pre-trained cell like so:
rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)
state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]
outputs, states = tf.contrib.rnn.static_rnn(rnn_cell,
data_input,
dtype=tf.float32,
initial_state = pretrained_state)
Setting the initial variables to trainable=False
doesn't help. These are just used to initialize the weights and as a result the weights still change.
I still need to run an optimizer in my training step, since the rest of my model needs to train. But how can I prevent the optimizer from changing the weights in this rnn cell?
Is there a rnn_cell equivalent to trainable=False
?