Python 为什么自定义只读操作在测试会话上有效

Python 为什么自定义只读操作在测试会话上有效,python,tensorflow,Python,Tensorflow,我在tensorflow中编写了一个自定义内核操作,用于读取csv格式的数据 它在TestCase中工作良好,通过test_session函数返回sess对象 当我转到普通代码时,读卡器op每次都返回相同的结果。然后,我在MyOp:Compute函数的开头添加了一些调试打印。似乎在第一次运行之后,sess.runmyop根本不会调用MyOp:Compute函数 然后我返回到我的测试用例,如果我用tf.session而不是self.test_session替换session对象,它也会以同样的方式

我在tensorflow中编写了一个自定义内核操作,用于读取csv格式的数据

它在TestCase中工作良好,通过test_session函数返回sess对象

当我转到普通代码时,读卡器op每次都返回相同的结果。然后,我在MyOp:Compute函数的开头添加了一些调试打印。似乎在第一次运行之后,sess.runmyop根本不会调用MyOp:Compute函数

然后我返回到我的测试用例,如果我用tf.session而不是self.test_session替换session对象,它也会以同样的方式失败

有人知道这件事吗

要分享更多详细信息,以下是我的迷你演示代码:

在测试用例中:

def testSimple(self):
  input_data_schema, feas, batch_size = self.get_simple_format()
  iter_op = ops.csv_iter('./sample_data.txt', input_data_schema, feas, batch_size=batch_size, label='label2')
  with self.test_session() as sess:
    label,sign = sess.run(iter_op)
    print label

    self.assertAllEqual(label.shape, [batch_size])
    self.assertAllEqual(sign.shape, [batch_size, len(feas)])
    self.assertAllEqual(sum(label), 2)
    self.assertAllEqual(sign[0,:], [7,0,4,1,1,1,5,9,8])

    label,sign = sess.run(iter_op)
    self.assertAllEqual(label.shape, [batch_size])
    self.assertAllEqual(sign.shape, [batch_size, len(feas)])
    self.assertAllEqual(sum(label), 1)
    self.assertAllEqual(sign[0,:], [9,9,3,1,1,1,5,4,8])
正常通话:

def testing_tf():
    path = './sample_data.txt'
    input_data_schema, feas, batch_size = get_simple_format()
    with tf.device('/cpu:0'):
        n_data_op = tf.placeholder(dtype=tf.float32)
        iter_op = ops.csv_iter(path, input_data_schema, feas, batch_size=batch_size, label='label2') 
        init_op = [tf.global_variables_initializer(), tf.local_variables_initializer() ]

    with tf.Session() as sess:
      sess.run(init_op)
      n_data = 0
      for batch_idx in range(3):
        print '>>>>>>>>>>>>>> before run batch', batch_idx
        ## it should be some debug printing here, but nothing come out when batch_idx>0
        label,sign = sess.run(iter_op)
        print '>>>>>>>>>>>>>> after run batch', batch_idx
        ## the content of sign remain the same every time
        print sign
        if len(label) == 0:
          break
查看tf.test.TestCase.test_会话的配置可以提供一些线索,因为它对会话的配置与对tf.session的直接调用略有不同。特别是,测试持续的折叠优化。默认情况下,TensorFlow会将图形的无状态部分转换为tf.constant节点,因为每次运行它们时,它们都会产生相同的结果

在注册CsvIter op时,有一个on-SetIsStateful注释,因此TensorFlow将把它视为无状态的,因此需要不断折叠。然而,它的实现是非常有状态的:一般来说,您希望使用相同的输入张量产生不同结果的任何操作,或者在成员变量中存储可变状态的任何操作,都应该标记为有状态

解决方案是对CsvIter的寄存器_OP进行一行更改:


你能分享具体的代码来更清楚地理解吗?请尝试分享这个问题的一个最简单的工作示例。如果没有说明您的问题的代码,我们就不可能提供帮助。@Engineero我已更新了最低工作代码,您想检查一下吗?@Salikaragoz我已更新了最低工作代码,您想检查一下吗?
REGISTER_OP("CsvIter")
    .Input("data_file: string")
    .Output("labels: float32")
    .Output("signs: int64")
    .Attr("input_schema: list(string)")
    .Attr("feas: list(string)")
    .Attr("label: string = 'label' ")
    .Attr("batch_size: int = 10000")
    .SetIsStateful();  // Add this line.