Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/351.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 在TensorFlow中导入GraphDef后设置py_func op_Python_Tensorflow - Fatal编程技术网

Python 在TensorFlow中导入GraphDef后设置py_func op

Python 在TensorFlow中导入GraphDef后设置py_func op,python,tensorflow,Python,Tensorflow,我有一个保存的图形定义,它是通过tf.train.import\u meta\u graph导入的。该图包含不可序列化的py_funcop。我是否可以定义python函数并将其分配给该op,而不必从头构建图形?这是可能的,但可能有点脆弱。特别是,pyfuncs需要按照它们在原始图形中定义的相同顺序重新定义(以便它们在原始图形中具有相同的标识符) 举个例子。我们可以定义一个包含py_func的图: import tensorflow as tf def my_py_func(x): retu

我有一个保存的图形定义,它是通过
tf.train.import\u meta\u graph
导入的。该图包含不可序列化的
py_func
op。我是否可以定义python函数并将其分配给该op,而不必从头构建图形?

这是可能的,但可能有点脆弱。特别是,pyfuncs需要按照它们在原始图形中定义的相同顺序重新定义(以便它们在原始图形中具有相同的标识符)

举个例子。我们可以定义一个包含py_func的图:

import tensorflow as tf

def my_py_func(x):
  return 13. * x + 2.

def train_model():
  with tf.Graph().as_default():
    some_input = tf.constant([[1., 2., 3., 4.],
                              [5., 6., 7., 8.]])
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
                               name="my_py_func")
    coefficient = tf.get_variable(
        "coefficient",
        shape=[])
    bias = tf.get_variable(
        "bias",
        shape=[])
    loss = tf.reduce_sum((coefficient * some_input + bias - after_py_func) ** 2)
    global_step = tf.contrib.framework.get_or_create_global_step()
    train_op = tf.group(tf.train.AdamOptimizer(0.1).minimize(loss),
                        tf.assign_add(global_step, 1))
    # Make it easy to retreive things we care about when the metagraph is reloaded.
    tf.add_to_collection('useful_ops', bias)
    tf.add_to_collection('useful_ops', coefficient)
    tf.add_to_collection('useful_ops', loss)
    tf.add_to_collection('useful_ops', train_op)
    tf.add_to_collection('useful_ops', global_step)
    tf.add_to_collection('useful_ops', some_input)
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as session:
      session.run(init_op)
      for i in range(5000):
        (_, evaled_loss, evaled_coefficient, evaled_bias,
         evaled_global_step) = session.run(
             [train_op, loss, coefficient, bias, global_step])
        if i % 1000 == 0:
          print(evaled_global_step, evaled_loss, evaled_coefficient,
                evaled_bias)
      saver.save(session, "./trained_pyfunc_model", global_step=global_step)
这将进行一些基本训练(与py_func中的线性函数相匹配):

如果我们在新的Python会话中尝试加载元图而不重新定义pyfunc,则会出现错误:

def load_model():
  with tf.Graph().as_default():
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
    #after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
    #                           name="my_py_func")
    with tf.Session() as session:
      saver.restore(session, "./trained_pyfunc_model-5000")
      (_, evaled_loss, evaled_coefficient, evaled_bias,
       evaled_global_step) = session.run(
           [train_op, loss, coefficient, bias, global_step])
      print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)
UnknownError(请参见上面的回溯):KeyError:“pyfunc_0”

但是,只要py_func的定义顺序相同且实现方式相同,我们就可以:

def load_model():
  with tf.Graph().as_default():
    saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
    bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
    after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
                               name="my_py_func")
    with tf.Session() as session:
      saver.restore(session, "./trained_pyfunc_model-5000")
      (_, evaled_loss, evaled_coefficient, evaled_bias,
       evaled_global_step) = session.run(
           [train_op, loss, coefficient, bias, global_step])
      print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)
这让我们可以继续训练,或者对恢复的模型执行任何其他操作:

Restored:  5001 1.77897e-09 13.0 2.00003
请注意,有状态py_funcs将更难处理:TensorFlow没有保存任何可能与之关联的Python变量

Restored:  5001 1.77897e-09 13.0 2.00003