0
votes

i'm trying to make a classifier using Logistic Regression to predict the right digit (label) based on the values of the pixels (features).
i'm using Apache Spark in java and i'm using the data from mnist database after i convert it into libsvm format, here is my code:

package ml;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.optimization.L1Updater;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;


public class MNIST5 {

    static String trainImagesPath = "train-images.idx3-ubyte";
    static String trainLabelsPath = "train-labels.idx1-ubyte";
    static String testImagesPath = "t10k-images.idx3-ubyte";
    static String testLabelsPath = "t10k-labels.idx1-ubyte";


    static SparkConf conf = new SparkConf()
            .setMaster("local")
            .setAppName("Machine learning - MNIST Example");

    static SparkContext sc = SparkContext.getOrCreate(conf);

    public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {

        mnist_spark_logistic_regression();
        //saveMnistDataLibsvmFormat();

    }

    static void mnist_spark_logistic_regression(){

        long t;

        System.out.println("Loading training data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> trainData = MLUtils.loadLibSVMFile(sc, "mnist-train-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 6661 ms


        System.out.println("Training logistic regression classifier ...");
        t = System.currentTimeMillis();
        // Run training algorithm to build the model.         
        LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS()           
            .setNumClasses(10);
        //lr.optimizer().setUpdater(new L1Updater());        
        LogisticRegressionModel model = lr.run(trainData.rdd());

        System.out.println(System.currentTimeMillis()-t+" ms"); // 1951 ms
        // print weights and intercept
        System.out.println("numClasses: "+model.numClasses());
        System.out.println("numFeatures: "+model.numFeatures());
        System.out.println("Weights: "+model.weights());
        System.out.println("Wlength: "+model.weights().size());
        System.out.println("Intercept: "+model.intercept());


        System.out.println("Loading testing data ...");
        t = System.currentTimeMillis();
        JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, "mnist-test-data.txt").toJavaRDD();
        System.out.println(System.currentTimeMillis()-t+" ms"); // 11356 ms


        System.out.println("Compute raw scores on the test set ...");
        t = System.currentTimeMillis();
        // Compute raw scores on the test set.
        JavaPairRDD<Object, Object> predictionAndLabels = testData.mapToPair(
            (p) -> {
                return new Tuple2<>(model.predict(p.features()), p.label());
            }
        );
        System.out.println(System.currentTimeMillis()-t+" ms"); // 47 ms


        System.out.println("Iterate ...");
        t = System.currentTimeMillis();
        JavaRDD<Integer> wyw = testData.map(new Function<LabeledPoint, Integer>() {
            @Override
            public Integer call(LabeledPoint t1) throws Exception {
                double yb = model.predict(t1.features());

                if(yb==t1.label())
                    System.out.println("label: "+t1.label()+", predicted: "+yb);
                return 0;
            }
        });
        wyw.collect();
        System.out.println(System.currentTimeMillis()-t+" ms");


        System.out.println("Evaluating ...");
        t = System.currentTimeMillis();
        // Get evaluation metrics.
        MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
        double accuracy = metrics.accuracy();
        System.out.println("Accuracy = " + accuracy); // 0.098
        System.out.println(System.currentTimeMillis()-t+" ms"); // 1108 ms


        // Save and load model
        model.save(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "mnist_logreg_model"+"/javaMNISTLogisticRegressionWithLBFGSModel");
        System.out.println(sameModel);

    }


    static ArrayList<LabeledPoint> getData(String imagesPath, String labelsPath){        

        JavaRDD<LabeledPoint> data;
        ArrayList<LabeledPoint> lpts = new ArrayList<>();

        FileInputStream inImage = null;
        FileInputStream inLabel = null;

        try {
            inImage = new FileInputStream(imagesPath);
            inLabel = new FileInputStream(labelsPath);

            int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfRows  = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());

            int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
            int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());

            int numberOfPixels = numberOfRows * numberOfColumns;
            double[] imgPixels = new double[numberOfPixels];

            for(int i = 0; i < numberOfImages; i++) {

                //if(i % 100 == 0) {System.out.println("Number of images extracted: " + i);}

                for(int p = 0; p < numberOfPixels; p++) {
                    imgPixels[p] = inImage.read();
                }

                int label = inLabel.read();

                LabeledPoint lp = LabeledPoint.apply(label, Vectors.dense(imgPixels));
                lpts.add(lp);

            }

        } 
        catch (FileNotFoundException e) { e.printStackTrace(); } 
        catch (IOException e) { e.printStackTrace(); } 
        finally {
            if (inImage != null) {
                try {
                    inImage.close();
                } catch (IOException e) { e.printStackTrace(); }
            }
            if (inLabel != null) {
                try {
                    inLabel.close();
                } catch (IOException e) { e.printStackTrace(); }
            }
        }

        return lpts;
    }

    static JavaRDD<LabeledPoint> loadData(String imagesPath, String labelsPath){        

        ArrayList<LabeledPoint> lpts = getData(imagesPath, labelsPath);

        JavaSparkContext jsc = new JavaSparkContext(sc);
        JavaRDD<LabeledPoint> data = jsc.parallelize(lpts);

        return data;
    }

    static void saveMnistDataLibsvmFormat() throws FileNotFoundException, UnsupportedEncodingException{

        ArrayList<LabeledPoint> data = getData(testImagesPath, testLabelsPath);        
        PrintWriter writer = new PrintWriter("mnist-test-data.txt", "UTF-8");
        for(LabeledPoint lp : data){
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer.println(s.toString());
        }
        // writer.println("The first line");
        // writer.println("The second line");
        writer.close();

        ArrayList<LabeledPoint> data2 = getData(trainImagesPath, trainLabelsPath);
        PrintWriter writer2 = new PrintWriter("mnist-train-data.txt", "UTF-8");
        for(LabeledPoint lp : data2){
            StringBuilder s = new StringBuilder();
            s.append(lp.label()).append(" ");
            int i;
            double[] arr = lp.features().toArray();
            for(i=0;i<arr.length-1;i++)
                if(arr[i]!=0)
                    s.append(i+1).append(":").append(arr[i]).append(" ");
            if(arr[i]!=0)
                s.append(i+1).append(":").append(arr[i]);
            writer2.println(s.toString());
        }
        // writer.println("The first line");
        // writer.println("The second line");
        writer2.close();

    }

}


The values of the weights are all equal to zero, i don't understand why ? please some help, thank you.

1

1 Answers

0
votes

What do you mean by

            `if(yb==t1.label())
                System.out.println("label: "+t1.label()+", predicted: "+yb);
            return 0;`

It always return 0.