3
votes

I am writing a udf which will take two of the dataframe columns along with an extra parameter (a constant value) and should add a new column to the dataframe. My function looks like:

def udf_test(column1, column2, constant_var):
    if column1 == column2:
        return column1
    else:
        return constant_var

also, I am doing the following to pass in multiple columns:

apply_test = udf(udf_test, StringType())
df = df.withColumn('new_column', apply_test('column1', 'column2'))

This does not work right now unless I remove the constant_var as my functions third argument but I really need that. So I have tried to do something like the following:

constant_var = 'TEST'
apply_test = udf(lambda x: udf_test(x, constant_var), StringType())
df = df.withColumn('new_column', apply_test(constant_var)(col('column1', 'column2')))

and

apply_test = udf(lambda x,y: udf_test(x, y, constant_var), StringType())

None of the above have worked for me. I got those ideas based on this and this stackoverflow posts and I think it is obvious how my question is different from both of the. Any help would be much appreciated.

NOTE: I have simplified the function here just for the sake of discussion and the actual function is more complex. I know this operation could be done using when and otherwise statements.

1
You can use .when() and .otherwise(), right?pvy4917
@Prazy the function is actually more complicated and I have changed it to this just for the sake of simplifying the problem. but you are right, in that case I can use when and otherwiseahajib
What is constant_var?pvy4917

1 Answers

7
votes

You do not have to use an user-defined function. You can use the functions when() and otherwise():

from pyspark.sql import functions as f
df = df.withColumn('new_column', 
                   f.when(f.col('col1') == f.col('col2'), f.col('col1'))
                    .otherwise('other_value'))

Another way to do it is to generate a user-defined function. However, using udf's has a negative impact on the performance since the data must be (de)serialized to and from python. To generate a user-defined function, you need a function that returns a (user-defined) function. For example:

def generate_udf(constant_var):
    def test(col1, col2):
        if col1 == col2:
            return col1
        else:
            return constant_var
    return f.udf(test, StringType())

df = df.withColumn('new_column', 
                   generate_udf('default_value')(f.col('col1'), f.col('col2')))