Python 如何在每次迭代中仅从一个类中对批进行采样

Python 如何在每次迭代中仅从一个类中对批进行采样,python,tensorflow,Python,Tensorflow,我想在一个ImageNet数据集上训练一个分类器(1000个类,每个类大约有1300个图像)。出于某种原因,我需要每个批包含来自同一类的64个图像,以及来自不同类的连续批。最新的TensorFlow是否可行(且有效) tf 1.9中的tf.contrib.data.sample\u from\u Dataset允许从tf.data.Dataset对象列表中进行采样,其中权重表示概率。我想知道以下想法是否有意义: 将每个类的数据保存为单独的tfrecord文件 将tf.data.Dataset.

我想在一个ImageNet数据集上训练一个分类器(1000个类,每个类大约有1300个图像)。出于某种原因,我需要每个批包含来自同一类的64个图像,以及来自不同类的连续批。最新的TensorFlow是否可行(且有效)

tf 1.9中的
tf.contrib.data.sample\u from\u Dataset
允许从
tf.data.Dataset
对象列表中进行采样,其中
权重表示概率。我想知道以下想法是否有意义:

  • 将每个类的数据保存为单独的tfrecord文件
  • tf.data.Dataset.from_generator
    对象作为
    权重传递。对象从分类分布中采样,使每个样本看起来像999
    0
    s和1
    1
  • 创建1000个
    tf.data.Dataset
    对象,每个对象链接一个tfrecord文件
我想,通过这种方式,也许在每次迭代中,
sample\u from\u datasets
将首先对稀疏权重向量进行采样,该向量指示从哪个
tf.data.Dataset
进行采样,然后从该类中进行相同的采样

对吗?还有其他有效的方法吗

更新

正如p-Gn善意建议的那样,从一个类别中抽取数据的一种方法是:

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(some_parser_fun)  # parse one datum from tfrecord
dataset = dataset.shuffle(buffer_size)

if sample_same_class:
    group_fun = tf.contrib.data.group_by_window(
        key_func=lambda data_x, data_y: data_y,
        reduce_func=lambda key, d: d.batch(batch_size),
        window_size=batch_size)
    dataset = dataset.apply(group_fun)
else:
    dataset = dataset.batch(batch_size)

dataset = dataset.repeat()
data_batch = dataset.make_one_shot_iterator().get_next()

后续问题可在

中找到,如果我理解正确,我认为您的解决方案不可行,因为来自_数据集
样本_需要其权重的值列表,而不是张量

但是,如果您不介意在建议的解决方案中使用1000个
Dataset
s,那么我建议您

  • 为每个类创建一个
    数据集
  • batch
    这些数据集中的每一个-每一个批次都有来自单个类别的样本
  • zip
    将它们全部放入一个大的
    Dataset
    批次中
  • shuffle
    数据集
    -洗牌将发生在批次上,而不是样本上,因此它不会改变批次是单个类的事实
一个更复杂的方法是依靠。让我用一个综合例子来说明这一点

import numpy as np
import tensorflow as tf

def gen():
  while True:
    x = np.random.normal()
    label = np.random.randint(10)
    yield x, label

batch_size = 4
batch = (tf.data.Dataset
  .from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
  .apply(tf.contrib.data.group_by_window(
    key_func=lambda x, label: label,
    reduce_func=lambda key, d: d.batch(batch_size),
    window_size=batch_size))
  .make_one_shot_iterator()
  .get_next())

sess = tf.InteractiveSession()
sess.run(batch)
# (array([ 0.04058843,  0.2843775 , -1.8626076 ,  1.1154234 ], dtype=float32),
# array([6, 6, 6, 6], dtype=int64))
sess.run(batch)
# (array([ 1.3600663,  0.5935658, -0.6740045,  1.174328 ], dtype=float32),
# array([3, 3, 3, 3], dtype=int64))

非常感谢您的回复。我可以确认到目前为止,tf.contrib.data.groupby_窗口在cifar10数据集上工作正常。。。我将做更多的测试,看看它在ImageNet上是否有效。