1
votes

I have a Spark dataframe like below, and I want to perform some aggregate functions on it by different columns independently of each other and get some statistics over a single column.

val df = (Seq((1, "a", "1"),
              (1,"b", "3"),
              (1,"c", "6"),
              (2, "a", "9"),
              (2,"c", "10"),
              (1,"b","8" ),
              (2, "c", "3"),
              (3,"r", "19")).toDF("col1", "col2", "col3"))

df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   1|   a|   1|
|   1|   b|   3|
|   1|   c|   6|
|   2|   a|   9|
|   2|   c|  10|
|   1|   b|   8|
|   2|   c|   3|
|   3|   r|  19|
+----+----+----+

I want to group by col1 and col2 and get the average of column col3 to get the following output dataframe:

+----+----+----+---------+---------+
|col1|col2|col3|mean_col1|mean_col2|
+----+----+----+---------+---------+
|   1|   a|   1|      4.5|      5.0|
|   1|   b|   3|      4.5|      5.5|
|   1|   c|   6|      4.5|     6.33|
|   2|   a|   9|     7.33|      5.0|
|   2|   c|  10|     7.33|     6.33|
|   1|   b|   8|      4.5|      5.5|
|   2|   c|   3|     7.33|     6.33|
|   3|   r|  19|     19.0|     19.0|
+----+----+----+---------+---------+

This can be done using the following operations:

val col1df = df.groupBy("col1").agg(round(mean("col3"),2).alias("mean_col1"))

val col2df = df.groupBy("col2").agg(round(mean("col3"),2).alias("mean_col2"))

df.join(col1df, "col1").join(col2df, "col2").select($"col1",$"col2",$"col3",$"mean_col1",$"mean_col2").show()

However, if I have many more columns to group by, I'd need to perform several expensive join operations. Moreover, grouping by each column before doing the join seems rather cumbersome. What's the best way to get the output dataframe by minimizing (and preferably eliminating) the join operations, and without having to generate the dataframes col1df and col2df ?

1

1 Answers

3
votes

As you want your final table to contain all of the original rows, that can be done via a window function.

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val df = (Seq((1, "a", "1"),
    (1,"b", "3"),
    (1,"c", "6"),
    (2, "a", "9"),
    (2,"c", "10"),
    (1,"b","8" ),
    (2, "c", "3"),
    (3,"r", "19")).toDF("col1", "col2", "col3"))

  df.show(false)

  val col1Window = Window.partitionBy("col1").rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
  val col2Window = Window.partitionBy("col2").rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)


  val res = df
              .withColumn("mean_col1", round(mean("col3").over(col1Window), 2))
              .withColumn("mean_col2", round(mean("col3").over(col2Window), 2))

  res.show(false)

In the context of the Window functions partitionBy is similar as groupBy and rangeBetween defines the size of the window, which is all of the rows with the same value of the partitioned columns or it could be seen as a group by column.