I have a Spark dataframe (Pyspark 2.2.0) that contains events, each has a timestamp. There is an additional column that contains series of tags (A,B,C or Null). I would like to calculate for each row - by group of events, ordered by timestamp - a count of the current longest stretch of changes of non Null tags (Null should reset this count to 0). Example of df with my ideal calculated column called stretch:
event timestamp tag stretch
G1 09:59:00 Null 0
G1 10:00:00 A 1 ---> first non Null tag starts the count
G1 10:01:00 A 1 ---> no change of tag
G1 10:02:00 B 2 ---> change of tag (A to B)
G1 10:03:00 A 3 ---> change of tag (B to A)
G1 10:04:00 Null 0 ---> Null resets the count
G1 10:05:00 A 1 ---> first non Null tag restarts the count
G2 10:00:00 B 1 ---> first non Null tag starts the count
G2 10:01:00 C 2 ---> change of tag (B to C)
In Pyspark I can define a window like this:
window = Window.partitionBy("event").orderBy(col("timestamp").asc())
and calculate for example the change of tag:
df = df.withColumn("change_of_tag",col("tag")!=lag("tag",1).over(window))
but I cannot find how to calculate a cumulative sum of these changes that would reset each time a Null tag is encountered. I suspect that I should define a new window partitioned by event and type of tag (Null or not Null) but I don't know how to partition by event, order by timestamp and after that, group by type of tag.