I've been trying to use stochastic gradient descent with sum-of-squared-error as cost function to build a neural network using feed-forward backpropagation algorithm that is able to represent this training data :
Input Output
{{0,1} , {1,0,0,0,0,0,0,0}}
{{0.1,1}, {0,1,0,0,0,0,0,0}}
{{0.2,1}, {0,0,1,0,0,0,0,0}}
{{0.3,1}, {0,0,0,1,0,0,0,0}}
{{0.4,1}, {0,0,0,0,1,0,0,0}}
{{0.5,1}, {0,0,0,0,0,1,0,0}}
{{0.6,1}, {0,0,0,0,0,0,1,0}}
{{0.7,1}, {0,0,0,0,0,0,0,1}}
where the it consist of 1 input unit, 1 bias unit, 8 output units and a total of 16 weights (a total of 8 input weights, and 8 bias weights. Each 2 weights (1 from input and 1 from bias) out of total 16 refer to the respective single output unit). However, the set is very slow to converge. I'm using a sigmoid activation function for all output units :
output = 1/(1+e^(-weightedSum))
The error gradient I derived is :
errorGradient = learningRate*(output-trainingData) * output * (1-output)*inputUnit;
where trainingData
variable refers to the target output specified in the training set at the index of the current output unit and inputUnit
refers to input unit connected to the current weight.
Therefore, I update each individual weight at each iteration with the following equation :
weights of i = weights of i - (learningRate * errorGradient)
Codes:
package ann;
import java.util.Arrays;
import java.util.Random;
public class MSEANN {
static double learningRate= 0.1;
static double totalError=0;
static double previousTotalError=Double.POSITIVE_INFINITY;
static double[] weights;
public static void main(String[] args) {
genRanWeights();
double [][][] trainingData = {
{{0,1}, {1,0,0,0,0,0,0,0}},
{{0.1,1}, {0,1,0,0,0,0,0,0}},
{{0.2,1}, {0,0,1,0,0,0,0,0}},
{{0.3,1}, {0,0,0,1,0,0,0,0}},
{{0.4,1}, {0,0,0,0,1,0,0,0}},
{{0.5,1}, {0,0,0,0,0,1,0,0}},
{{0.6,1}, {0,0,0,0,0,0,1,0}},
{{0.7,1}, {0,0,0,0,0,0,0,1}},
};
while(true){
int errorCount = 0;
totalError=0;
//Iterate through training set
for(int i=0; i < trainingData.length; i++){
//Iterate through a list of output unit
for (int out=0 ; out < trainingData[i][1].length ; out++) {
double weightedSum = 0;
//Calculate weighted sum for this specific training set and this specific output unit
for(int ii=0; ii < trainingData[i][0].length; ii++) {
weightedSum += trainingData[i][0][ii] * weights[out*(2)+ii];
}
//Calculate output
double output = 1/(1+Math.exp(-weightedSum));
double error = Math.pow(trainingData[i][1][out] - output,2)/2;
totalError+=error;
if(error >=0.001){
errorCount++;
}
//Iterate through a the training set to update weights
for(int iii = out*2; iii < (out+1)*2; iii++) {
double firstGrad= -( trainingData[i][1][out] - output ) * output*(1-output);
weights[iii] -= learningRate * firstGrad * trainingData[i][0][iii % 2];
}
}
}
//Total Error accumulated
System.out.println(totalError);
//If error is getting worse every iteration, terminate the program.
if (totalError-previousTotalError>=0){
System.out.println("FAIL TO CONVERGE");
System.exit(0);
}
previousTotalError=totalError;
if(errorCount == 0){
System.out.println("Final weights: " + Arrays.toString(weights));
System.exit(0);
}
}
}
//Generate random weights
static void genRanWeights() {
Random r = new Random();
double low = -1/(Math.sqrt(2));
double high = 1/(Math.sqrt(2));
double[] result = new double[16];
for(int i=0;i<result.length;i++) {
result[i] = low + (high-low)*r.nextDouble();
}
System.out.println(Arrays.toString(result));
weights = result;
}
}
In the code above, I debugged through the ANN by printing the total error accumulated as I run the program, and it is shown at each iteration that the error is reducing at every iteration, however at a VERY SLOW rate. I've tweaked my learning rate, and it does not amount to much. Additionally, I've tried simplifying the training set to the following :
Input Output
{{0 ,1}, {1,0,0,0,0,0,0,0}},
{{0.1,1}, {0,1,0,0,0,0,0,0}},
// {{0.2,1}, {0,0,1,0,0,0,0,0}},
and the network trains well very quickly/instantly and is able to reproduce the targeted result. However, if un-comment the 3rd line, the training goes very slowly and never converge at all during running of the program even though I notice the sum of error is decreasing. So based on my experimentation above, the pattern I found out is if I use 3 training set, it will take such a long time that I have never even noticed the ANN finished training. If I use less than 2 or exactly 2, the network is able to instantly produce the correct output.
So my question is, is this 'anomaly' I observe due to that wrong choice of activation function, or due to choice of learning rate, or simply wrong implementation? And in the future, what are the steps you recommend I should to debug effectively for this type of problem?