2
votes

I am using spark 2.3 in my scala application. I have a dataframe which create from spark sql that name is sqlDF in the sample code which I shared. I have a string list that has the items below

List[] stringList items

-9,-8,-7,-6

I want to replace all values that match with this lists item in all columns in dataframe to 0.

Initial dataframe

column1 | column2 | column3
1       |1        |1       
2       |-5       |1       
6       |-6       |1       
-7      |-8       |-7       

It must return to

column1 | column2 | column3
1       |1        |1       
2       |-5       |1       
6       |0        |1       
0       |0        |0

For this I am itarating the query below for all columns (more than 500) in sqlDF.

sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList:_*), 0).otherwise(col(currColumnName)))

But getting the error below, by the way if I choose only one column for iterating it works, but if I run the code above for 500 columns iteration it fails

Exception in thread "streaming-job-executor-0" java.lang.StackOverflowError at scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:57) at scala.collection.generic.GenTraversableFactory$GenericCanBuildFrom.apply(GenTraversableFactory.scala:52) at scala.collection.TraversableLike$class.builder$1(TraversableLike.scala:229) at scala.collection.TraversableLike$class.map(TraversableLike.scala:233) at scala.collection.immutable.List.map(List.scala:285) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:333) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)

What is the thing that I am missing?

3

3 Answers

2
votes

Here is a different approach applying left anti join between columnX and X where X is your list of items transferred into a dataframe. The left anti join will return all the items not present in X, the results we concatenate them all together through an outer join (which can be replaced with left join for better performance, this though will exclude records with all zeros i.e id == 3) based on the id assigned with monotonically_increasing_id:

import org.apache.spark.sql.functions.{monotonically_increasing_id, col}

val df = Seq(
(1, 1, 1),       
(2, -5, 1),       
(6, -6, 1),       
(-7, -8, -7))
.toDF("c1", "c2", "c3")
.withColumn("id", monotonically_increasing_id())

val exdf = Seq(-9, -8, -7, -6).toDF("x")

df.columns.map{ c =>
   df.select("id", c).join(exdf, col(c) === $"x", "left_anti")
}
.reduce((df1, df2) => df1.join(df2, Seq("id"), "outer"))
.na.fill(0)
.show

Output:

+---+---+---+---+
| id| c1| c2| c3|
+---+---+---+---+
|  0|  1|  1|  1|
|  1|  2| -5|  1|
|  3|  0|  0|  0|
|  2|  6|  0|  1|
+---+---+---+---+
1
votes

foldLeft works perfect for your case here as below

val df = spark.sparkContext.parallelize(Seq(
  (1, 1, 1),
  (2, -5, 1),
  (6, -6, 1),
  (-7, -8, -7)
)).toDF("a", "b", "c")

val list = Seq(-7, -8, -9)

val resultDF = df.columns.foldLeft(df) { (acc, name) => {
    acc.withColumn(name, when(col(name).isin(list: _*), 0).otherwise(col(name)))
  }
}

Output:

+---+---+---+
|a  |b  |c  |
+---+---+---+
|1  |1  |1  |
|2  |-5 |1  |
|6  |-6 |1  |
|0  |0  |0  |
+---+---+---+
0
votes

I would suggest you to broadcast the list of String :

val stringList=sc.broadcast(<Your List of List[String]>)

After that use this :

sqlDF = sqlDF.withColumn(currColumnName, when(col(currColumnName).isin(stringList.value:_*), 0).otherwise(col(currColumnName)))

Make sure your currColumnName also is in String Format. Comparison should be String to String