Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/matlab/16.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Matlab 如何将训练过的TF网络导出到h5文件_Matlab_Tensorflow - Fatal编程技术网

Matlab 如何将训练过的TF网络导出到h5文件

Matlab 如何将训练过的TF网络导出到h5文件,matlab,tensorflow,Matlab,Tensorflow,我有一个经过训练的模型保存在3个文件中,还有一些tensorflow(1.x)代码行可以使用它。 我需要将模型转换为.h5(keras),以便在Matlab中导入它 模型保存在3个文件中: xxx.ckpt.data-00000-of-00001 xxx.ckpt.index xxx.ckpt.meta tensorflow代码行如下所示: zzz = "a function of the previous line" session_name = "a function of xxx" s

我有一个经过训练的模型保存在3个文件中,还有一些tensorflow(1.x)代码行可以使用它。 我需要将模型转换为.h5(keras),以便在Matlab中导入它

模型保存在3个文件中:

  • xxx.ckpt.data-00000-of-00001
  • xxx.ckpt.index
  • xxx.ckpt.meta
tensorflow代码行如下所示:

zzz = "a function of the previous line"
session_name = "a function of xxx"
sess = tf.Session()  
saver = tf.train.Saver()
saver.restore(sess,session_name)
output=sess.run(zzz, ...)
是否可以将经过培训的网络导出为H5文件? 如果是,怎么做


补充:我已经向一位朋友提出了要求,显然这不可能直接实现,因为首先,我必须“翻译”Keras的网络结构。这一步对于我在该领域的级别来说太难了。

据我所知,目前还没有使用检查点文件并直接转换到h5(keras)的方法。相反,您可以通过如下所述的一些变通方法来实现这一点

  • 如果您想要网络架构,那么您需要在Keras中重写代码
  • 如果您只想转换权重(假设您有相同模型的代码),则必须使用随机权重创建一个模型,然后使用
    tf.train.NewCheckpointReader
    读取
    TensorFlow.ckpt
    文件,然后为每个对应层调用
    set_weights()
    方法 下面提到的相同的示例代码:

    reader = tf.train.NewCheckpointReader(filename)
    
    for key in reader.get_variable_to_shape_map():
        # not saving the following tensors
        if key == 'global_step':
            continue
        if 'AuxLogit' in key:
            continue
    
        # convert tensor name into the corresponding Keras layer weight name and save
        path = os.path.join(output_folder, get_filename(key))
        arr = reader.get_tensor(key)
        np.save(path, arr)
    
    您可以使用保存的.npy文件输出h5(keras)中的模型权重,如下所示

    for layer in tqdm(model.layers):
        if layer.weights:
            weights = []
            for w in layer.weights:
                weight_name = os.path.basename(w.name).replace(':0', '')
                weight_file = layer.name + '_' + weight_name + '.npy'
                weight_arr = np.load(numpy_weight_file)
    
            # remove the "background class"
            if weight_file.startswith('Logits_bias'):
                weight_arr = weight_arr[1:]
            elif weight_file.startswith('Logits_kernel'):
                weight_arr = weight_arr[:, 1:]
    
            weights.append(weight_arr)
        layer.set_weights(weights)
    model.save_weights("Keras_model.h5")
    

    @Zal78-如果答案回答了您的问题,请向上投票并接受答案。非常感谢。