Python 3.x 使用tf CoreML将Tensorflow传输至CoreML

Python 3.x 使用tf CoreML将Tensorflow传输至CoreML,python-3.x,tensorflow,tensorflow-lite,coreml,batch-normalization,Python 3.x,Tensorflow,Tensorflow Lite,Coreml,Batch Normalization,我有一个多输入网络,它使用tf.booltf.placeholder来管理如何在培训和验证/测试中执行批量标准化。 我一直试图通过tf CoreML库将这个经过训练的模型转换为CoreML,但没有成功,错误如下: tensorflow.python.framework.errors\u impl.InvalidArgumentError:Retval[26]没有值 我知道这个错误表示某个节点缺少一个值,因此转换器可以执行模型。我还了解到,此错误与控制流操作有关(链接到批量规范化方法,创建诸如Sw

我有一个多输入网络,它使用
tf.bool
tf.placeholder
来管理如何在培训和验证/测试中执行批量标准化。 我一直试图通过
tf CoreML
库将这个经过训练的模型转换为
CoreML
,但没有成功,错误如下:

tensorflow.python.framework.errors\u impl.InvalidArgumentError:Retval[26]没有值

我知道这个错误表示某个节点缺少一个值,因此转换器可以执行模型。我还了解到,此错误与控制流操作有关(链接到批量规范化方法,创建诸如
Switch
Merge
之类的操作)。报告显示:

def testSwitchDeadBranch(self):
    with self.cached_session():
      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = ops.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      dead_branch = array_ops.identity(switch_op[0])

      with self.assertRaisesWithPredicateMatch(
          errors_impl.InvalidArgumentError,
          lambda e: "Retval[0] does not have value" in str(e)):
        self.evaluate(dead_branch)
请注意,我的错误是
Retval[26]
(我得到了[24],等等),而不是
Retval[0]
。我假设它测试了
开关
“死分支”,它应该是用于推断的未使用分支。代码对
Merge
“死分支”也执行相同的操作

是否有任何我遗漏的细节可能导致这个错误(当然不是我在转换过程中遇到的第一个错误)?推理的方式是什么?批处理规范化是如何实现的?保存模型的方式

到目前为止我所做的:

  • 我使用的是
    Tensorflow 1.14.0
  • 我知道
    tf.layers.batch_normalization
    创建操作
    Switch
    Merge
    ,这两个操作与CoreML不兼容
  • 我尝试过转换为Tensorflow Lite,但遇到了类似的问题
  • 我遵循了
    Facenet
    (此模型使用相同的
    tf.bool
    逻辑进行培训、验证和测试)转换过程,但没有成功
  • 我试过图书馆
  • 我已尝试使用脚本删除/修改控制流
  • 我创建了单独的图表来避免额外的操作,但没有成功
注意:我已经提取了大部分代码来发布这个问题

这就是批量标准化的实现方式(在卷积块中)

下面是训练和保存模型的代码

saver = tf.train.Saver()

    with tf.Session(config = config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(init_train_op)

        for epoch in range(MAX_EPOCHS):

            for step in range(10):

                l, _, se = sess.run(
                    [loss, train_op, mean_squared_error],
                     feed_dict = {training: True})

            print('\nRunning validation operation...')

            sess.run(init_val_op)
            for _ in range(10):
                val_out, val_l, val_se = sess.run(
                    [out, val_loss, val_mean_squared_error],
                    feed_dict = {training: False})

            sess.run(init_train_op) # switch back to training set

        #Save model
        print('Saving Model...\n')
        saver.save(sess, join(saveDir, './model_saver_validation'.format(modelIndex)), write_meta_graph = True)
# Dummy data for inference
b = np.zeros((1, 80, 160, 1), np.float32)
ill = np.ones((1,3), np.float32)
is_train = False

def freeze():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            bIn = tf.placeholder(dtype=tf.float32, shape=[
                             1, 80, 160, 1], name='bIn')
            illumIn = tf.placeholder(dtype=tf.float32, shape=[
                                     1, 3], name='illumIn')
            training = tf.placeholder(tf.bool, shape=(), name = 'training')

            # Load the model metagraph and checkpoint
            meta_file = meta_graph #.meta file from saver.save()
            ckpt_file = checkpoint_file #checkpoint file

            # Load graph to redirect inputs from iterator to expected inputs
            saver = tf.train.import_meta_graph(meta_file, input_map={
                'IteratorGetNext:0': bIn,
                'IteratorGetNext:3': illumIn,
                'training:0': training},  clear_devices = True)

            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), ckpt_file)

            pred = tf.get_default_graph().get_tensor_by_name('Out:0')

            tf.get_default_session().run(pred, feed_dict={'bIn:0': b, 'poseIn:0': po, 'training:0': is_train})

            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()

            # Freeze the graph def
            output_graph_def = freeze_graph_def(
                sess, input_graph_def, output_node_names)

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(frozen_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())

freeze()
下面是加载、更新输入、执行推断和冻结模型的代码

saver = tf.train.Saver()

    with tf.Session(config = config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(init_train_op)

        for epoch in range(MAX_EPOCHS):

            for step in range(10):

                l, _, se = sess.run(
                    [loss, train_op, mean_squared_error],
                     feed_dict = {training: True})

            print('\nRunning validation operation...')

            sess.run(init_val_op)
            for _ in range(10):
                val_out, val_l, val_se = sess.run(
                    [out, val_loss, val_mean_squared_error],
                    feed_dict = {training: False})

            sess.run(init_train_op) # switch back to training set

        #Save model
        print('Saving Model...\n')
        saver.save(sess, join(saveDir, './model_saver_validation'.format(modelIndex)), write_meta_graph = True)
# Dummy data for inference
b = np.zeros((1, 80, 160, 1), np.float32)
ill = np.ones((1,3), np.float32)
is_train = False

def freeze():
    with tf.Graph().as_default():
        with tf.Session() as sess:
            bIn = tf.placeholder(dtype=tf.float32, shape=[
                             1, 80, 160, 1], name='bIn')
            illumIn = tf.placeholder(dtype=tf.float32, shape=[
                                     1, 3], name='illumIn')
            training = tf.placeholder(tf.bool, shape=(), name = 'training')

            # Load the model metagraph and checkpoint
            meta_file = meta_graph #.meta file from saver.save()
            ckpt_file = checkpoint_file #checkpoint file

            # Load graph to redirect inputs from iterator to expected inputs
            saver = tf.train.import_meta_graph(meta_file, input_map={
                'IteratorGetNext:0': bIn,
                'IteratorGetNext:3': illumIn,
                'training:0': training},  clear_devices = True)

            tf.get_default_session().run(tf.global_variables_initializer())
            tf.get_default_session().run(tf.local_variables_initializer())
            saver.restore(tf.get_default_session(), ckpt_file)

            pred = tf.get_default_graph().get_tensor_by_name('Out:0')

            tf.get_default_session().run(pred, feed_dict={'bIn:0': b, 'poseIn:0': po, 'training:0': is_train})

            # Retrieve the protobuf graph definition and fix the batch norm nodes
            input_graph_def = sess.graph.as_graph_def()

            # Freeze the graph def
            output_graph_def = freeze_graph_def(
                sess, input_graph_def, output_node_names)

        # Serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(frozen_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())

freeze()
下面是转换为CoreML的代码

tfcoreml.convert(
    tf_model_path=frozen_graph,
    mlmodel_path='./coreml_model.mlmodel',
    output_feature_names=['Out:0'],
    input_name_shape_dict={
        'bIn:0': [1, 80, 160, 1],
        'illumIn:0': [1, 3], 
        'training:0': []})
下面是
tf coreml
引发的错误

Loading the TF graph...
Graph Loaded.
Collecting all the 'Const' ops from the graph, by running it....

Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tf2opencv.py", line 392, in <module>
    'illumIn:0': [1, 3], 'poseIn:0': [1, 16], 'training:0': []})
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 586, in convert
    custom_conversion_functions=custom_conversion_functions)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tfcoreml/_tf_coreml_converter.py", line 243, in _convert_pb_to_mlmodel
    tensors_evaluated = sess.run(tensors, feed_dict=input_feed_dict)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Retval[26] does not have value
正在加载TF图。。。
已加载图形。
通过运行图,从图中收集所有“Const”操作。。。。
回溯(最近一次呼叫最后一次):
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1356行,在
返回fn(*args)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py”,第1341行,在
选项、提要、获取列表、目标列表、运行元数据)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1429行,在调用会话运行中
运行(元数据)
tensorflow.python.framework.errors\u impl.InvalidArgumentError:Retval[26]没有值
在处理上述异常期间,发生了另一个异常:
回溯(最近一次呼叫最后一次):
文件“tf2opencv.py”,第392行,在
‘illumIn:0’:[1,3],‘poseIn:0’:[1,16],‘training:0’:[]}
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tfcoreml/_tf_coreml_converter.py”,第586行,转换为
自定义转换函数=自定义转换函数)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tfcoreml/_tf_coreml_converter.py”,第243行,在“convert_pb_to_mlmodel”中
计算的张量=sess.run(张量,feed\u dict=input\u feed\u dict)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第950行,正在运行
运行_元数据_ptr)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1173行,正在运行
feed_dict_tensor、options、run_元数据)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1350行,在
运行(元数据)
文件“/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site packages/tensorflow/python/client/session.py”,第1370行,在
提升类型(e)(节点定义、操作、消息)
tensorflow.python.framework.errors\u impl.InvalidArgumentError:Retval[26]没有值