0
votes

I try to convert KERAS weight to PYTORCH. I specified the models and also convert the weight, and I checked the weight and bias at every layer, and they are the same. However, When I test it, I input an image, the result turn to be different, and I also check every layer feature, it is also different. The model is not complex, it has conv2d, maxpooling2d, BatchNorm2d, ReLU/Sigmoid, Dropout2d, one Linear. I check the input kernel size and padding and step, I guess they are all the same in PYTORCH and KERAS . Here is a process I convert weight: I wrote this function by myself:

def keras2torch(kerasNet,torchNet,kerasWeighpath):
    '''
    keras model should have the same structure as pytorch;
    in pytorch, all the sublayer like con2d should only be wrap in one big sequential, in this case, the name is "layers";
    filtername=[ 'MaxPool2d','MaxPool','Sigmoid','Dropout2d']
    filtername should be the sublayer without weight.
    Instruction: in keras, you can use model.get_weights() to get all the weights, however, say first layer is conv2d, the get_weights list will have 2 array which stand for conv2d weight and bias. In pytorch, you can just specific the name bias and weight to get them.
    '''
    filtername=[ 'MaxPool2d','MaxPool','Sigmoid','Dropout2d','Flatten']
    filtername=[x.lower() for x in filtername]
    kerasNet.load_weights(kerasWeighpath)
    kerasweig=kerasNet.get_weights()
    print(kerasNet.layers)
    print(len(kerasNet.layers)==len(torchNet.layers))
    sublayer_with_param=[]
    # get the layer number with param
    for number, module in torchNet.layers.named_children():
        modulestr=(str(module)).lower()
        if any(str in modulestr for str in filtername.lower()):
            pass
        else:
            sublayer_with_param.append(number)
    # check each layer param name, like weight, bias, running mean....
    torchPall=torchNet.state_dict()
    for i in sublayer_with_param:
        print('**************')
        kerasP=kerasNet.layers[int(i)].get_weights()
        torchP=torchNet.layers[int(i)].state_dict()
        torchPd=dict(torchP)
        torch_Pname_list=list(torchPd)
        print('keras param',len(kerasP))
        print('torch param',len(torchP))
        print(torch_Pname_list)
        for P in range(len(kerasP)):
            Kp=kerasP[P]
            torch_Pname=torch_Pname_list[P]
            print(torch_Pname)
            Tp=torchPall[f"layers.{i}.{torch_Pname}"]
            print('Keras before:',Kp.shape,'\n','torch before:',Tp.shape)
            if Kp.shape!=Tp.shape:
                KpShapeConv=np.transpose(Kp)
                torchPall[f"layers.{i}.{torch_Pname}"]=torch.from_numpy(KpShapeConv)
            else:
                KpShapeConv=Kp
                torchPall[f"layers.{i}.{torch_Pname}"]=torch.from_numpy(Kp)
            print('Keras after:',KpShapeConv.shape,'\n','torch after:',Tp.shape)
        print('####')
        torchNet.load_state_dict(torchPall)