Python 批量迭代在Tensorflow中是如何工作的?

Python 批量迭代在Tensorflow中是如何工作的?,python,multithreading,tensorflow,machine-learning,queue,Python,Multithreading,Tensorflow,Machine Learning,Queue,我试图在我的数据上重用Tensorflow,但缺乏Tensorflow的知识,无法理解它如何处理训练数据上的批量迭代。以下是我如何理解培训期间的批处理迭代: while epoch <= maxepoch do for minibatch in data_iterator() do model.forward(minibatch) (...) end end 此函数在调用后返回x输入和y目标。我在这里没有看到Python迭代器的迹象,但是有一个对tf.stridd

我试图在我的数据上重用Tensorflow,但缺乏Tensorflow的知识,无法理解它如何处理训练数据上的批量迭代。以下是我如何理解培训期间的批处理迭代:

while epoch <= maxepoch do
  for minibatch in data_iterator() do
    model.forward(minibatch)
    (...)
  end
end
此函数在调用后返回
x
输入和
y
目标。我在这里没有看到Python迭代器的迹象,但是有一个对
tf.stridded\u slice
的调用,它使用
tf.train.range\u input\u producer
生成的
I
索引,因此这应该模拟数据上的滑动窗口。但是,该函数在训练之前只被调用一次,那么它如何迭代我的数据呢?这还不清楚。有人能解释一下这个“魔法”和完全模糊的张量流机制吗?

这个“魔法”隐藏在一行调用:

i=tf.train.range\u input\u producer(历元大小,shuffle=False).dequeue()
。。。这将创建一个op,该op从队列中弹出值,保存
0..epoch\u size-1
整数。换句话说,它在
0..epoch\u size-1
范围内迭代


是的,这似乎违反直觉。下面是一个在tensorflow中处理队列的简单可运行示例:

index=tf.train.range\u input\u producer(10,shuffle=False).dequeue()
使用tf.Session()作为sess:
coord=tf.train.Coordinator()
线程=tf.train.start\u queue\u runner(coord=coord)
对于范围(15)内的i:
打印(sess.run(索引))
协调请求停止()
坐标连接(线程)
运行时,您应该看到从
0
9
的值,然后再从
0
4
的值。请注意,
sess.run
计算相同的张量
索引
,但每次都得到不同的值。可以添加更多依赖于
索引
的操作,并使用新的
索引
值对其进行评估

还要注意,队列在另一个线程中运行,因此为了使用
tf.train.range\u input\u producer
,必须启动
协调器并生成多个线程(最后停止它们)。如果尝试在没有协调器的情况下运行同一示例,则
sess.run(index)
将阻止脚本执行

您可以使用此示例,例如,设置
shuffle=True
,等等


返回到PTB生产者代码段:

i=tf.train.range\u input\u producer(历元大小,shuffle=False).dequeue()
x=tf.跨步切片(数据,[0,i*num\u步数],[batch\u size,(i+1)*num\u步数])
x、 设置形状([批次大小,数量步数])
y=tf.跨步切片(数据,[0,i*num\u步数+1],[batch\u size,(i+1)*num\u步数+1])
y、 设置形状([批次大小,数量步数])
现在应该很清楚,尽管
x
y
被定义为简单张量,但它们实际上是
数据片上的迭代器。所有的螺纹工作都由工程师负责。因此,调用优化操作(取决于
x
y
)将自动获取新批次


建议如下:

  • Tensorflow教程
def ptb_producer(raw_data, batch_size, num_steps, name=None):
    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
        raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)

        data_len = tf.size(raw_data)
        batch_len = data_len // batch_size
        data = tf.reshape(raw_data[0 : batch_size * batch_len],
                                            [batch_size, batch_len])

        epoch_size = (batch_len - 1) // num_steps
        assertion = tf.assert_positive(
                epoch_size,
                message="epoch_size == 0, decrease batch_size or num_steps")
        with tf.control_dependencies([assertion]):
            epoch_size = tf.identity(epoch_size, name="epoch_size")

        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
        x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps])
        x.set_shape([batch_size, num_steps])
        y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1])
        y.set_shape([batch_size, num_steps])
        return x, y