Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/tensorflow/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 默认情况下tf.Dataset.batch是否预加载以及如何禁用?_Python_Tensorflow_Tensorflow Datasets - Fatal编程技术网

Python 默认情况下tf.Dataset.batch是否预加载以及如何禁用?

Python 默认情况下tf.Dataset.batch是否预加载以及如何禁用?,python,tensorflow,tensorflow-datasets,Python,Tensorflow,Tensorflow Datasets,当使用tf.Dataset.batch时,get_next()将在调用时预加载一些数据。看起来有一个后台线程正在执行此操作。有没有办法禁用它 复制代码块: import tensorflow as tf def pr(x): print(x) return x dataset = tf.data.Dataset.range(10000) dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64])) dat

当使用
tf.Dataset.batch
时,
get_next()
将在调用时预加载一些数据。看起来有一个后台线程正在执行此操作。有没有办法禁用它

复制代码块:

import tensorflow as tf

def pr(x):
    print(x)
    return x


dataset = tf.data.Dataset.range(10000)
dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))

dataset = dataset.batch(3)

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    next_element = iterator.get_next()

    for i in range(2):
        fetches = sess.run(next_element)
        print(fetches)
不稳定的样本输出如下所示:

0
1
2
3
(array([0, 1, 2]),)
4
5
6
(array([3, 4, 5]),)
7
8
我希望确定性输出为:

0
1
2
(array([0, 1, 2]),)
3
4
5
(array([3, 4, 5]),)

由于giser_yugang的评论,CPU模式下的环境是OSX+python3.7.2+tensorflow1.13.1。我从1.13的变更日志中找到了一些提示。()

设置dataset选项可在1.13中解决此问题


import tensorflow as tf

def pr(x):
    print(x)
    return x


dataset = tf.data.Dataset.range(10000)

options = tf.data.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)

dataset = dataset.map(lambda x: tf.py_func(pr, [x], [tf.int64]))

dataset = dataset.batch(3)

iterator = dataset.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    next_element = iterator.get_next()

    for i in range(2):
        fetches = sess.run(next_element)
        print(fetches)

我运行的代码的返回是您期望的结果。谢谢您的评论。你的环境是什么?我的是OS X+python3.7.2+tensorflow1.13.1+CPU模式,添加到post.Ubuntu16.04+python3.6+tensorflow1.12+GPUWeird。我找到一个和你的一样的服务器,它按预期打印。看起来像是1.13中的修改。