Neural network 焦损实现
在介绍焦点损耗的章节中,他们指出损耗函数的公式如下: 在哪里 我在Github页面上找到了它的一个实现,这是另一位作者在他们的文章中使用的。我在我拥有的一个分割问题数据集上试用了这个函数,它似乎运行得很好 实施情况如下:Neural network 焦损实现,neural-network,pytorch,conv-neural-network,probability,loss-function,Neural Network,Pytorch,Conv Neural Network,Probability,Loss Function,在介绍焦点损耗的章节中,他们指出损耗函数的公式如下: 在哪里 我在Github页面上找到了它的一个实现,这是另一位作者在他们的文章中使用的。我在我拥有的一个分割问题数据集上试用了这个函数,它似乎运行得很好 实施情况如下: def binary_focal_loss(pred, truth, gamma=2., alpha=.25): eps = 1e-8 pred = nn.Softmax(1)(pred) truth = F.one_hot(truth, num_c
def binary_focal_loss(pred, truth, gamma=2., alpha=.25):
eps = 1e-8
pred = nn.Softmax(1)(pred)
truth = F.one_hot(truth, num_classes = pred.shape[1]).permute(0,3,1,2).contiguous()
pt_1 = torch.where(truth == 1, pred, torch.ones_like(pred))
pt_0 = torch.where(truth == 0, pred, torch.zeros_like(pred))
pt_1 = torch.clamp(pt_1, eps, 1. - eps)
pt_0 = torch.clamp(pt_0, eps, 1. - eps)
out1 = -torch.mean(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1))
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
return out1 + out0
我不理解的部分是pt_0和pt_1的计算。我为自己创建了一个小例子来尝试解决这个问题,但它仍然让我有点困惑
# one hot encoded prediction tensor
pred = torch.tensor([
[
[.2, .7, .8], # probability
[.3, .5, .7], # of
[.2, .6, .5] # background class
],
[
[.8, .3, .2], # probability
[.7, .5, .3], # of
[.8, .4, .5] # class 1
]
])
# one-hot encoded ground truth labels
truth = torch.tensor([
[1, 0, 0],
[1, 1, 0],
[1, 0, 0]
])
truth = F.one_hot(truth, num_classes = 2).permute(2,0,1).contiguous()
print(truth)
# gives me:
# tensor([
# [
# [0, 1, 1],
# [0, 0, 1],
# [0, 1, 1]
# ],
# [
# [1, 0, 0],
# [1, 1, 0],
# [1, 0, 0]
# ]
# ])
pt_0 = torch.where(truth == 0, pred, torch.zeros_like(pred))
pt_1 = torch.where(truth == 1, pred, torch.ones_like(pred))
print(pt_0)
# gives me:
# tensor([[
# [0.2000, 0.0000, 0.0000],
# [0.3000, 0.5000, 0.0000],
# [0.2000, 0.0000, 0.0000]
# ],
# [
# [0.0000, 0.3000, 0.2000],
# [0.0000, 0.0000, 0.3000],
# [0.0000, 0.4000, 0.5000]
# ]
# ])
print(pt_1)
# gives me:
# tensor([[
# [1.0000, 0.7000, 0.8000],
# [1.0000, 1.0000, 0.7000],
# [1.0000, 0.6000, 0.5000]
# ],
# [
# [0.8000, 1.0000, 1.0000],
# [0.7000, 0.5000, 1.0000],
# [0.8000, 1.0000, 1.0000]
# ]
# ])
我不明白的是,为什么在pt_0中,我们把0放在火炬的位置,而在pt_1中,我们把1放在火炬的位置。从我对这篇论文的理解来看,我本以为你应该把1-p放在这里,而不是放0或1
有人能帮我解释一下吗?所以你想理解的部分是当人们想把不需要的额外计算归零时通常会做的一个过程 再看看
pt
的公式:
以下代码正是通过分离这两个条件来实现这一点的:
# if y=1
pt_1 = torch.where(truth == 1, pred, torch.ones_like(pred))
# otherwise
pt_0 = torch.where(truth == 0, pred, torch.zeros_like(pred))
如果在pt_0
中设置为零,在pt_1
中设置为一,则输出为零,因此对贡献损失值没有影响,即:
# Because pow(0., gamma) == 0. and log(1.) == 0.
# out1 == 0. if pt_1 == 1.
out1 = -torch.mean(alpha * torch.pow(1. - pt_1, gamma) * torch.log(pt_1))
# out0 == 0. if pt_0 == 0.
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
而pt_0
使用p
值而不是1-p
的原因与您上一个问题的原因相同,即:
1 - (1 - p) == 1 - 1 + p == p
因此,它可以通过以下方式计算FL(pt)
:
# -a * pow(1 - (1 - p), gamma )* log(1 - p) == -a * pow(p, gamma )* log(1 - p)
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0, gamma) * torch.log(1. - pt_0))
好的,我现在看到了。谢谢你帮我回答这个问题和最后一个问题:)@SteveAhlswede没问题,很乐意帮忙:)祝你过得愉快