我目前正在尝试编程一个神经网络。。。为了学习我想用反向传播算法!我的问题是,我不知道我的错误在哪里。我试着训练它的逻辑和逻辑。
第一轮后我的网络错误是:
28.68880035284087输入1 | 1
22.17048518538824输入1 | 0
21.346787829014342输入0 | 1
20.44791655274438输入0 | 0
如果我进行几次迭代,我的错误如下:
34.17584528001372输入1 | 1
18.315643070675343输入1 | 0
17.568891920535222输入0 | 1
17.753497551261436输入0 | 0
我完全不知道为什么输入1 | 1的误差越来越大,而其他的却越来越小。。。
这是我的密码:
testdata的类:
public class Trainingset
{
private double[] input;
private double[] target;
public Trainingset(double[] input, double[] target)
{
this.input = input;
this.target = target;
}
public double[] getInput()
{
return input;
}
public double[] getTarget()
{
return target;
}
}
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
public class TrainingData
{
private List<Trainingset> trainingSets;
private Random random;
private int nextEntry;
public TrainingData()
{
random = new Random();
trainingSets = new ArrayList<Trainingset>();
nextEntry = 0;
}
public void addData(double[] input, double[] target)
{
Trainingset ts = new Trainingset(input.clone(), target.clone());
trainingSets.add(ts);
}
public Trainingset getRandomTrainingset()
{
return trainingSets.get(random.nextInt(trainingSets.size()));
}
public Trainingset getNext()
{
if(nextEntry == trainingSets.size())
nextEntry = 0;
return trainingSets.get(nextEntry++);
}
}
网络类:
import java.util.ArrayList;
import java.util.List;
public class FFN3
{
private List<FFNlayer3> layers;
private double learningrate = 0.45;
private double momentum = 0.9;
private double outputError;
private double networkErrkor;
public FFN3()
{
layers = new ArrayList<>();
layers.add(new FFNlayer3(2));
layers.add(new FFNlayer3(1));
layers.get(0).setNextLayer(layers.get(1));
layers.get(1).setPrevLayer(layers.get(0));
double[][] ItoH = {
{ 0.4, 0.1 },
{ -0.1, -0.1 }
};
double[][] HtoO = {
{ 0.06, -0.4 }
};
layers.get(0).setWeights(ItoH);
layers.get(1).setWeights(HtoO);
networkErrkor = Double.MAX_VALUE;
}
public void learn(TrainingData td)
{
Trainingset ts = td.getNext();
double[] results = compute(ts.getInput());
double error = 0;
for(int i = 0; i < results.length; i++)
{
error += Math.pow(ts.getTarget()[i] - results[i], 2);
}
networkErrkor = error / results.length;
layers.get(layers.size()-1).updateWeights(learningrate, momentum, ts.getTarget());
layers.get(0).updateHiddenWeights(learningrate, momentum, ts.getInput());
}
public double getNetworkError()
{
return networkErrkor;
}
public double[] compute(double[] input)
{
return layers.get(0).compute(input);
}
}
图层类:
public class FFNlayer3
{
private double[][] incomingWeights;
private double[][] prevWeightChanges;
private double[] neuronValues;
private double[] neuronSums;
private double[] errors;
private FFNlayer3 prevLayer;
private FFNlayer3 nextLayer;
public FFNlayer3(int neuroncount)
{
neuronValues = new double[neuroncount];
neuronSums = new double[neuroncount];
errors = new double[neuroncount];
nextLayer = null;
prevLayer = null;
}
public void setWeights(double[][] weights)
{
incomingWeights = weights;
prevWeightChanges = new double[incomingWeights.length][incomingWeights[0].length];
}
public void setPrevLayer(FFNlayer3 prevLayer)
{
this.prevLayer = prevLayer;
}
public void setNextLayer(FFNlayer3 nextLayer)
{
this.nextLayer = nextLayer;
}
public void updateWeights(double learningrate, double momentum, double[] targetValues)
{
for(int i = 0; i < errors.length; i++)
{
errors[i] = neuronValues[i] * (1 - neuronValues[i]) * (targetValues[i] - neuronValues[i]);
}
for(int i = 0; i < incomingWeights.length; i++)
{
for(int j = 0; j < incomingWeights[i].length; j++)
{
double delta = learningrate * errors[i] * prevLayer.getNeuronValues()[j];
incomingWeights[i][j] += delta + momentum * prevWeightChanges[i][j];
}
}
prevLayer.updateHiddenWeights(learningrate, momentum);
}
public void updateHiddenWeights(double learningrate, double momentum)
{
if(prevLayer==null)
return;
for(int i = 0; i < errors.length; i++)
{
for(int j = 0; j < nextLayer.getErrors().length; j++)
{
errors[i] += nextLayer.getErrors()[j] * nextLayer.getWeights()[j][i];
}
}
for(int i = 0; i < incomingWeights.length; i++)
{
for(int j = 0; j < incomingWeights[i].length; j++)
{
double delta = learningrate * errors[i] * prevLayer.getNeuronValues()[j];
incomingWeights[i][j] += delta + momentum * prevWeightChanges[i][j];
}
}
prevLayer.updateHiddenWeights(learningrate, momentum);
}
public void updateHiddenWeights(double learningrate, double momentum, double[] input)
{
for(int i = 0; i < errors.length; i++)
{
for(int j = 0; j < nextLayer.getErrors().length; j++)
{
errors[i] += nextLayer.getErrors()[j] * nextLayer.getWeights()[j][i];
}
}
for(int i = 0; i < incomingWeights.length; i++)
{
for(int j = 0; j < incomingWeights[i].length; j++)
{
double delta = learningrate * errors[i] * input[j];
incomingWeights[i][j] += delta + momentum * prevWeightChanges[i][j];
}
}
}
public double[][] getWeights()
{
return incomingWeights;
}
public double[] getErrors()
{
return errors;
}
public double[] getNeuronValues()
{
return neuronValues;
}
public double[] compute(double[] input)
{
for(int i = 0; i < neuronValues.length; i++)
{
for(int j = 0; j < incomingWeights[i].length; j++)
{
neuronSums[i] += input[j] * incomingWeights[i][j];
}
neuronValues[i] = SIGMOID(neuronSums[i]);
neuronSums = new double[neuronSums.length];
}
if(nextLayer==null)
return neuronValues;
return nextLayer.compute(neuronValues);
}
private double SIGMOID(double value)
{
return 1 / (1+ Math.exp(-value));
}
}
还有我的主要观点:
FFN3 network = new FFN3();
double[] input = new double[2];
double[] target = new double[1];
TrainingData td = new TrainingData();
input[0] = 1;
input[1] = 1;
target[0] = 1;
td.addData(input, target);
input[0] = 1;
input[1] = 0;
target[0] = 0;
//target[1] = 1;
td.addData(input, target);
input[0] = 0;
input[1] = 1;
target[0] = 0;
td.addData(input, target);
input[0] = 0;
input[1] = 0;
target[0] = 0;
td.addData(input, target);
while(Double.compare(network.getNetworkError(), 0.001)>0)
{
network.learn(td);
System.out.println(network.getNetworkError()*100);
}
我正在使用此文档:http://www.dataminingmasters.com/uploads/studentprojects/neuralnetworks.pdf
第一个纪元之后的值与文档中的值相似。。。怎么了?是文件,还是我的代码?
希望你能帮助我!
1条答案
按热度按时间okxuctiv1#
您可以尝试使用bigdecimal而不是double,因为它们可能会引起麻烦(请查看此处以了解更多信息)