Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/video/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Tensorflow数据集API中的条件语句_Tensorflow - Fatal编程技术网

Tensorflow数据集API中的条件语句

Tensorflow数据集API中的条件语句,tensorflow,Tensorflow,我已经使用Tensorflow数据集API构建了一个数据管道,但我希望一些操作(比如洗牌)取决于我是在训练数据集上迭代还是在测试数据集。我想知道是否有办法在dataset API管道中使用条件语句?我尝试了下面的代码,但它说它无法将类型为ShuffleDataset的对象转换为张量 # This is the placeholder I feed with proper file name depending on whether I'm training or testing filename

我已经使用Tensorflow数据集API构建了一个数据管道,但我希望一些操作(比如洗牌)取决于我是在训练数据集上迭代还是在测试数据集。我想知道是否有办法在dataset API管道中使用条件语句?我尝试了下面的代码,但它说它无法将类型为
ShuffleDataset
的对象转换为张量

# This is the placeholder I feed with proper file name depending on whether I'm training or testing
filenames_placeholder = tf.placeholder(tf.string, shape = (None), name = 'filenames_placeholder')

# This it the placeholder I would like to feed with True/False to influence shuffling
shuffle = tf.placeholder(tf.bool, shape = (None), name = 'shuffle')

dataset = tf.data.TFRecordDataset(self.filenames_placeholder)
dataset = dataset.map(lambda x: parse(x), num_parallel_calls = 4)

# The following does not work
def shuffle_true():
    return dataset.shuffle(buffer_size = 1024)
def shuffle_false():
    return dataset
dataset = tf.cond(self.shuffle, shuffle_true, shuffle_false)

你可以定义一个函数

def tr_input_fn(filename, mode):
    dataset = tf.data.TFRecordDataset(filename)
    if mode == 'Train':
        dataset = dataset.shuffle()
        dataset = dataset.map(map_func)
        return dataset
    return dataset
据我所知,dataset api中现在有显式的条件语句。

您可以定义一个函数

def tr_input_fn(filename, mode):
    dataset = tf.data.TFRecordDataset(filename)
    if mode == 'Train':
        dataset = dataset.shuffle()
        dataset = dataset.map(map_func)
        return dataset
    return dataset
据我所知,dataset api中现在有显式的条件语句。