Python tensorflow中不平衡数据集的二次采样

Python tensorflow中不平衡数据集的二次采样,python,tensorflow,tensorflow-datasets,Python,Tensorflow,Tensorflow Datasets,这里是Tensorflow初学者。这是我的第一个项目,我正在使用预定义的估计器 我有一个非常不平衡的数据集,其中正面结果约占总数据的0.1%,我怀疑这种不平衡会极大地影响我的模型的性能。作为解决这个问题的第一次尝试,由于我有大量的数据,我想扔掉我的大部分负面信息,以便创建一个平衡的数据集。我可以看到两种方法:对数据进行预处理,只保留千分之一的底片,然后将其保存在新文件中,然后再将其传递给tensorflow,例如pyspark;让tensorflow只使用它找到的一千个负片中的一个 我试图编写最

这里是Tensorflow初学者。这是我的第一个项目,我正在使用预定义的估计器

我有一个非常不平衡的数据集,其中正面结果约占总数据的0.1%,我怀疑这种不平衡会极大地影响我的模型的性能。作为解决这个问题的第一次尝试,由于我有大量的数据,我想扔掉我的大部分负面信息,以便创建一个平衡的数据集。我可以看到两种方法:对数据进行预处理,只保留千分之一的底片,然后将其保存在新文件中,然后再将其传递给tensorflow,例如pyspark;让tensorflow只使用它找到的一千个负片中的一个

我试图编写最后一个想法,但没有成功。我修改了我的输入函数,读起来像

def train_input_fn(data_file="../data/train_input.csv", shuffle_size=100_000, batch_size=128):
    """Generate an input function for the Estimator."""

    dataset = tf.data.TextLineDataset(data_file)  # Extract lines from input files using the Dataset API.
    dataset = dataset.map(parse_csv, num_parallel_calls=3)
    dataset = dataset.shuffle(shuffle_size).repeat().batch(batch_size)

    iterator = dataset.make_one_shot_iterator()
    features, labels = iterator.get_next()

    # TRY TO IMPLEMENT THE SELECTION OF NEGATIVES
    thrown = 0
    flag = np.random.randint(1000)
    while labels == 0 and flag != 0:
        features, labels = iterator.get_next()
        thrown += 1
        flag = np.random.randint(1000)
    print("I've thrown away {} negative examples before going for label {}!".format(thrown, labels))
    return features, labels
当然,这不起作用,因为迭代器不知道其中包含什么,所以永远不会满足labels==0条件。另外,stdout中只有一个打印,这意味着这个函数只被调用一次(这意味着我仍然不理解tensorflow是如何工作的)。不管怎样,有没有办法实现我想要的


PS:我怀疑之前的代码,即使它按预期工作,也会返回不到初始负数的千分之一,因为每次它发现一个正数时计数都会重新开始。这是一个小问题,到目前为止,我甚至可以在标志中找到一个神奇的数字,它给出了我期望的结果,而不用太担心它的数学美。

如果对代表性不足的类进行过采样,而不是在代表性过高的类中丢弃数据,可能会得到更好的结果。这样,您可以在过度表示的类中保持差异。你也可以使用你拥有的数据

实现这一点的最简单方法可能是创建两个数据集,每个类一个。然后,您可以使用
Dataset.interleave
从两个数据集中平均采样


非常感谢您的回答。关于
数据集,有一点我不清楚。交错
:我必须创建两个单独的文件,一个有积极的结果,一个有消极的结果,对吗?这听起来是一个合理的方法,但没有必要。首先需要创建两个单独的Dataset对象,每个类一个。如何做到这一点取决于你自己。单独的文件听起来很容易,但是您可能会找到一种方法从每个数据集中过滤掉不需要的类<代码>数据集。交错
要求传入多个数据集,它只需依次从每个数据集中采样一个值,并将其作为自己的数据集返回。因此,它会自动完成平衡类的工作。确保将
.shuffle
分别添加到每个类的数据集中。