2
votes

I have PySpark code that effectively groups up rows numerically, and increments when a certain condition is met. I'm having trouble figuring out how to transform this code, efficiently, into one that can be applied to groups.

Take this sample dataframe df

df = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01'),
        (33, ['apple', 'orange'], '2017-01-02'),
        (33, [], '2017-01-03'),
        (33, ['banana'], '2017-01-04')
    ],
    ('ID', 'X', 'date')
)

This code achieves what I want for this sample df, which is to order by date and to create groups ('grp') that increment when the size column goes back to 0.

df \
.withColumn('size', size(col('X'))) \
.withColumn(
    "grp", 
    sum((col('size') == 0).cast("int")).over(Window.orderBy('date'))
).show()

This is partly based on Pyspark - Cumulative sum with reset condition

Now what I am trying to do is apply the same approach to a dataframe that has multiple IDs - achieving a result that looks like

df2 = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01', 0, 1),
        (33, ['apple', 'orange'], '2017-01-02', 2, 1),
        (33, [], '2017-01-03', 0, 2),
        (33, ['banana'], '2017-01-04', 1, 2),
        (55, ['coffee'], '2017-01-01', 1, 1),
        (55, [], '2017-01-03', 0, 2)
    ],
    ('ID', 'X', 'date', 'size', 'group')
)

edit for clarity

1) For the first date of each ID - the group should be 1 - regardless of what shows up in any other column.

2) However, for each subsequent date, I need to check the size column. If the size column is 0, then I increment the group number. If it is any non-zero, positive integer, then I continue the previous group number.

I've seen a few way to handle this in pandas, but I'm having difficulty understanding the applications in pyspark and the ways in which grouped data is different in pandas vs spark (e.g. do I need to use something called UADFs?)

2
Why does coffee row has a group value of 1? Shouldn't it be 0?cronoik
@cronoik it is 1 because the ID has changed, and this row is the first date for that ID - so it should be Group 1.mcharl02
So the start value is the size of X and it is incremented everytime the size of the X is 0?cronoik
Start value is always 1 - for the first ID + date. Then that value increments only when the size of X is 0. The size of X for the first ID + date doesn't matter (in my actual data, it will always be 0 or missing, but I have tried to simplify that here). I will edit this info into the main questionmcharl02

2 Answers

1
votes

Create a column zero_or_first by checking whether the size is zero or the row is the first row. Then sum.

df2 = sqlContext.createDataFrame(
    [
        (33, [], '2017-01-01', 0, 1),
        (33, ['apple', 'orange'], '2017-01-02', 2, 1),
        (33, [], '2017-01-03', 0, 2),
        (33, ['banana'], '2017-01-04', 1, 2),
        (55, ['coffee'], '2017-01-01', 1, 1),
        (55, [], '2017-01-03', 0, 2),
        (55, ['banana'], '2017-01-01', 1, 1)
    ],
    ('ID', 'X', 'date', 'size', 'group')
)


w = Window.partitionBy('ID').orderBy('date')
df2 = df2.withColumn('row', F.row_number().over(w))
df2 = df2.withColumn('zero_or_first', F.when((F.col('size')==0)|(F.col('row')==1), 1).otherwise(0))
df2 = df2.withColumn('grp', F.sum('zero_or_first').over(w))
df2.orderBy('ID').show()

Here' the output. You can see that column group == grp. Where group is the expected results.

+---+---------------+----------+----+-----+---+-------------+---+
| ID|              X|      date|size|group|row|zero_or_first|grp|
+---+---------------+----------+----+-----+---+-------------+---+
| 33|             []|2017-01-01|   0|    1|  1|            1|  1|
| 33|       [banana]|2017-01-04|   1|    2|  4|            0|  2|
| 33|[apple, orange]|2017-01-02|   2|    1|  2|            0|  1|
| 33|             []|2017-01-03|   0|    2|  3|            1|  2|
| 55|       [coffee]|2017-01-01|   1|    1|  1|            1|  1|
| 55|       [banana]|2017-01-01|   1|    1|  2|            0|  1|
| 55|             []|2017-01-03|   0|    2|  3|            1|  2|
+---+---------------+----------+----+-----+---+-------------+---+


0
votes

I added a window function, and created an index within each ID. Then I expanded the conditional statement to also reference that index. The following seems to produce my desired output dataframe - but I am interested in knowing if there is a more efficient way to do this.

window = Window.partitionBy('ID').orderBy('date')
df \
.withColumn('size', size(col('X'))) \
.withColumn('index', rank().over(window).alias('index')) \
.withColumn(
    "grp", 
    sum(((col('size') == 0) | (col('index') == 1)).cast("int")).over(window)
).show()

which yields

+---+---------------+----------+----+-----+---+
| ID|              X|      date|size|index|grp|
+---+---------------+----------+----+-----+---+
| 33|             []|2017-01-01|   0|    1|  1|
| 33|[apple, orange]|2017-01-02|   2|    2|  1|
| 33|             []|2017-01-03|   0|    3|  2|
| 33|       [banana]|2017-01-04|   1|    4|  2|
| 55|       [coffee]|2017-01-01|   1|    1|  1|
| 55|             []|2017-01-03|   0|    2|  2|
+---+---------------+----------+----+-----+---+