0
votes

I have a dataframe contains 7 days, 24 hours data, so it has 144 columns.

id     d1h1  d1h2   d1h3 .....  d7h24 
aaa    21     24     8   .....   14       
bbb    16     12     2   .....   4
ccc    21      2     7   .....   6

what I want to do, is to find the max 3 values for each day:

id    d1        d2       d3  ....   d7
aaa  [22,2,2] [17,2,2] [21,8,3]    [32,11,2]
bbb  [32,22,12] [47,22,2] [31,14,3]    [32,11,2]
ccc  [12,7,4] [28,14,7] [11,2,1]    [19,14,7] 
2

2 Answers

1
votes
import org.apache.spark.sql.functions._
var df = ...
val first3 = udf((list : Seq[Double]) => list.slice(0,3))
for (i <- 1 until 7) {
    val columns = (1 until 24).map(x=> "d"+i+"h"+x)
    df = df
        .withColumn("d"+i, first3(sort_array(array(columns.head, columns.tail :_*), false)))
        .drop(columns :_*)
}

This should give you what you want. In fact for each day I aggregate the 24 hours into an array column, that I sort in desc order and from which I select the first 3 elements.

1
votes

Define pattern:

val p = "^(d[1-7])h[0-9]{1,2}$".r

Group columns:

import org.apache.spark.sql.functions._

val cols = df.columns.tail
  .groupBy { case p(d) => d }
  .map { case (c, cs) =>  {
    val sorted = sort_array(array(cs map col: _*), false)
    array(sorted(0), sorted(1), sorted(2)).as(c)
  }}

And select:

df.select($"id" +: cols.toSeq: _*)