0
votes

I'd like to construct a network in the Tensorflow V2 object detection API using 5-channel images. However, I am stuck on how to modify the weights of the first convolutional layer using the Tensorflow 2.2 framework.

I have downloaded the pre-trained RetinaNet from the V2 Model Zoo. I then tried the following to modify the weights in the first layer of the checkpoint and save them back:

tf_path = tf.train.latest_checkpoint('./RetinaNet/checkpoint/')
init_vars = tf.train.list_variables(tf_path)
tf_vars = {}
for name, shape in init_vars:

    array = tf.train.load_variable(tf_path, name)
    try:
        if shape[2]==3:#look for a layer who's 3rd input dimension is 3 i.e. the 1st convolutional layer
            array=np.concatenate((array,array[:,:,:2,:]),axis=2)
            array=array.astype('float32')
            tf_vars[name]=tf.Variable(array)
            
        else:
            tf_vars[name]=tf.Variable(array)
            
    except:
        tf_vars[name]=tf.Variable(array)
        
        
saver = tf.compat.v1.train.Saver(var_list=tf_vars)
sess = tf.compat.v1.Session()
saver.save(sess, './RetinaNet/checkpoint/ckpt-0')

I loaded the model back in to make sure the 1st convolutional layer had been changed - all looks ok.

But when I go to train the model, I get the following error: Model was constructed with shape (None, None, None, 3) for input Tensor("input_1:0", shape=(None, None, None, 3), dtype=float32), but it was called on an input with incompatible shape (64, 128, 128, 5)

Which leads me to believe my method of modifying the weights is not so "ok" after all. Would anyone have some tips on how to modify these weights correctly?

Thanks

1

1 Answers

0
votes

This now works but the solution is very hacky... it also means not training from the pretrained weights from the model zoo - so you need to comment everything to do with the fine_tune_checkpoint in the config file. Then, go to .\Lib\site-packages\official\vision\image_classification\efficientnet and change the number of input channels and number of classes in efficientnet_model.py and efficientnet_config.py.