I try to convert KERAS weight to PYTORCH. I specified the models and also convert the weight, and I checked the weight and bias at every layer, and they are the same. However, When I test it, I input an image, the result turn to be different, and I also check every layer feature, it is also different. The model is not complex, it has conv2d, maxpooling2d, BatchNorm2d, ReLU/Sigmoid, Dropout2d, one Linear. I check the input kernel size and padding and step, I guess they are all the same in PYTORCH and KERAS . Here is a process I convert weight: I wrote this function by myself:
def keras2torch(kerasNet,torchNet,kerasWeighpath):
'''
keras model should have the same structure as pytorch;
in pytorch, all the sublayer like con2d should only be wrap in one big sequential, in this case, the name is "layers";
filtername=[ 'MaxPool2d','MaxPool','Sigmoid','Dropout2d']
filtername should be the sublayer without weight.
Instruction: in keras, you can use model.get_weights() to get all the weights, however, say first layer is conv2d, the get_weights list will have 2 array which stand for conv2d weight and bias. In pytorch, you can just specific the name bias and weight to get them.
'''
filtername=[ 'MaxPool2d','MaxPool','Sigmoid','Dropout2d','Flatten']
filtername=[x.lower() for x in filtername]
kerasNet.load_weights(kerasWeighpath)
kerasweig=kerasNet.get_weights()
print(kerasNet.layers)
print(len(kerasNet.layers)==len(torchNet.layers))
sublayer_with_param=[]
# get the layer number with param
for number, module in torchNet.layers.named_children():
modulestr=(str(module)).lower()
if any(str in modulestr for str in filtername.lower()):
pass
else:
sublayer_with_param.append(number)
# check each layer param name, like weight, bias, running mean....
torchPall=torchNet.state_dict()
for i in sublayer_with_param:
print('**************')
kerasP=kerasNet.layers[int(i)].get_weights()
torchP=torchNet.layers[int(i)].state_dict()
torchPd=dict(torchP)
torch_Pname_list=list(torchPd)
print('keras param',len(kerasP))
print('torch param',len(torchP))
print(torch_Pname_list)
for P in range(len(kerasP)):
Kp=kerasP[P]
torch_Pname=torch_Pname_list[P]
print(torch_Pname)
Tp=torchPall[f"layers.{i}.{torch_Pname}"]
print('Keras before:',Kp.shape,'\n','torch before:',Tp.shape)
if Kp.shape!=Tp.shape:
KpShapeConv=np.transpose(Kp)
torchPall[f"layers.{i}.{torch_Pname}"]=torch.from_numpy(KpShapeConv)
else:
KpShapeConv=Kp
torchPall[f"layers.{i}.{torch_Pname}"]=torch.from_numpy(Kp)
print('Keras after:',KpShapeConv.shape,'\n','torch after:',Tp.shape)
print('####')
torchNet.load_state_dict(torchPall)