Pytorch 二进制分类-BCELoss和模型输出大小不对应

Pytorch 二进制分类-BCELoss和模型输出大小不对应,pytorch,Pytorch,我正在进行二元分类,因此我使用了二元交叉熵损失: criterion = torch.nn.BCELoss() 但是,我得到了一个错误: Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 2])) is deprecated. Please ensure they have the same size. 我的模型以以下内容结尾: x = self

我正在进行二元分类,因此我使用了二元交叉熵损失:

criterion = torch.nn.BCELoss()
但是,我得到了一个错误:

Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 2])) is deprecated. Please ensure they have the same size.
我的模型以以下内容结尾:

    x = self.wave_block6(x)
    x = self.sigmoid(self.fc(x))
    return x.squeeze()

我试着消除挤压,但没有效果。我的批量是64。我好像做错了什么。我的模型是否有1个输出,BCE损耗是否有2个输入?那么我应该使用哪种损失?

二进制交叉熵损失(BCELoss)用于二进制分类任务。因此,如果N是批次大小,则模型输出应为形状
[64,1]
,标签必须为形状
[64]
。因此,只需在第二维度压缩输出,并将其传递给损失函数- 下面是一个简单的工作示例

import torch
a = torch.randn((64, 1))
b = torch.randn((64))
loss = torch.nn.BCELoss()

b = torch.round(torch.sigmoid(b)) # just to create some labels
a = torch.sigmoid(a).squeeze(1)
l = loss(a, b)
更新-根据评论中的对话,焦点丢失的定义如下-

class focalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=3):
        super(focalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred_logits: torch.Tensor, target: torch.Tensor):
        batch_size = pred_logits.shape[0]
        pred = pred.view(batch_size, -1)
        target = target.view(batch_size, -1)
        pred = pred_logits.sigmoid()
        ce = F.binary_cross_entropy(pred_logits, target, reduction='none')
        alpha = target * self.alpha + (1. - target) * (1. - self.alpha)
        pt = torch.where(target == 1, pred, 1 - pred)
        return alpha * (1. - pt) ** self.gamma * ce

你的模型应该有单一输出…尝试包含更多信息这是否回答了你的问题?仍然得到相同的错误。也许我的colab没有从本地驱动器刷新代码。我还注意到,在培训中,这被称为:loss=criteria(output,target.long())对于BCELoss,目标和输入都应为float类型,因此在解决此错误后,您将得到该错误,请注意,CrossEntropy loss函数的目标类型为long,所以,也许最初的代码是为CrossEntropy编写的loss@dorien,回到手头的问题,你能告诉我你的标签形状(即你的目标)和你的模型输出形状吗?你在问题中提到它是'64,2',但是你使用的是BCELoss,对于一个二进制分类问题,你只需要输出一个值,而不是2。我在原始损失中使用的是FocalLoss,我正在尝试适应BCELoss。让我运行打印形状(抱歉,这需要一段时间,大数据集),尝试在此时运行本地。此外,我将添加一个焦距损失定义,您可以使用开箱即用