2
votes

I need to transform a Python script to Pyspark and it's being a tough task for me.

I'm trying to remove null values from a dataframe (without removing the entire column or row) and shift the next value to the prior column. Example:

        CLIENT| ANIMAL_1 | ANIMAL_2 | ANIMAL_3| ANIMAL_4
ROW_1     1   |   cow    | frog     | null    | dog
ROW_2     2   |   pig    | null     | cat     | null

My goal is to have:

       CLIENT| ANIMAL_1 | ANIMAL_2 | ANIMAL_3| ANIMAL_4
ROW_1     1   |   cow    | frog     | dog     | null
ROW_2     2   |   pig    | cat      | null    | null

The code I'm using on python is (which I got here on Stackoverflow):

df_out = df.apply(lambda x: pd.Series(x.dropna().to_numpy()), axis=1)

Then I rename the columns. But I have no idea how to do this on Pyspark.

1

1 Answers

2
votes

Here's a way to do this for Spark version 2.4+:

Create an array of the columns you want and sort by your conditions, which are the following:

  1. Sort non-null values first
  2. Sort values in the order they appear in the columns

We can do the sorting by using array_sort. To achieve the multiple conditions, use arrays_zip. To make it easy to extract the value you want (i.e. the animal in this example) zip column value as well.

from pyspark.sql.functions import array, array_sort, arrays_zip, col, lit

animal_cols = df.columns[1:]
N = len(animal_cols)

df_out = df.select(
    df.columns[0],
    array_sort(
        arrays_zip(
            array([col(c).isNull() for c in animal_cols]),
            array([lit(i) for i in range(N)]),
            array([col(c) for c in animal_cols])
        )
    ).alias('sorted')
)
df_out.show(truncate=False)
#+------+----------------------------------------------------------------+
#|CLIENT|sorted                                                          |
#+------+----------------------------------------------------------------+
#|1     |[[false, 0, cow], [false, 1, frog], [false, 3, dog], [true, 2,]]|
#|2     |[[false, 0, pig], [false, 2, cat], [true, 1,], [true, 3,]]      |
#+------+----------------------------------------------------------------+

Now that things are in the right order, you just need to extract the value. In this case, that's the item at element '2' in the i-th index of sorted column.

df_out = df_out.select(
    df.columns[0],
    *[col("sorted")[i]['2'].alias(animal_cols[i]) for i in range(N)]
)
df_out.show(truncate=False)
#+------+--------+--------+--------+--------+
#|CLIENT|ANIMAL_1|ANIMAL_2|ANIMAL_3|ANIMAL_4|
#+------+--------+--------+--------+--------+
#|1     |cow     |frog    |dog     |null    |
#|2     |pig     |cat     |null    |null    |
#+------+--------+--------+--------+--------+