3
votes

So I perform the necessary imports etc

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import spark.implicits._

then define some latlong points

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)  

I then create a spark Dataframe like this and check that it works:

val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show()

the dataframe consists of the following types

DataFrame = [AR1: array<struct<_1:double,_2:double>>, AR2: array<struct<_1:double,_2:double>>]

I create a function to create a combination of points

// function to do what I want
val latlongexplode =  (x: Array[(Double,Double)], y: Array[(Double,Double)]) => {
 for (a <- x; b <-y) yield (a,b)
}

I check that the function works

latlongexplode(Array(london,york),Array(suburbia,southampton))

and it does. However after i create a UDF out of this function

// declare function into a Spark UDF
val latlongexplodeUDF = udf (latlongexplode) 

when i try to use it in the spark dataframe I have created above like this:

exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2")).show(false)

I get a really long stacktrace which basically boils down to :

java.lang.ClassCastException: scala.collection.mutable.WrappedArray$ofRef cannot be cast to [Lscala.Tuple2;
org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$f$3(ScalaUDF.scala:121) org.apache.spark.sql.catalyst.expressions.ScalaUDF.eval(ScalaUDF.scala:1063) org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:151) org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:50) org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:32) scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:273)

How can I get this udf to work in Scala Spark? (im using 2.4 at the moment if this helps)

EDIT: it could be that the way I construct my example df has an issue. But what I have as the actual data is an array (of unknown size) of lat/long tuples on each column.

1
You may want to contact Raphael Roth on this, he seems to go that step further than most.thebluephantom
It has to do with the struct aspect for the array, but I am not sure how to get around this.thebluephantom
@raphaelroth can u comment pls?thebluephantom
@thebluephantom no need for raphael, I've solved it :)mck
@mck thanks for the explanation... and the solution. Really appreciate it.Mamonu

1 Answers

3
votes

When working with struct types in UDF, they are represented as Row objects, and array columns are represented as Seq. Also, you need to return a struct in the form of a Row, and you need to define a schema to return a struct.

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

val london = (1.0, 1.0)
val suburbia = (2.0, 2.0)
val southampton = (3.0, 3.0)  
val york = (4.0, 4.0)
val exampleDF = Seq((List(london,suburbia),List(southampton,york)),
    (List(york,london),List(southampton,suburbia))).toDF("AR1","AR2")
exampleDF.show(false)
+------------------------+------------------------+
|AR1                     |AR2                     |
+------------------------+------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|
+------------------------+------------------------+
val latlongexplode = (x: Seq[Row], y: Seq[Row]) => {
    for (a <- x; b <- y) yield Row(a, b)
}

val udf_schema = ArrayType(
    StructType(Seq(
        StructField(
            "city1",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        ),
        StructField(
            "city2",
            StructType(Seq(
                StructField("lat", FloatType),
                StructField("long", FloatType)
            ))
        )
    ))
)

// include this line if you see errors like 
// "You're using untyped Scala UDF, which does not have the input type information."
// spark.sql("set spark.sql.legacy.allowUntypedScalaUDF = true")

val latlongexplodeUDF = udf(latlongexplode, udf_schema)
result = exampleDF.withColumn("latlongexplode", latlongexplodeUDF($"AR1",$"AR2"))
result.show(false)
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|AR1                     |AR2                     |latlongexplode                                                                                          |
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+
|[[1.0, 1.0], [2.0, 2.0]]|[[3.0, 3.0], [4.0, 4.0]]|[[[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [4.0, 4.0]], [[2.0, 2.0], [3.0, 3.0]], [[2.0, 2.0], [4.0, 4.0]]]|
|[[4.0, 4.0], [1.0, 1.0]]|[[3.0, 3.0], [2.0, 2.0]]|[[[4.0, 4.0], [3.0, 3.0]], [[4.0, 4.0], [2.0, 2.0]], [[1.0, 1.0], [3.0, 3.0]], [[1.0, 1.0], [2.0, 2.0]]]|
+------------------------+------------------------+--------------------------------------------------------------------------------------------------------+