1
votes

So, I've done enough research and haven't found a post that addresses what I want to do.

I have a PySpark DataFrame my_df which is sorted by value column-

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
|   F|   15|
|   G|   10|
+----+-----+

The summation of all the counts in value column is equal to 136. I want to get all the rows whose combined values >= x% of 136. In this example, let's say x=80. Then target sum = 0.8*136 = 108.8. Hence, the new DataFrame will consist of all the rows that have a combined value >= 108.8.

In our example, this would come down to row D (since combined values upto D = 30+25+20+18 = 93).

However, the hard part is that I also want to include the immediately following rows with duplicate values. In this case, I also want to include row E since it has the same value as row D i.e. 18.

I want to slice my_df by giving a percentage x variable, for example 80 as discussed above. The new DataFrame should consist of the following rows-

+----+-----+                                                                    
|name|value|
+----+-----+
|   A|   30|
|   B|   25|
|   C|   20|
|   D|   18|
|   E|   18|
+----+-----+

One thing I could do here is iterate through the DataFrame (which is ~360k rows), but I guess that defeats the purpose of Spark.

Is there a concise function for what I want here?

2
added a better descriptionkev
Can you share the code that yo use to sort the DataFrame? Is it based on the value? Or value and name?pault
It's sorted by valuekev

2 Answers

2
votes

Your requirements are quite strict, so it's difficult to formulate an efficient solution to your problem. Nevertheless, here is one approach:

First calculate the cumulative sum and the total sum for the value column and filter the DataFrame using the percentage of target condition you specified. Let's call this result df_filtered:

import pyspark.sql.functions as f
from pyspark.sql import Window

w = Window.orderBy(f.col("value").desc(), "name").rangeBetween(Window.unboundedPreceding, 0)
target = 0.8

df_filtered = df.withColumn("cum_sum", f.sum("value").over(w))\
    .withColumn("total_sum", f.sum("value").over(Window.partitionBy()))\
    .where(f.col("cum_sum") <= f.col("total_sum")*target)

df_filtered.show()
#+----+-----+-------+---------+
#|name|value|cum_sum|total_sum|
#+----+-----+-------+---------+
#|   A|   30|     30|      136|
#|   B|   25|     55|      136|
#|   C|   20|     75|      136|
#|   D|   18|     93|      136|
#+----+-----+-------+---------+

Then join this filtered DataFrame back on the original on the value column. Since your DataFrame is already sorted by value, the final output will contain the rows you want.

df.alias("r")\
    .join(
    df_filtered.alias('l'),
    on="value"
).select("r.name", "r.value").sort(f.col("value").desc(), "name").show()
#+----+-----+
#|name|value|
#+----+-----+
#|   A|   30|
#|   B|   25|
#|   C|   20|
#|   D|   18|
#|   E|   18|
#+----+-----+

The total_sum and cum_sum columns are calculated using a Window function.

The Window w orders on the value column descending, followed by the name column. The name column is used to break ties- without it, both rows C and D would have the same cumulative sum of 111 = 75+18+18 and you'd incorrectly lose both of them in the filter.

w = Window\                                     # Define Window
    .orderBy(                                   # This will define ordering
        f.col("value").desc(),                  # First sort by value descending
        "name"                                  # Sort on name second
    )\
    .rangeBetween(Window.unboundedPreceding, 0) # Extend back to beginning of window

The rangeBetween(Window.unboundedPreceding, 0) specifies that the Window should include all rows before the current row (defined by the orderBy). This is what makes it a cumulative sum.

3
votes

Use pyspark SQL functions to do this concisely.

result = my_df.filter(my_df.value > target).select(my_df.name,my_df.value)
result.show()

Edit: Based on OP's question edit - Compute running sum and get rows until the target value is reached. Note that this will result in rows upto D, not E..which seems like a strange requirement.

from pyspark.sql import Window
from pyspark.sql import functions as f

# Total sum of all `values`
target = (my_df.agg(sum("value")).collect())[0][0]

w = Window.orderBy(my_df.name) #Ideally this should be a column that specifies ordering among rows
running_sum_df = my_df.withColumn('rsum',f.sum(my_df.value).over(w))
running_sum_df.filter(running_sum_df.rsum <= 0.8*target)