
I have a DataFrame which I have to apply a series of filter queries against. For example, I load my DataFrame as follows.

val df = spark.read.parquet("hdfs://box/some-parquet")

I then have a bunch of "arbitrary" filters as follows.

  • C0='true' and C1='false'
  • C0='false' and C3='true'
  • and so on...

I typically get these filters dynamically using a util method.

val filters: List[String] = getFilters()

All I do is apply these filters to the DataFrame to get the counts. For example.

val counts = filters.map(filter => {

I noticed that is NOT a parallel/distributed operation when mapping over the filters. If I stick the filters into an RDD/DataFrame, this approach won't work either, because I'd then be performing nested data frame operations (which, as I've read on SO, is not allowed in Spark). Something like the following gives a NullPointerException (NPE).

val df = spark.read.parquet("hdfs://box/some-parquet")
val filterRDD = spark.sparkContext.parallelize(List("C0='false'", "C1='true'"))
val counts = filterRDD.map(df.filter(_).count).collect
Caused by: java.lang.NullPointerException
  at org.apache.spark.sql.Dataset.filter(Dataset.scala:1127)
  at $anonfun$1.apply(:27)
  at $anonfun$1.apply(:27)
  at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
  at scala.collection.Iterator$class.foreach(Iterator.scala:893)
  at scala.collection.AbstractIterator.foreach(Iterator.scala:1336)
  at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
  at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
  at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
  at scala.collection.AbstractIterator.to(Iterator.scala:1336)
  at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
  at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1336)
  at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
  at scala.collection.AbstractIterator.toArray(Iterator.scala:1336)
  at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$13.apply(RDD.scala:912)
  at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$13.apply(RDD.scala:912)
  at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1899)
  at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1899)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:70)
  at org.apache.spark.scheduler.Task.run(Task.scala:86)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:274)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
  at java.lang.Thread.run(Thread.java:745)

Is there any way to parallelize/distribute the count filters on a DataFrame in Spark? By the way, I am on Spark v2.0.2.

Assuming what you want to achieve is a single pass over the input datas (otherwise, there may be no gain to expect out of this), I'd rework the filter functions to UDFs that return 1 (filter match) or 0 (no filter match), add 1 column by UDF to the dataframe, and do a groupBy / count on the added columns, which would result in a 1 row dataframe, holding all the counts.GPI
could you show an example?Jane Wayne

1 Answers


By doing that, the only expectable gain (which can be very substantial) would be to pass only once on the input data.

I would do it like so (programmatic solution, but equivalent SQL is possible) :

  1. Convert your filters to UDFs that return 1 or 0
  2. Add one column for each of these UDFS
  3. Group By / sum your datas.

A sample spark session looks like :

scala> val data = spark.createDataFrame(Seq("A", "BB", "CCC").map(Tuple1.apply)).withColumnRenamed("_1", "input")

data: org.apache.spark.sql.DataFrame = [input: string]

scala> data.show
|    A|
|   BB|
|  CCC|

scala> val containsBFilter = udf((input: String) => if(input.contains("B")) 1 else 0)
containsBFilter: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(StringType)))

scala> val lengthFilter = udf((input: String) => if (input.length < 3) 1 else 0)
lengthFilter: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,Some(List(StringType)))

scala> data.withColumn("inputLength", lengthFilter($"input")).withColumn("containsB", containsBFilter($"input")).select(sum($"inputLength"), sum($"containsB")).show

|               2|             1|