Machine learning 用于直方图匹配的简单Pyrotch模型不会更新其参数

Machine learning 用于直方图匹配的简单Pyrotch模型不会更新其参数,machine-learning,computer-vision,pytorch,pytorch-lightning,Machine Learning,Computer Vision,Pytorch,Pytorch Lightning,我目前正在尝试拟合一个非常简单的模型,该模型基本上应该找到一个用于直方图匹配的最佳直方图。我编写了一个超级简单的模型,其中只包含一个我直接使用的参数对象: import torch.nn.functional as F class AutoHist(pl.LightningModule): def __init__(self, channel=1, bins=255): super().__init__() self.hist = torch.n

我目前正在尝试拟合一个非常简单的模型,该模型基本上应该找到一个用于直方图匹配的最佳直方图。我编写了一个超级简单的模型,其中只包含一个我直接使用的参数对象:

import torch.nn.functional as F

class AutoHist(pl.LightningModule):    
    def __init__(self, channel=1, bins=255):
        super().__init__()
        self.hist = torch.nn.Parameter(torch.rand((1, channel, bins), requires_grad=True))
        self.eps = 1e-5    

    def b_distance(self, h1, h2):
        distance = 1
        distance -= 1/(torch.sqrt(torch.mean(h1, axis=2)*torch.mean(h2, axis=2)*h1.size(2)**2))
        distance *= torch.sum(torch.sqrt(h1*h2 + self.eps),axis=2)
        return torch.sqrt(distance + self.eps)    

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        hist = self.hist / self.hist.sum()
        distances = self.b_distance(hist,x)
        loss = F.binary_cross_entropy(distances[:,0], y)
        return loss    

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
但是由于某些原因,backprop没有通过,参数也没有得到更新。有人知道问题出在哪里吗?梯度实际上是存在的,并且随着批次的变化而变化。我使用pytorch lightning删除样板代码,但关键应该在于我编写的代码