3
votes

Trained a RandomForest as this (Spark 1.6.0)

val numClasses = 4 // 0-2
val categoricalFeaturesInfo = Map[Int, Int]()
val numTrees = 9
val featureSubsetStrategy = "auto" // Let the algorithm choose.
val impurity = "gini"
val maxDepth = 6
val maxBins = 32

val model = RandomForest.trainClassifier(trainRDD, numClasses, 
                                         categoricalFeaturesInfo, numTrees, 
                                         featureSubsetStrategy, impurity, 
                                         maxDepth, maxBins)

input labels:

labels = labeledRDD.map(lambda lp: lp.label).distinct().collect()
for label in sorted(labels):
    print label

0.0
1.0
2.0

But the output only contain only two classes:

metrics = MulticlassMetrics(labelsAndPredictions)
df_confusion = metrics.confusionMatrix()
display_cm(df_confusion)

Output:

83017.0  81.0    0.0
8703.0   2609.0  0.0
10232.0  255.0   0.0

Output from when I load the same model in pyspark and run it against the other data (parts of the above)

DenseMatrix([[  1.75280000e+04,   3.26000000e+02],
            [  3.00000000e+00,   1.27400000e+03]])
1
I cannot reproduce this or at least there is nothing wrong with a confusion matrix I get. And you actually have 3 classes :)zero323
@zero323 i saved it and loaded it inbetween. might be that. Can you post your reproductionoluies
I basically replicated example from docs replacing data with data.map{case LabeledPoint(i, v) => if (i == 0.0) LabeledPoint(2.0, v) else LabeledPoint(3.0, v) } ++ datazero323
from scala I can se that class 2.0 is always 0, with pyspark I get the densematrix per above (with no third column) @zero323oluies
Could your provide some minimal reproducible examples for both Python and Scala?zero323

1 Answers

0
votes

It got better... I used pearson correlation to figure out which columns did not have any correlation. Deletes the ten lowest correlating columns and now I get ok results:

enter image description here

Test Error = 0.0401823
precision = 0.959818
Recall = 0.959818

ConfusionMatrix([[ 17323.,      0.,    359.],
                 [     0.,   1430.,     92.],
                 [   208.,    170.,   1049.]])

enter image description here