Machine learning 神经网络只是不擅长解决这个简单的线性问题,还是因为训练不好?

Machine learning 神经网络只是不擅长解决这个简单的线性问题,还是因为训练不好?,machine-learning,neural-network,pytorch,skorch,Machine Learning,Neural Network,Pytorch,Skorch,我试着用PyTorch和skorch训练一个非常简单(我认为)的NN模型,但是糟糕的性能真的让我很困惑,所以如果你对此有任何了解,那就太好了 问题是这样的:有五个物体,A,B,C, D、 E,(用指纹标记,例如,(0,0)是A,(0.2,0.5)是B, 等)每个都对应一个数字,问题是试图找到 每个数字对应什么数字。培训数据是一个列表 “收款”和相应的金额。例如:[A,A,A,B,B] ==[(0,0)、(0,0)、(0,0)、(0.2,0.5)、(0.2,0.5)]-->15、[B、C、D、E]

我试着用PyTorch和skorch训练一个非常简单(我认为)的NN模型,但是糟糕的性能真的让我很困惑,所以如果你对此有任何了解,那就太好了

问题是这样的:有五个物体,A,B,C, D、 E,(用指纹标记,例如,(0,0)是A,(0.2,0.5)是B, 等)每个都对应一个数字,问题是试图找到 每个数字对应什么数字。培训数据是一个列表 “收款”和相应的金额。例如:[A,A,A,B,B] ==[(0,0)、(0,0)、(0,0)、(0.2,0.5)、(0.2,0.5)]-->15、[B、C、D、E]=[(0.2,0.5)、(0.5,0.8)、(0.3,0.9)、(1,1)]-->30。。。。请注意,一个集合中的对象数不是常数

没有噪音或任何东西,所以它只是一个可以直接求解的线性系统。所以我认为这对于NN来说是很容易找到的。我实际上是在用这个例子来检查一个更复杂问题的合理性,但我很惊讶NN甚至不能解决这个问题

现在,我只是想准确地指出哪里出了问题。模型定义似乎是正确的,数据输入是正确的,糟糕的性能是由于糟糕的培训造成的吗?还是NN就是不擅长这些

以下是模型定义:

class NN(nn.Module):
    def __init__(
        self,
        input_dim,
        num_nodes,
        num_layers,
        batchnorm=False,
        activation=Tanh,
    ):
        super(SingleNN, self).__init__()
        self.get_forces = get_forces
        self.activation_fn = activation

        self.model = MLP(
            n_input_nodes=input_dim,
            n_layers=num_layers,
            n_hidden_size=num_nodes,
            activation=activation,
            batchnorm=batchnorm,
        )

    def forward(self, batch):
        if isinstance(batch, list):
            batch = batch[0]
        with torch.enable_grad():
            fingerprints = batch.fingerprint.float()
            fingerprints.requires_grad = True
            #index of the current "collection" in the training list
            idx = batch.idx
            sorted_idx = torch.unique_consecutive(idx)
            o = self.model(fingerprints)
            total = scatter(o, idx, dim=0)[sorted_idx]

            return total

    @property
    def num_params(self):
        return sum(p.numel() for p in self.parameters())

class MLP(nn.Module):
    def __init__(
        self,
        n_input_nodes,
        n_layers,
        n_hidden_size,
        activation,
        batchnorm,
        n_output_nodes=1,
    ):
        super(MLP, self).__init__()
        if isinstance(n_hidden_size, int):
            n_hidden_size = [n_hidden_size] * (n_layers)
        self.n_neurons = [n_input_nodes] + n_hidden_size + [n_output_nodes]
        self.activation = activation
        layers = []
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(self.n_neurons[_], self.n_neurons[_ + 1]))
            layers.append(activation())
            if batchnorm:
                layers.append(nn.BatchNorm1d(self.n_neurons[_ + 1]))
        layers.append(nn.Linear(self.n_neurons[-2], self.n_neurons[-1]))
        self.model_net = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.model_net(inputs)
skorch的部分很简单

model = NN(2, 100, 2)
net = NeuralNetRegressor(
        module=model,
        ...
    )
net.fit(train_dataset, None)
对于测试运行,数据集如下所示(总共16个集合):

相应总计: [10,11,14,14,17,18,…]

只要看一眼就可以很容易地分辨出一个收藏中的物品是什么/有多少 培训过程如下所示:

 epoch    train_energy_mae    train_loss    cp     dur
-------  ------------------  ------------  ----  ------
      1              4.9852        0.5425     +  0.1486
      2             16.3659        4.2273        0.0382
      3              6.6945        0.7403        0.0025
      4              7.9199        1.2694        0.0024
      5             12.0389        2.4982        0.0024
      6              9.9942        1.8391        0.0024
      7              5.6733        0.7528        0.0024
      8              5.7007        0.5166        0.0024
      9              7.8929        1.0641        0.0024
     10              9.2560        1.4663        0.0024
     11              8.5545        1.2562        0.0024
     12              6.7690        0.7589        0.0024
     13              5.3769        0.4806        0.0024
     14              5.1117        0.6009        0.0024
     15              6.2685        0.8831        0.0024
....
    290              5.1899        0.4750        0.0024
    291              5.1899        0.4750        0.0024
    292              5.1899        0.4750        0.0024
    293              5.1899        0.4750        0.0024
    294              5.1899        0.4750        0.0025
    295              5.1899        0.4750        0.0025
    296              5.1899        0.4750        0.0025
    297              5.1899        0.4750        0.0025
    298              5.1899        0.4750        0.0025
    299              5.1899        0.4750        0.0025
    300              5.1899        0.4750        0.0025
    301              5.1899        0.4750        0.0024
    302              5.1899        0.4750        0.0025
    303              5.1899        0.4750        0.0024
    304              5.1899        0.4750        0.0024
    305              5.1899        0.4750        0.0025
    306              5.1899        0.4750        0.0024
    307              5.1899        0.4750        0.0025
你可以看到,它只是在一段时间后停止训练。 我可以确认,对于不同的指纹,神经网络确实给出了不同的结果,但不知何故,最终的预测值永远不够好


我尝试过不同的神经网络大小、学习率、批量大小、激活函数(tanh、relu等),但它们似乎都没有帮助。你对此有什么见解吗?有什么我做错了/可以尝试的,或者NN只是不擅长这种任务吗?

我注意到的第一件事是:
super(SingleNN,self)。\uuu init\uuu()
应该是
super(NN,self)。\uuu init\uuu()
。如果仍然有错误,请更改此选项并告诉我。

我对神经网络了解不多,但我的评论可能仍然有用。第一:我不太理解你对问题的描述。为什么示例15和30中的列表总和是?第二:神经网络不是有固定的输入大小吗?如何将可变长度列表输入神经网络?
 epoch    train_energy_mae    train_loss    cp     dur
-------  ------------------  ------------  ----  ------
      1              4.9852        0.5425     +  0.1486
      2             16.3659        4.2273        0.0382
      3              6.6945        0.7403        0.0025
      4              7.9199        1.2694        0.0024
      5             12.0389        2.4982        0.0024
      6              9.9942        1.8391        0.0024
      7              5.6733        0.7528        0.0024
      8              5.7007        0.5166        0.0024
      9              7.8929        1.0641        0.0024
     10              9.2560        1.4663        0.0024
     11              8.5545        1.2562        0.0024
     12              6.7690        0.7589        0.0024
     13              5.3769        0.4806        0.0024
     14              5.1117        0.6009        0.0024
     15              6.2685        0.8831        0.0024
....
    290              5.1899        0.4750        0.0024
    291              5.1899        0.4750        0.0024
    292              5.1899        0.4750        0.0024
    293              5.1899        0.4750        0.0024
    294              5.1899        0.4750        0.0025
    295              5.1899        0.4750        0.0025
    296              5.1899        0.4750        0.0025
    297              5.1899        0.4750        0.0025
    298              5.1899        0.4750        0.0025
    299              5.1899        0.4750        0.0025
    300              5.1899        0.4750        0.0025
    301              5.1899        0.4750        0.0024
    302              5.1899        0.4750        0.0025
    303              5.1899        0.4750        0.0024
    304              5.1899        0.4750        0.0024
    305              5.1899        0.4750        0.0025
    306              5.1899        0.4750        0.0024
    307              5.1899        0.4750        0.0025