I am attempting to build a spark function that recursively re-writes ArrayType columns:
import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.functions._
val arrayHead = udf((sequence: Seq[String]) => sequence.head)
val arrayTail = udf((sequence: Seq[String]) => sequence.tail)
// re-produces the ArrayType column recursively
val rewriteArrayCol = (c: Column) => {
def helper(elementsRemaining: Column, outputAccum: Column): Column = {
when(size(elementsRemaining) === lit(0), outputAccum)
.otherwise(helper(arrayTail(elementsRemaining), concat(outputAccum, array(arrayHead(elementsRemaining)))))
}
helper(c, array())
}
// Test
val df =
Seq("100" -> Seq("a", "b", "b", "b", "b", "b", "c", "c", "d"))
.toDF("id", "sequence")
// .withColumn("test_tail", arrayTail($"sequence")) //head & tail udfs work
// .withColumn("test", rewriteArrayCol($"sequence")) //stackoverflow if uncommented
display(df)
Unfortunately, I keep getting a stackoverflow. One area where I believe the function is lacking is that it's not tail-recursive; i.e. the whole 'when().otherwise()' block is not the same as an 'if else' block. That being said, the function currently throws a stackoverflow when applied to even tiny dataframes (so I figure there is must be more wrong with it than just not being tail-recursive).
I have not been able to find any examples of a similar function online, so I thought I'd ask here. The only implementations of Column => Column functions that I've been able to find are very, very simple ones which were not helpful to this use-case.
Note: I am able to achieve the functionality of the above by using a UDF. The reason I am attempting to make a Column => Column function is because Spark is better able to optimize these compared to UDFs (as far as I am aware).