1
votes

I am fairly inexperienced in Spark, and need help with groupBy and aggregate functions on a dataframe. Consider the following dataframe:

val df = (Seq((1, "a", "1"),
              (1,"b", "3"),
              (1,"c", "6"),
              (2, "a", "9"),
              (2,"c", "10"),
              (1,"b","8" ),
              (2, "c", "3"),
              (3,"r", "19")).toDF("col1", "col2", "col3"))

df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|   1|
|   1|   b|   3|
|   1|   c|   6|
|   2|   a|   9|
|   2|   c|  10|
|   1|   b|   8|
|   2|   c|   3|
|   3|   r|  19|
+----+----+----+

I need to group by col1 and col2 and calculate the mean of col3, which I can do using:

val col1df = df.groupBy("col1").agg(round(mean("col3"),2).alias("mean_col1"))
val col2df = df.groupBy("col2").agg(round(mean("col3"),2).alias("mean_col2"))

However, on a large dataframe with a few million rows and tens of thousands of unique elements in the columns to group by, it takes a very long time. Besides, I have many more columns to group by and it takes insanely long, which I am looking to reduce. Is there a better way to do the groupBy followed by the aggregation?

1

1 Answers

3
votes

You could use ideas from Multiple Aggregations, it might do everything in one shuffle operations, which is the most expensive operation.

Example:

val df = (Seq((1, "a", "1"),
(1,"b", "3"),
(1,"c", "6"),
(2, "a", "9"),
(2,"c", "10"),
(1,"b","8" ),
(2, "c", "3"),
(3,"r", "19")).toDF("col1", "col2", "col3"))

df.createOrReplaceTempView("data")

val grpRes = spark.sql("""select grouping_id() as gid, col1, col2, round(mean(col3), 2) as res 
                          from data group by col1, col2 grouping sets ((col1), (col2)) """)

grpRes.show(100, false)

Output:

+---+----+----+----+
|gid|col1|col2|res |
+---+----+----+----+
|1  |3   |null|19.0|
|2  |null|b   |5.5 |
|2  |null|c   |6.33|
|1  |1   |null|4.5 |
|2  |null|a   |5.0 |
|1  |2   |null|7.33|
|2  |null|r   |19.0|
+---+----+----+----+

gid is a bit funny to use, as it has some binary calculations underneath. But if your grouping columns can not have nulls, than you can use it for selecting the correct groups.

Execution Plan:

scala> grpRes.explain
== Physical Plan ==
*(2) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[avg(cast(col3#9 as double))])
+- Exchange hashpartitioning(col1#111, col2#112, spark_grouping_id#108, 200)
   +- *(1) HashAggregate(keys=[col1#111, col2#112, spark_grouping_id#108], functions=[partial_avg(cast(col3#9 as double))])
      +- *(1) Expand [List(col3#9, col1#109, null, 1), List(col3#9, null, col2#110, 2)], [col3#9, col1#111, col2#112, spark_grouping_id#108]
         +- LocalTableScan [col3#9, col1#109, col2#110]

As you can see there is single Exchange operation, the expensive shuffle.