使用transforms.LinearTransformation在PyTorch中应用白化

使用transforms.LinearTransformation在PyTorch中应用白化,pytorch,Pytorch,我需要在PyTorch中使用ZCA美白。我想我已经找到了一种方法,可以通过使用transforms.LinearTransformation来实现这一点,并且我在PyTorch repo中找到了一个测试,它提供了一些关于如何实现这一点的见解,请参见下面的最终代码块或链接 我正在努力弄清楚我自己是如何应用这样的东西的 目前,我进行了以下几方面的转换: transform_test = transforms.Compose([ transforms.ToTensor(), t

我需要在PyTorch中使用ZCA美白。我想我已经找到了一种方法,可以通过使用transforms.LinearTransformation来实现这一点,并且我在PyTorch repo中找到了一个测试,它提供了一些关于如何实现这一点的见解,请参见下面的最终代码块或链接

我正在努力弄清楚我自己是如何应用这样的东西的

目前,我进行了以下几方面的转换:

    transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0,
                         np.array([63.0, 62.1, 66.7]) / 255.0),
])
文件中说,他们使用线性转换的方法如下:

torchvision.transforms.LinearTransformation(transformation_matrix, mean_vector) 
白化变换:假设X是以零为中心的列向量 数据然后计算数据协方差矩阵[D x D] torch.mmX.t,X,在该矩阵上执行SVD,并将其作为 变换矩阵

我可以从上面链接和下面复制的测试中看出,他们正在使用torch.mm来计算所谓的主分量:

def test_linear_transformation(self):
    num_samples = 1000
    x = torch.randn(num_samples, 3, 10, 10)
    flat_x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
    # compute principal components
    sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0)
    u, s, _ = np.linalg.svd(sigma.numpy())
    zca_epsilon = 1e-10  # avoid division by 0
    d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon)))
    u = torch.Tensor(u)
    principal_components = torch.mm(torch.mm(u, d), u.t())
    mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0))
    # initialize whitening matrix
    whitening = transforms.LinearTransformation(principal_components, mean_vector)
    # estimate covariance and mean using weak law of large number
    num_features = flat_x.size(1)
    cov = 0.0
    mean = 0.0
    for i in x:
        xwhite = whitening(i)
        xwhite = xwhite.view(1, -1).numpy()
        cov += np.dot(xwhite, xwhite.T) / num_features
        mean += np.sum(xwhite) / num_features
    # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov
    assert np.allclose(cov / num_samples, np.identity(1), rtol=2e-3), "cov not close to 1"
    assert np.allclose(mean / num_samples, 0, rtol=1e-3), "mean not close to 0"

    # Checking if LinearTransformation can be printed as string
    whitening.__repr__()
我如何应用这样的东西?我是在定义变换时使用它,还是在迭代训练循环时在训练循环中应用它


提前感谢

ZCA白化通常是一个预处理步骤,如中心缩减,其基本目的是使数据更友好,并提供以下附加信息。因此,应该在培训前应用一次

所以在你开始用给定的数据集X训练你的模型之前,计算白化的数据集Z,它是X与ZCA矩阵W_ZCA的乘积,你可以学习计算。然后在白化数据集上训练模型。 最后,你应该有这样的东西

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        # Feel free to use something more useful than a simple linear layer
        self._network = torch.nn.Linear(...)
        # Do your stuff
        ...

    def fit(self, inputs, labels):
    """ Trains the model to predict the right label for a given input """
        # Compute the whitening matrix and inputs
        self._zca_mat = compute_zca(inputs)
        whitened_inputs = torch.mm(self._zca_mat, inputs)

        # Apply training on the whitened data
        outputs = self._network(whitened_inputs)
        loss = torch.nn.MSEloss()(outputs, labels)
        loss.backward()
        optimizer.step()

     def forward(self, input):
         # You always need to apply the zca transform before forwarding, 
         # because your network has been trained with whitened data
         whitened_input = torch.mm(self._zca_mat, input)
         predicted_label = self._network.forward(whitened_input)
         return predicted_label
附加信息 白化数据意味着对其维度进行解相关,以便白化数据的相关矩阵为单位矩阵。这是一个旋转缩放操作,因此是线性的,实际上有无穷多个可能的ZCA变换。要理解ZCA背后的数学,请阅读