为了账号安全,请及时绑定邮箱和手机立即绑定

如何用Java构建一个基本的神经网络?

如何用Java构建一个基本的神经网络?

慕桂英546537 2021-05-31 17:35:37
我正在尝试XOR用 Java构建一个基本的神经网络来计算逻辑函数。该网络有两个输入神经元,一个包含三个神经元的隐藏层和一个输出神经元。但是经过几次迭代后,输出中的误差变为NaN。我已经浏览了其他实现神经网络的实现和教程,但我找不到错误。我觉得问题在于我的落后功能。请帮助我理解我哪里出错了。我的代码:import org.ejml.simple.SimpleMatrix;import java.util.ArrayList;import java.util.List;import java.util.Random;// SimpleMatrix constructor format: SimpleMatrix(rows, cols)//The layers are represented as a matrix with 1 row and multiple columns (row vector)public class Network {    private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;    static final double LEARNING_RATE = 0.3;    Network(List<double[]> ips, List<double[]> ops){        hidden = new SimpleMatrix(1, 3);        W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());        W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);        initWeights(W1,W2);        for(int i=0;i<5000;i++){            for(int j=0;j<ips.size();j++){                train(ips.get(j), ops.get(j));            }        }        System.out.println("Trained");    }    //Prints output matrix    SimpleMatrix predict(double[] ip){        SimpleMatrix bkpInputs = inputs.copy();        SimpleMatrix bkpOutputs = outputs.copy();        inputs = new SimpleMatrix(1, ip.length);        inputs.setRow(0, 0, ip);        forward();        inputs = bkpInputs;        outputs = bkpOutputs;        predicted.print();        return predicted;    }    void train(double[] inputs, double[] outputs){        this.inputs = new SimpleMatrix(1, inputs.length);        this.inputs.setRow(0, 0, inputs);        this.outputs = new SimpleMatrix(1, outputs.length);        this.outputs.setRow(0,0,outputs);        this.predicted = new SimpleMatrix(1,outputs.length);        forward();        backward();    }
查看完整描述

2 回答

?
江户川乱折腾

TA贡献1851条经验 获得超5个赞

因此,这可能不是导致您出现问题的原因,但我注意到:

W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));

当您更新权重时。我认为正确的公式是:

//img1.sycdn.imooc.com//60b6e3e20001ea5405630109.jpg

所以你的代码应该是:

W1(i,j) += LEARNING_RATE * W1_delta.get(i, 0) *  <output from the connected node>;

它可能无法解决它,但值得一试!


查看完整回答
反对 回复 2021-06-02
?
千巷猫影

TA贡献1829条经验 获得超7个赞

尝试使用较低的学习率。当错误出现时NaN,通常意味着您的成本/错误函数已经爆炸。尝试范围内的东西[10^-3, 10^-5]


查看完整回答
反对 回复 2021-06-02
  • 2 回答
  • 0 关注
  • 171 浏览

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信