0
votes

I am attempting to build a spark function that recursively re-writes ArrayType columns:

import org.apache.spark.sql.{DataFrame, Column}
import org.apache.spark.sql.functions._

val arrayHead = udf((sequence: Seq[String]) => sequence.head)
val arrayTail = udf((sequence: Seq[String]) => sequence.tail)

// re-produces the ArrayType column recursively
val rewriteArrayCol = (c: Column) => {

  def helper(elementsRemaining: Column, outputAccum: Column): Column = {

    when(size(elementsRemaining) === lit(0), outputAccum)
    .otherwise(helper(arrayTail(elementsRemaining), concat(outputAccum, array(arrayHead(elementsRemaining)))))
  }

  helper(c, array())
}


// Test
val df = 
  Seq("100"  -> Seq("a", "b", "b", "b", "b", "b", "c", "c", "d"))
  .toDF("id", "sequence")
//  .withColumn("test_tail", arrayTail($"sequence"))   //head & tail udfs work
//  .withColumn("test", rewriteArrayCol($"sequence"))  //stackoverflow if uncommented

display(df)

Unfortunately, I keep getting a stackoverflow. One area where I believe the function is lacking is that it's not tail-recursive; i.e. the whole 'when().otherwise()' block is not the same as an 'if else' block. That being said, the function currently throws a stackoverflow when applied to even tiny dataframes (so I figure there is must be more wrong with it than just not being tail-recursive).

I have not been able to find any examples of a similar function online, so I thought I'd ask here. The only implementations of Column => Column functions that I've been able to find are very, very simple ones which were not helpful to this use-case.

Note: I am able to achieve the functionality of the above by using a UDF. The reason I am attempting to make a Column => Column function is because Spark is better able to optimize these compared to UDFs (as far as I am aware).

1

1 Answers

2
votes

That's not going to work, because there is no meaningful stop condition here. when / otherwise are not language level control flow blocks (hence cannot break execution), and the function will simply recurse forever.

In fact it won't stop even for an empty array, outside any SQL evaluation context:

rewriteArrayCol(array())

Furthermore you assumption is incorrect. Skipping over the fact that your code deserializes data twice (once for each arrayHead, arrayTail) which is way worse than just calling udf once (though it could be avoided with slice), very complex expressions come with their own issues, one of which is code generation size limit.

Don't despair though - there is already a valid solution out there - which is transform. See How to use transform higher-order function?