Pytorch中类不平衡的多标签分类

Pytorch中类不平衡的多标签分类,pytorch,multilabel-classification,imbalanced-data,Pytorch,Multilabel Classification,Imbalanced Data,我有一个多标签分类问题,我正试图用Pytorch中的CNN解决这个问题。我有80000个培训示例和7900个课程;每个示例可以同时属于多个类,每个示例的平均类数为130 问题是我的数据集非常不平衡。对于某些类,我只有约900个示例,约为1%。对于“人数过多”的课程,我有大约12000个例子(15%)。当我训练模型时,我使用带有正权重参数的BCEWithLogitsLoss from。我计算权重的方法与文档中描述的相同:负面示例数除以正面示例数 结果,我的模型几乎高估了每一门课……我得到的预测几乎

我有一个多标签分类问题,我正试图用Pytorch中的CNN解决这个问题。我有80000个培训示例和7900个课程;每个示例可以同时属于多个类,每个示例的平均类数为130

问题是我的数据集非常不平衡。对于某些类,我只有约900个示例,约为1%。对于“人数过多”的课程,我有大约12000个例子(15%)。当我训练模型时,我使用带有正权重参数的BCEWithLogitsLoss from。我计算权重的方法与文档中描述的相同:负面示例数除以正面示例数

结果,我的模型几乎高估了每一门课……我得到的预测几乎是真实标签的两倍。我的AUPRC只有0.18。尽管这比根本不加权要好得多,因为在这种情况下,模型预测所有东西都为零


所以我的问题是,如何提高性能?还有什么我能做的吗?我尝试了不同的批量抽样技术(对少数群体进行过抽样),但它们似乎不起作用。

我建议采用以下任一策略

焦点损失
中介绍了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法 Tsung Yi Lin、Priya Goyal、Ross Girshick、He Kaiming和Piotr Dollar(ICCV 2017)。
他们建议修改二元交叉熵损失,以减少易分类示例的损失和梯度,同时“集中精力”处理模型存在严重错误的示例

硬负开采 另一种流行的方法是进行“硬负挖掘”;也就是说,只为部分训练示例传播梯度——“硬”示例。
参见,例如:

Abhinav Shrivastava、Abhinav Gupta和Ross Girshick(CVPR 2016)

@Shai提供了在深度学习时代开发的两种策略。我想为您提供一些额外的传统机器学习选项:过采样欠采样

它们的主要思想是在开始训练之前通过采样生成一个更平衡的数据集。请注意,您可能会面临一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点


有关更多信息,请参阅。

在这种情况下,焦点丢失可能不是一个好的选择,因为它有7900个等级。有太多的超参数需要微调。@zihaozhihao确实很棘手。但我会尝试在所有类中使用相同的gamma。