Python 如何在TensorFlow导入过程中更改输入维度

Python 如何在TensorFlow导入过程中更改输入维度,python,tensorflow,Python,Tensorflow,我的设想: 定义RNN模型结构,并使用具有固定批量大小和序列长度的输入对其进行训练 冻结模型(即将所有可训练变量转换为常数),生成一个GraphDef,包含测试时使用模型所需的一切(通过tf.graph\u util.convert\u variables\u to\u constants) 通过tf导入GraphDef。导入\u graph\u def并使用input\u map参数替换输入。新输入需要具有任意的批大小和序列长度 问题是:在我将输入传递到测试时间图之前,上述所有操作都有效,该测

我的设想:

  • 定义RNN模型结构,并使用具有固定批量大小和序列长度的输入对其进行训练
  • 冻结模型(即将所有可训练变量转换为常数),生成一个
    GraphDef
    ,包含测试时使用模型所需的一切(通过
    tf.graph\u util.convert\u variables\u to\u constants
  • 通过
    tf导入
    GraphDef
    。导入\u graph\u def
    并使用
    input\u map
    参数替换输入。新输入需要具有任意的批大小和序列长度
  • 问题是:在我将输入传递到测试时间图之前,上述所有操作都有效,该测试时间图使用的批大小或序列长度与培训时使用的原始大小不同。在这一点上,我得到了如下错误:

    InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,5] vs. shape[1] = [2,7]
         [[Node: import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](import/rnn/while/TensorArrayReadV3, import/rnn/while/Identity_2, import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat/axis)]]
    
    为了说明和重现问题,请考虑下面的最小例子。
    • v1
      :创建具有任意批量大小和序列长度的图形。这很好,但不幸的是,我必须在训练时使用固定的批大小和序列长度,并且必须在测试时使用任意的批大小和序列长度,所以我不能使用这种简单的方法
    • v2a
      :我们模拟创建具有固定批量大小(2)和序列长度(3)的训练时间图,并冻结该图
    • v2ba
      :我们证明在unchanged中加载冻结的模型仍然会产生相同的结果
    • v2bb
      :我们演示了用仍然使用固定批量大小和序列长度的替换输入加载冻结的模型仍然会产生相同的结果
    • v2bc
      :我们证明,只要输入是根据原始批次大小和序列长度成形的,使用使用任意批次大小和序列长度的替换输入加载冻结模型仍然会产生相同的结果。它与
      数据
      一起工作,但与
      数据2
      一起失败——唯一的区别是前者的批量大小为2,后者的批量大小为1
    是否可以通过
    input\u map
    参数将RNN图更改为
    tf.import\u graph\u def
    ,从而使输入不再具有固定的批大小和序列长度?

    以下代码适用于TensorFlow 1.1 RC2,也可能适用于TensorFlow 1.0

    import numpy
    import tensorflow as tf
    from tensorflow import graph_util as tf_graph_util
    from tensorflow.contrib import rnn as tfc_rnn
    
    
    def v1(data):
        with tf.Graph().as_default():
            tf.set_random_seed(1)
            x = tf.placeholder(tf.float32, shape=(None, None, 5))
            _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
    
            with tf.Session() as session:
                session.run(tf.global_variables_initializer())
                print session.run(s, feed_dict={x: data})
    
    
    def v2a():
        with tf.Graph().as_default():
            tf.set_random_seed(1)
            x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
            _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
    
            with tf.Session() as session:
                session.run(tf.global_variables_initializer())
                return tf_graph_util.convert_variables_to_constants(
                    session, session.graph_def, [s.op.name]), s.name
    
    
    def v2ba((graph_def, s_name), data):
        with tf.Graph().as_default():
            x, s = tf.import_graph_def(graph_def,
                                       return_elements=["x:0", s_name])
    
            with tf.Session() as session:
                print '2ba', session.run(s, feed_dict={x: data})
    
    
    def v2bb((graph_def, s_name), data):
        with tf.Graph().as_default():
            x = tf.placeholder(tf.float32, shape=(2, 3, 5))
            [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                      return_elements=[s_name])
    
            with tf.Session() as session:
                print '2bb', session.run(s, feed_dict={x: data})
    
    
    def v2bc((graph_def, s_name), data):
        with tf.Graph().as_default():
            x = tf.placeholder(tf.float32, shape=(None, None, 5))
            [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                      return_elements=[s_name])
    
            with tf.Session() as session:
                print '2bc', session.run(s, feed_dict={x: data})
    
    
    def main():
        data1 = numpy.random.random_sample((2, 3, 5))
        data2 = numpy.random.random_sample((1, 3, 5))
        v1(data1)
        model = v2a()
        v2ba(model, data1)
        v2bb(model, data1)
        v2bc(model, data1)
        v2bc(model, data2)
    
    
    if __name__ == "__main__":
        main()
    

    这是tensorflow中的一个bug,已经存在了一段时间:您无法可靠地用一个已定义形状的占位符替换另一个(部分)未定义形状的占位符

    你会发现一个相关的问题,显然没有得到太多的关注