Warning: file_get_contents(/data/phpspider/zhask/data//catemap/8/design-patterns/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Arm 从TensorFlow图中删除掉操作_Arm_Tensorflow - Fatal编程技术网

Arm 从TensorFlow图中删除掉操作

Arm 从TensorFlow图中删除掉操作,arm,tensorflow,Arm,Tensorflow,我有一个经过训练的冻结图,我正试图在ARM设备上运行。基本上,我使用的是contrib/pi_examples/label_image,但不是Inception,而是我的网络。我的人际网络曾受过辍学训练,现在这给我带来了麻烦: Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered kernels: device='CPU'; T in [DT_FLOA

我有一个经过训练的冻结图,我正试图在ARM设备上运行。基本上,我使用的是contrib/pi_examples/label_image,但不是Inception,而是我的网络。我的人际网络曾受过辍学训练,现在这给我带来了麻烦:

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs.  Registered kernels:
  device='CPU'; T in [DT_FLOAT]
  device='CPU'; T in [DT_INT32]
  device='GPU'; T in [DT_STRING]
  device='GPU'; T in [DT_BOOL]
  device='GPU'; T in [DT_INT32]
  device='GPU'; T in [DT_FLOAT]

 [[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]]
我看到的一个解决方案是构建这样一个包含相应操作的TF静态库。从另一方面来说,为了使其更简单、更快,从网络中消除退出操作可能是一个更好的主意。有办法吗


谢谢。

作为一个更通用的解决方案:

#!/usr/bin/env python2

import argparse

import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2

def print_graph(input_graph):
    for node in input_graph.node:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

def strip(input_graph, drop_scope, input_before, output_after, pl_name):
    input_nodes = input_graph.node
    nodes_after_strip = []
    for node in input_nodes:
        print "{0} : {1} ( {2} )".format(node.name, node.op, node.input)

        if node.name.startswith(drop_scope + '/'):
            continue

        if node.name == pl_name:
            continue

        new_node = node_def_pb2.NodeDef()
        new_node.CopyFrom(node)
        if new_node.name == output_after:
            new_input = []
            for node_name in new_node.input:
                if node_name == drop_scope + '/cond/Merge':
                    new_input.append(input_before)
                else:
                    new_input.append(node_name)
            del new_node.input[:]
            new_node.input.extend(new_input)
        nodes_after_strip.append(new_node)

    output_graph = graph_pb2.GraphDef()
    output_graph.node.extend(nodes_after_strip)
    return output_graph

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--input-graph', action='store', dest='input_graph')
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary')
    parser.add_argument('--output-graph', action='store', dest='output_graph')
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True)

    args = parser.parse_args()

    input_graph = args.input_graph
    input_binary = args.input_binary
    output_graph = args.output_graph
    output_binary = args.output_binary

    if not tf.gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return

    input_graph_def = tf.GraphDef()
    mode = "rb" if input_binary else "r"
    with tf.gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read().decode("utf-8"), input_graph_def)

    print "Before:"
    print_graph(input_graph_def)
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl')
    print "After:"
    print_graph(output_graph_def)

    if output_binary:
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
    else:
        with tf.gfile.GFile(output_graph, "w") as f:
            f.write(text_format.MessageToString(output_graph_def))
    print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == "__main__":
    main()
for node in temp_graph_def.node:
    for idx, i in enumerate(node.input):
        input_clean = node_name_from_input(i)
        if input_clean.endswith('/cond/Merge') and input_clean.split('/')[-3].startswith('dropout'):
            identity = node_from_map(input_node_map, i).input[0]
            assert identity.split('/')[-1] == 'Identity'
            parent = node_from_map(input_node_map, node_from_map(input_node_map, identity).input[0])
            pred_id = parent.input[1]
            assert pred_id.split('/')[-1] == 'pred_id'            
            good = parent.input[0]
            node.input[idx] = good

您可以在文本编辑器中编辑
graph.pbtxt
,并去除掉的部分(即,用Identity op替换掉掉掉的部分)。脚本似乎删除了这些层,但是如果我删除了中间掉的部分,下一层就需要掉的输出张量。在我的情况下,当我尝试读取图中留下的层时,我收到一个错误:ValueError:graph_def在节点u'fc7/Conv2D处无效:在graph_def中找不到输入张量'dropout/mul_1:0'。如何在protobuf中更改输入层张量名称u'fc7/Conv2D?脚本也提供了该功能。。。非常好,谢谢。我设法从tensorflow.python.tools.optimize_for_expression_lib从_input导入节点名称,从_map找到了这些
,但是从哪里可以得到
input_node_map
?它应该是什么?好吧,在谷歌搜索之后,我想你是指这个