0
votes

I need a udf function to input array column of dataframe and perform equality check of two string elements in it. My dataframe has a schema like this.

ID date options
1 2021-01-06 ['red', 'green']
2 2021-01-07 ['Blue', 'Blue']
3 2021-01-08 ['Blue', 'Yellow']
4 2021-01-09 nan

I have tried this :

def equality_check(options: list):
  try:
   if options[0] == options[1]:
     return 1
   else:
     return 0
  except:
     return -1

equality_udf = f.udf(equality_check, t.IntegerType())

But it was throwing out of index error. I am confident that options column is array of strings. the expectation is this:

ID date options equality_check
1 2021-01-06 ['red', 'green'] 0
2 2021-01-07 ['Blue', 'Blue'] 1
3 2021-01-08 ['Blue', 'Yellow'] 0
4 2021-01-09 nan -1
1

1 Answers

1
votes

You can check if the options list is defined or its length is inferior to 2 instead of using try/except. Here's a working example:

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

data = [
    (1, "2021-01-06", ['red', 'green']),
    (2, "2021-01-07", ['Blue', 'Blue']),
    (3, "2021-01-08", ['Blue', 'Yellow']),
    (4, "2021-01-09", None),
]
df = spark.createDataFrame(data, ["ID", "date", "options"])

def equality_check(options: list):
    if not options or len(options) < 2:
        return -1

    return int(options[0] == options[1])

equality_udf = F.udf(equality_check, IntegerType())

df1 = df.withColumn("equality_check", equality_udf(F.col("options")))
df1.show()

#+---+----------+--------------+--------------+
#| ID|      date|       options|equality_check|
#+---+----------+--------------+--------------+
#|  1|2021-01-06|  [red, green]|             0|
#|  2|2021-01-07|  [Blue, Blue]|             1|
#|  3|2021-01-08|[Blue, Yellow]|             0|
#|  4|2021-01-09|          null|            -1|
#+---+----------+--------------+--------------+

However, I advise you to not use UDF as you can do the same using only built-in functions :

df1 = df.withColumn(
    "equality_check",
    F.when(F.size(F.col("options")) < 2, -1)
        .when(F.col("options")[0] == F.col("options")[1], 1)
        .otherwise(0)
)