0
votes

I have a Spark dataframe which looks a bit like this:

id  country  date        action
 1    A   2019-01-01   suppress
 1    A   2019-01-02   suppress
 2    A   2019-01-03   bid-up
 2    A   2019-01-04   bid-down
 3    C   2019-01-01   no-action
 3    C   2019-01-02   bid-up
 4    D   2019-01-01   suppress

I want to reduce this dataframe by grouping by id, country and collecting the unique values of the 'action' column into an array, but this array should be ordered by the date column.

E.g.

id  country action_arr
 1    A      [suppress]
 2    A      [bid-up, bid-down]
 3    C      [no-action, bid-up]
 4    D      [suppress]

To explain this a little more concisely i have some SQL (presto) code that does exactly what i want... i'm just struggling to do this in PySpark or SparkSQL:

SELECT id, country, array_distinct(array_agg(action ORDER BY date ASC)) AS actions
FROM table
GROUP BY id, country

Now here's my attempt in PySpark:

from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('action').orderBy('date')

sorted_list_df = df.withColumn('sorted_list', F.collect_set('action').over(w))

Then I want to find out the number of occurrences of each set of actions by group:

df = sorted_list_df.select('country', 'sorted_list').groupBy('coutry', 'sorted_list').agg(F.count('sorted_list'))

The code runs but in the output he sorted_list column is basically the same as action without any array aggregation..Can someone help?

EDIT: I managed to pretty much get what i want.. but the results don't fully match the presto results. Can anyone explain why? Solution below:

from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('action').orderBy('date')

df_2 = df.withColumn("sorted_list", F.collect_set("action").over(Window.partitionBy("id").orderBy("date")))

test = df_2.select('id', 'country', 'sorted_list')\
           .dropDuplicates()\
           .select('country', 'sorted_list')\
           .groupBy('site_name', 'sorted_list')\
           .agg(F.count('sorted_list'))
1
Why is it returning wrong results, and what kind of results would you deem as correct?Grzegorz Skibinski
@GrzegorzSkibinski - edited description to clarifyTim496

1 Answers

0
votes

IMO, your window definition is wrong. You should partition by the column with which you want to make groups of, and then collect a set of unique values per group.

IIUC, you just need to do:

w = Window.partitionBy(['id', 'country']).orderBy('date')

sorted_list_df = df.withColumn('sorted_list', F.collect_set('action').over(w))

df_new = sorted_list_df.select('id', 'country', 'sorted_list').withColumn("count_of_elems", F.size("sorted_list"))

DRAWBACK:

If you use a window, you will have a new set for every row, and your row count is going to be the same as the old df. There is not going to be an aggregation per se, since I don't think that's not what you want too.

This next line aggregates the values as a set for every group. I'm hoping it gets you exactly what you want:

df_new = sorted_list_df.groupby('id', 'country').agg(F.max('sorted_list').alias('sorted_list')).withColumn("count_of_elems", F.size("sorted_list"))