11
votes

I need to implement the below SQL logic in Spark DataFrame

SELECT KEY,
    CASE WHEN tc in ('a','b') THEN 'Y'
         WHEN tc in ('a') AND amt > 0 THEN 'N'
         ELSE NULL END REASON,
FROM dataset1;

My input DataFrame is as below:

val dataset1 = Seq((66, "a", "4"), (67, "a", "0"), (70, "b", "4"), (71, "d", "4")).toDF("KEY", "tc", "amt")

dataset1.show()
+---+---+---+
|KEY| tc|amt|
+---+---+---+
| 66|  a|  4|
| 67|  a|  0|
| 70|  b|  4|
| 71|  d|  4|
+---+---+---+

I have implement the nested case when statement as:

dataset1.withColumn("REASON", when(col("tc").isin("a", "b"), "Y")
  .otherwise(when(col("tc").equalTo("a") && col("amt").geq(0), "N")
    .otherwise(null))).show()
+---+---+---+------+
|KEY| tc|amt|REASON|
+---+---+---+------+
| 66|  a|  4|     Y|
| 67|  a|  0|     Y|
| 70|  b|  4|     Y|
| 71|  d|  4|  null|
+---+---+---+------+

Readability of the above logic with "otherwise" statement is little messy if the nested when statements goes further.

Is there any better way of implementing nested case when statements in Spark DataFrames?

3

3 Answers

23
votes

There is no nesting here, therefore there is no need for otherwise. All you need is chained when:

import spark.implicits._

when($"tc" isin ("a", "b"), "Y")
  .when($"tc" === "a" && $"amt" >= 0, "N")

ELSE NULL is implicit so you can omit it completely.

Pattern you use, is more more applicable for folding over a data structure:

val cases = Seq(
  ($"tc" isin ("a", "b"), "Y"),
  ($"tc" === "a" && $"amt" >= 0, "N")
)

where when - otherwise naturally follows recursion pattern and null provides the base case.

cases.foldLeft(lit(null)) {
  case (acc, (expr, value)) => when(expr, value).otherwise(acc)
}

Please note, that it is impossible to reach "N" outcome, with this chain of conditions. If tc is equal to "a" it will be captured by the first clause. If it is not, it will fail to satisfy both predicates and default to NULL. You should rather:

when($"tc" === "a" && $"amt" >= 0, "N")
 .when($"tc" isin ("a", "b"), "Y")
3
votes

For more complex logic, I prefer to use UDFs for better readability:

val selectCase = udf((tc: String, amt: String) =>
  if (Seq("a", "b").contains(tc)) "Y"
  else if (tc == "a" && amt.toInt <= 0) "N"
  else null
)


dataset1.withColumn("REASON", selectCase(col("tc"), col("amt")))
  .show
0
votes

you can simply use selectExpr on your dataset

dataset1.selectExpr("*", "CASE WHEN tc in ('a') AND amt > 0 THEN 'N' WHEN tc in ('a','b') THEN 'Y' ELSE NULL END
REASON").show()

+---+---+---+------+
|KEY| tc|amt|REASON|
+---+---+---+------+
| 66|  a|  4|     N|
| 67|  a|  0|     Y|
| 70|  b|  4|     Y|
| 71|  d|  4|  null|
+---+---+---+------+

Second condition should be place before first one, as first condition is more generic one.

WHEN tc in ('a') AND amt > 0 THEN 'N'