I'd like to construct a network in the Tensorflow V2 object detection API using 5-channel images. However, I am stuck on how to modify the weights of the first convolutional layer using the Tensorflow 2.2 framework.
I have downloaded the pre-trained RetinaNet from the V2 Model Zoo. I then tried the following to modify the weights in the first layer of the checkpoint and save them back:
tf_path = tf.train.latest_checkpoint('./RetinaNet/checkpoint/')
init_vars = tf.train.list_variables(tf_path)
tf_vars = {}
for name, shape in init_vars:
array = tf.train.load_variable(tf_path, name)
try:
if shape[2]==3:#look for a layer who's 3rd input dimension is 3 i.e. the 1st convolutional layer
array=np.concatenate((array,array[:,:,:2,:]),axis=2)
array=array.astype('float32')
tf_vars[name]=tf.Variable(array)
else:
tf_vars[name]=tf.Variable(array)
except:
tf_vars[name]=tf.Variable(array)
saver = tf.compat.v1.train.Saver(var_list=tf_vars)
sess = tf.compat.v1.Session()
saver.save(sess, './RetinaNet/checkpoint/ckpt-0')
I loaded the model back in to make sure the 1st convolutional layer had been changed - all looks ok.
But when I go to train the model, I get the following error: Model was constructed with shape (None, None, None, 3) for input Tensor("input_1:0", shape=(None, None, None, 3), dtype=float32), but it was called on an input with incompatible shape (64, 128, 128, 5)
Which leads me to believe my method of modifying the weights is not so "ok" after all. Would anyone have some tips on how to modify these weights correctly?
Thanks