Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/java/370.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
如何用Java构建基本的神经网络?_Java_Machine Learning_Neural Network_Backpropagation - Fatal编程技术网

如何用Java构建基本的神经网络?

如何用Java构建基本的神经网络?,java,machine-learning,neural-network,backpropagation,Java,Machine Learning,Neural Network,Backpropagation,我正试图用Java构建一个基本的神经网络来计算逻辑XOR函数 该网络有两个输入神经元,一个隐藏层有三个神经元和一个输出神经元 但经过几次迭代后,输出中的错误变为NaN 我已经阅读了其他实现和实现神经网络的教程,但是我找不到错误。我觉得问题在于我的功能落后 请帮我理解我哪里出错了 我的代码: import org.ejml.simple.SimpleMatrix; import java.util.ArrayList; import java.util.List; import java.uti

我正试图用Java构建一个基本的神经网络来计算逻辑
XOR
函数

该网络有两个输入神经元,一个隐藏层有三个神经元和一个输出神经元

但经过几次迭代后,输出中的错误变为
NaN

我已经阅读了其他实现和实现神经网络的教程,但是我找不到错误。我觉得问题在于我的功能落后

请帮我理解我哪里出错了

我的代码:

import org.ejml.simple.SimpleMatrix;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

// SimpleMatrix constructor format: SimpleMatrix(rows, cols)
//The layers are represented as a matrix with 1 row and multiple columns (row vector)
public class Network {
    private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;
    static final double LEARNING_RATE = 0.3;

    Network(List<double[]> ips, List<double[]> ops){
        hidden = new SimpleMatrix(1, 3);
        W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());
        W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);
        initWeights(W1,W2);

        for(int i=0;i<5000;i++){
            for(int j=0;j<ips.size();j++){
                train(ips.get(j), ops.get(j));
            }
        }
        System.out.println("Trained");
    }

    //Prints output matrix
    SimpleMatrix predict(double[] ip){
        SimpleMatrix bkpInputs = inputs.copy();
        SimpleMatrix bkpOutputs = outputs.copy();

        inputs = new SimpleMatrix(1, ip.length);
        inputs.setRow(0, 0, ip);

        forward();
        inputs = bkpInputs;
        outputs = bkpOutputs;

        predicted.print();
        return predicted;
    }

    void train(double[] inputs, double[] outputs){
        this.inputs = new SimpleMatrix(1, inputs.length);
        this.inputs.setRow(0, 0, inputs);
        this.outputs = new SimpleMatrix(1, outputs.length);
        this.outputs.setRow(0,0,outputs);
        this.predicted = new SimpleMatrix(1,outputs.length);

        forward();
        backward();
    }

    private void initWeights(SimpleMatrix... W){
        Random random = new Random();
        for (SimpleMatrix aW : W) {
            for (int i = 0; i < aW.numRows(); i++)
                for (int j = 0; j < aW.numCols(); j++)
                    aW.set(i, j, random.nextDouble());
        }
    }

    //Using logistic function
    double sigmoid(double x){
        return (1/(1+Math.exp(-x)));
    }

    double sigmoidPrime(double x){
        return sigmoid(x)/(1-sigmoid(x));
    }

    void forward(){
        hidden = inputs.mult(W1);
        for(int i=0;i<hidden.numCols();i++){
            double x = sigmoid(hidden.get(0,i));
            hidden.set(0,i,x);
        }
        predicted = hidden.mult(W2);
        for(int i=0;i<predicted.numRows();i++){
            for(int j=0;j<predicted.numCols();j++){
                predicted.set(i,j, sigmoid(predicted.get(i,j)));
            }
        }
    }

    void backward(){

        //Error in output
        double o_error = 0.0;
        //Error functions I tried: (1/2)( (predicted-actual) ^ 2) and (predicted - actual)
        for(int i=0;i<outputs.numCols();i++)
            o_error += (predicted.get(0, i)-outputs.get(0, i));//Math.pow(predicted.get(0, i)-outputs.get(0, i), 2)/2;
        //Checking output error
        System.out.println(o_error);

        //Output deltas
        SimpleMatrix o_deltas = new SimpleMatrix(1, outputs.numCols());
        for(int i=0;i<outputs.numCols();i++)
            o_deltas.set(0, i, o_error*sigmoidPrime(predicted.get(0, i))); 


        //Error in hidden layer and deltas
        double h_error = o_deltas.dot(W2.transpose());
        SimpleMatrix h_deltas = new SimpleMatrix(1, hidden.numCols());
        for(int i=0;i<hidden.numCols();i++)
            h_deltas.set(0, i, h_error*sigmoidPrime(hidden.get(0, i)));


        //Hidden->Output layer update
        SimpleMatrix W2_delta = W2.mult(o_deltas.transpose());
        for(int i=0;i<W2.numRows();i++){
            for(int j=0;j<W2.numCols();j++){
                W2.set(i,j, W2.get(i,j) + LEARNING_RATE*W2_delta.get(i, 0));
            }
        }

        //Input->Hidden layer update
        SimpleMatrix W1_delta = W1.mult(h_deltas.transpose());
        for(int i=0;i<W1.numRows();i++){
            for(int j=0;j<W1.numCols();j++){
                W1.set(i,j, W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));
            }
        }
    }


    public static void main(String[] args){
        double[][] ips = {
                {0,0},
                {0,1},
                {1,0},
                {1,1}
        };

        double[][] ops = {
                {0},
                {1},
                {1},
                {0}
        };

        List<double[]> ip = new ArrayList<>();
        List<double[]> op = new ArrayList<>();

        for(int i=0;i<ips.length;i++){
            ip.add(ips[i]);
            op.add(ops[i]);
        }

        double[] testip = {1,0};
        Network n = new Network(ip,op);
        n.predict(testip);
    }
}
import org.ejml.simple.SimpleMatrix;
导入java.util.ArrayList;
导入java.util.List;
导入java.util.Random;
//SimpleMatrix构造函数格式:SimpleMatrix(行、列)
//层表示为一个矩阵,包含一行和多列(行向量)
公共班级网络{
私有SimpleMatrix输入、输出、隐藏、W1、W2、预测;
静态最终双学习率=0.3;
网络(列出IP、列出操作){
hidden=新的SimpleMatrix(1,3);
W1=新的SimpleMatrix(ips.get(0).length,hidden.numCols());
W2=新的SimpleMatrix(hidden.numCols(),ops.get(0.length));
初始重量(W1,W2);

对于(int i=0;i尝试较低的学习率。当错误为
NaN
时,通常意味着您的成本/错误函数已爆炸。尝试
[10^-3,10^-5]

范围内的内容,因此可能不是问题的原因,但我注意到:

W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));
当你更新权重时,我认为正确的公式是:

因此,您的代码应该是:

W1(i,j) += LEARNING_RATE * W1_delta.get(i, 0) *  <output from the connected node>;
W1(i,j)+=LEARNING_RATE*W1_delta.get(i,0)*;

它可能无法解决问题,但值得一试!

正确的代码?
h_deltas.set(h_error*…
没有编译,或者是,这是一个错误(函数重载)。但即使在纠正之后,同样的问题仍然存在。这样做会使每组输入的输出接近
1
。再次增加训练周期会产生
NaN