The developer API example (https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala) gives a simple implementation example for the function predictRaw() in a classification model. This is a function within abstract class ClassificationModel that must be implemented in the concrete class. According to the developer API example, you can calculate it as follows:
override def predictRaw(features: Features.Type): Vector = {
val margin = BLAS.dot(features, coefficients)
Vectors.dense(-margin, margin) // Binary classification so we return a length-2 vector, where index i corresponds to class i (i = 0, 1).
}
My understanding of BLAS.dot(features, coefficients)
is that this is simply the matrix dot product of the features vector (of length numFeatures) by the coefficients vector (of length numFeatures), so effectively each 'feature' cols is mutliplied by a coefficient and then summed to get val margin
. However Spark no longer provides access to the BLAS library as it's private in MLlib and instead matrix mutliplication is provided in the Matrix trait where there are various factory methods for multiplication.
My understanding of how to implement predictRaw()
using the matrix factory methods is as follows:
override def predictRaw(features: Vector): Vector = {
//coefficients is a Vector of length numFeatures: val coefficients = Vectors.zeros(numFeatures)
val coefficientsArray = coefficients.toArray
val coefficientsMatrix: SparkDenseMatrix = new SparkDenseMatrix(numFeatures, 1, coefficientsArray)
val margin: Array[Double] = coefficientsMatrix.multiply(features).toArray // contains a single element
val rawPredictions: Array[Double] = Array(-margin(0),margin(0))
new SparkDenseVector(rawPredictions)
}
This will require the overhead of converting the data structures to Arrays. Is there a better way? It seems strange that BLAS is now private. NB. Code not tested! At the moment val coefficients: Vector
is just a vector of zeros, but once I have implemented the learning algorithm this would contain the results.