master
/ tools / pytorch_to_keras.py

pytorch_to_keras.py @b6444b2

2f424c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
"""
转换pytorch版本OCR到keras 
暂时只支持dense ocr ,lstm层不支持
"""
import numpy as np
def set_cnn_weight(name,keramodel,torchmodelDict):
    """
    将torch  模型CNN层导入 keras模型CNN层
    """
    weight = None
    bias   = None 
    
    for key in torchmodelDict:
        if name in key and 'weight' in key:
            weight = torchmodelDict[key].numpy()
        if name in key and 'bias' in key: 
            bias = torchmodelDict[key].numpy()
    if weight is not None and bias is not None:
        weight = weight.transpose(2, 3, 1, 0)
        keramodel.get_layer(name).set_weights([weight,bias])
    
    
def set_bn_weight(name,keramodel,torchmodelDict):
    """
    将torch  模型BN层导入 keras模型BN层
    Keras的BN层参数顺序应该是[gamma, beta, mean, std]
    """
    gamma, beta, mean, std = None,None,None,None
    
    for key in torchmodelDict:
        if name in key and 'weight' in key:
            gamma = torchmodelDict[key].numpy()
        if name in key and 'bias' in key: 
            beta = torchmodelDict[key].numpy()
            
        if name in key and 'running_mean' in key: 
            mean = torchmodelDict[key].numpy()
            
        if name in key and 'running_var' in key: 
            std = torchmodelDict[key].numpy()
            
    keramodel.get_layer(name).set_weights([gamma, beta, mean, std])
    
def set_dense_weight(name,keramodel,torchmodelDict):
    """
    将torch  模型linear层导入 keras模型dense层
    """
    weight = None
    bias   = None 
    
    for key in torchmodelDict:
        if name in key and 'weight' in key:
            weight = torchmodelDict[key].numpy()
        if name in key and 'bias' in key: 
            bias = torchmodelDict[key].numpy()
            
    if weight is not None and bias is not None:
        weight = np.transpose(weight)
        keramodel.get_layer(name).set_weights([weight,bias])
 
if __name__=='__main__':
    import os
    import sys
    GPUID=''
    os.environ["CUDA_VISIBLE_DEVICES"] = GPUID##不调用GPU
    sys.path.append('..')
    import torch
    from collections import OrderedDict
    from crnn.keys import alphabetChinese
    from crnn.network_keras import keras_crnn
    
    
    kerasModel = keras_crnn(32, 1, len(alphabetChinese)+1, 256, 1,lstmFlag=False)
    ocrModel='models/ocr-dense.pth'##目前只支持 dense ocr
    state_dict = torch.load(ocrModel,map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
            name = k.replace('module.','') # remove `module.`
            new_state_dict[name] = v
            
    ##模型转换
    cnn = ['cnn.conv0','cnn.conv1','cnn.conv2','cnn.conv3','cnn.conv4','cnn.conv5','cnn.conv6']
    BN =['cnn.batchnorm2','cnn.batchnorm4','cnn.batchnorm6']
    linear = ['linear']
    ##CNN 层
    for cn in cnn:
        set_cnn_weight(cn,kerasModel,new_state_dict)  

    ##BN 层
    for bn in BN:
        set_bn_weight(bn,kerasModel,new_state_dict)  
    ## linear 层
    for lr in linear:
        set_dense_weight(lr,kerasModel,new_state_dict) 
        
    kerasModel.save_weights('models/ocr-dense-keras.h5')##保存keras权重