3
votes

I have a Spark dataframe, like so:

# For sake of simplicity only one user (uid) is shown, but there are multiple users 
+-------------------+-----+-------+
|start_date         |uid  |count  |
+-------------------+-----+-------+
|2020-11-26 08:30:22|user1|  4    |
|2020-11-26 10:00:00|user1|  3    |
|2020-11-22 08:37:18|user1|  3    |
|2020-11-22 13:32:30|user1|  2    |
|2020-11-20 16:04:04|user1|  2    |
|2020-11-16 12:04:04|user1|  1    |

I want to create a new boolean column where the values are True/False if the user had at least count >= x events in the past, and mark these events with a True. For example, for x=3 I expect to get:

+-------------------+-----+-------+--------------+
|start_date         |uid  |count  | marked_event |
+-------------------+-----+-------+--------------+
|2020-11-26 08:30:22|user1|  4    |  True        |
|2020-11-26 10:00:00|user1|  3    |  True        |
|2020-11-22 08:37:18|user1|  3    |  True        |
|2020-11-22 13:32:30|user1|  2    |  True        |
|2020-11-20 16:04:04|user1|  2    |  True        |
|2020-11-16 12:04:04|user1|  1    |  False       |

That is, for each count >= 3, I need to mark that event with True, and also the previous 3-events. Only the last event of user1 is False, because I mark 3 events before (and including) the event on start_date = 2020-11-22 08:37:18.

Any ideas how to approach this? My intuition is to somehow use window/lag to achieve this, but I'm new to spark and not sure how to do it...


EDIT:

I ended using a variation on @mck's solution, with a small bug-fix: The original solution has:

F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing))

condition, which ends up marking all events after 'begin', regardless of the conditions of 'count' being fulfilled or not. Instead I changed the solution so that the window would only mark events that happened before 'begin':

event = (f.max(f.col('begin')).over(w.rowsBetween(-2, 0))).\ 
          alias('event_post_only') 
# the number of events to mark is 3 from 'begin', 
# including the event itself, so that's -2.
df_marked_events = df_marked_events.select('*', event)

Then mark True for all events that were True in 'event_post_only' OR were True in 'event_post_only'

df_marked_events = df_marked_events.withColumn('event', (col('count') >= 3) \
                       | (col('event_post_only')))

This avoids marking True to everything upstream to 'begin' == True

1

1 Answers

1
votes
import pyspark.sql.functions as F
from pyspark.sql.window import Window

w = Window.partitionBy('uid').orderBy(F.col('count').desc(), F.col('start_date'))

# find the beginning point of >= 3 events
begin = (
    (F.col('count') >= 3) &
    (F.lead(F.col('count')).over(w) < 3)
).alias('begin')
df = df.select('*', begin)

# Mark as event if the event is in any rows after begin, or two rows before begin
event = (
    F.max(F.col('begin')).over(w.rowsBetween(0, Window.unboundedFollowing)) | 
    F.max(F.col('begin')).over(w.rowsBetween(-2,0))
).alias('event')
df = df.select('*', event)

df.show()
+-------------------+-----+-----+-----+-----+
|         start_date|  uid|count|begin|event|
+-------------------+-----+-----+-----+-----+
|2020-11-26 08:30:22|user1|  4.0|false| true|
|2020-11-22 08:37:18|user1|  3.0|false| true|
|2020-11-26 10:00:00|user1|  3.0| true| true|
|2020-11-20 16:04:04|user1|  2.0|false| true|
|2020-11-22 13:32:30|user1|  2.0|false| true|
|2020-11-16 12:04:04|user1|  1.0|false|false|
+-------------------+-----+-----+-----+-----+