Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/289.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python Tensorflow:从图形中删除节点_Python_Tensorflow - Fatal编程技术网

Python Tensorflow:从图形中删除节点

Python Tensorflow:从图形中删除节点,python,tensorflow,Python,Tensorflow,我正在尝试从图中删除一些节点并将其保存在.pb中 只有所需的节点才能添加到新的mod_graph\u defgraph,但该图在其他节点输入中仍有一些对已删除节点的引用,但我无法修改节点的输入: def delete_ops_from_graph(): with open(input_model_filepath, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read(

我正在尝试从图中删除一些节点并将其保存在.pb中

只有所需的节点才能添加到新的
mod_graph\u def
graph,但该图在其他节点输入中仍有一些对已删除节点的引用,但我无法修改节点的输入:

def delete_ops_from_graph():
    with open(input_model_filepath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    nodes = []
    for node in graph_def.node:
        if 'Neg' in node.name:
            print('Drop', node.name)
        else:
            nodes.append(node)

    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)

    # The problem that graph still have some references to deleted node in other nodes inputs
    for node in mod_graph_def.node:
        inp_names = []
        for inp in node.input:
            if 'Neg' in inp:
                pass
            else:
                inp_names.append(inp)

        node.input = inp_names # TypeError: Can't set composite field

    with open(output_model_filepath, 'wb') as f:
        f.write(mod_graph_def.SerializeToString())

前面的答案很好,但我建议将删除的节点输入绑定到下一个节点输入。比如,如果我们有一个链
a-input b->b-input c->c-input d->d
,并且要删除say节点
b
,那么我们不应该只删除
input c
,而是用
input b
替换它。 请看下面的代码:

#  remove node and connect its input to follower
def remove_node(graph_def, node_name, input_name):
    nodes = []
    for node in graph_def.node:
        if node.name == node_name:
            assert(input_name in node.input or len(node.input) == 0),\
                "Node input to use is not among inputs of node to remove"
            input_of_removed_node = input_name if len(node.input) else ''
            print("Removing {} and using its input {}".format(node.name, 
                   input_of_removed_node))
            continue
        nodes.append(node)
    
    # modify inputs where required
    # removed name must be replaced with input from removed node
    for node in nodes:
        inp_names = []
        replace = False
        for inp in node.input:
            if inp == node_name:
                inp_names.append(input_of_removed_node)
                print("For node {} replacing input {} 
                       with {}".format(node.name, inp, input_of_removed_node))
                replace = True
            else:
                inp_names.append(inp)
        if replace:
            del node.input[:]
            node.input.extend(inp_names)
    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)
    return mod_graph_def

前面的答案很好,但我建议将删除的节点输入绑定到下一个节点输入。比如,如果我们有一个链
a-input b->b-input c->c-input d->d
,并且要删除say节点
b
,那么我们不应该只删除
input c
,而是用
input b
替换它。 请看下面的代码:

#  remove node and connect its input to follower
def remove_node(graph_def, node_name, input_name):
    nodes = []
    for node in graph_def.node:
        if node.name == node_name:
            assert(input_name in node.input or len(node.input) == 0),\
                "Node input to use is not among inputs of node to remove"
            input_of_removed_node = input_name if len(node.input) else ''
            print("Removing {} and using its input {}".format(node.name, 
                   input_of_removed_node))
            continue
        nodes.append(node)
    
    # modify inputs where required
    # removed name must be replaced with input from removed node
    for node in nodes:
        inp_names = []
        replace = False
        for inp in node.input:
            if inp == node_name:
                inp_names.append(input_of_removed_node)
                print("For node {} replacing input {} 
                       with {}".format(node.name, inp, input_of_removed_node))
                replace = True
            else:
                inp_names.append(inp)
        if replace:
            del node.input[:]
            node.input.extend(inp_names)
    mod_graph_def = tf.GraphDef()
    mod_graph_def.node.extend(nodes)
    return mod_graph_def