6
votes

I'm trying to make a Tensorflow graph where part of the graph is already pre-trained and running in prediction mode, while the rest trains. I've defined my pre-trained cell like so:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)

state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]

outputs, states = tf.contrib.rnn.static_rnn(rnn_cell, 
                                            data_input,
                                            dtype=tf.float32,
                                            initial_state = pretrained_state)

Setting the initial variables to trainable=False doesn't help. These are just used to initialize the weights and as a result the weights still change.

I still need to run an optimizer in my training step, since the rest of my model needs to train. But how can I prevent the optimizer from changing the weights in this rnn cell?

Is there a rnn_cell equivalent to trainable=False?

1
The output of the pre-trained model is an input of your new model to be trained ? if so ,why not just pre-calculate the out of the pre-trained model? I mean that just keep two graphs independently.Yuwen Yan
@YuwenYan You're right, I could do this. I was hoping to avoid pre-calculating though by running the two graphs together, since it would be simpler to make sure all the data lined up and would save a step every time I wanted to change the pre-trained modelAlexR

1 Answers

3
votes

You can use either tf.stop_gradient() to prevent the pretrained parts of the graph from updating its weights or you can use the optimiser() where you can specify which parts of the graph should be trained. The second method would involve:

 #Create variable scope for the trainable parts of the graph: tf.variable_scope('train').

 # get trainable variables
 t_vars = tf.trainable_variables()
 train_vars = [var for var in t_vars if var.name.startswith('train')]
 # train only the variables of a particular scope
 opt = optimizer.minimize(cost, var_list=train_vars)