0
votes

I have a pyspark dataframe which has 4 columns.

+-----+-----+-----+-----+
|col1 |col2 |col3 |col4 |
+-----+-----+-----+-----+
|10   | 5.0 | 5.0 | 5.0 |
|20   | 5.0 | 5.0 | 5.0 |
|null | 5.0 | 5.0 | 5.0 |
|30   | 5.0 | 5.0 | 6.0 |
|40   | 5.0 | 5.0 | 7.0 |
|null | 5.0 | 5.0 | 8.0 |
|50   | 5.0 | 6.0 | 9.0 |
|60   | 5.0 | 7.0 | 10.0|
|null | 5.0 | 8.0 | 11.0|
|70   | 6.0 | 9.0 | 12.0|
|80   | 7.0 | 10.0| 13.0|
|null | 8.0 | 11.0| 14.0|
+-----+-----+-----+-----+

Some values in the col1 are missing and I want to set those missing values based on the following approach:

try to set it based on the average of values of col1 of the records that have the same col2,col3,col4 values

if there is no such record, set it based on the average of values of col1 of the records that have the same col2,col3 values

if there is still no such record, set it based on the average of values of col1 of the records that have the same col2 values

If none of the above could be found, set it to the average of all other non-missing values in col1

For example, given the dataframe above, only the first two rows have the same col2, col3, col4 values as row 3. So the null value in col1 for row 3 should be replaced by the average of col1 values in row 1 and 2. For null value in col1 in row 6, it will be the average of col1 values in row 4 and 5, because only those rows have the same col2 and col3 values and not the same col4 values as row 6. And the list goes on...

+-----+-----+-----+-----+
|col1 |col2 |col3 |col4 |
+-----+-----+-----+-----+
|10   | 5.0 | 5.0 | 5.0 |
|20   | 5.0 | 5.0 | 5.0 |
|15   | 5.0 | 5.0 | 5.0 |
|30   | 5.0 | 5.0 | 6.0 |
|40   | 5.0 | 5.0 | 7.0 |
|25   | 5.0 | 5.0 | 8.0 |
|50   | 5.0 | 6.0 | 9.0 |
|60   | 5.0 | 7.0 | 10.0|
|35   | 5.0 | 8.0 | 11.0|
|70   | 6.0 | 9.0 | 12.0|
|80   | 7.0 | 10.0| 13.0|
|45   | 8.0 | 11.0| 14.0|
+-----+-----+-----+-----+

What's the best way to do this?

1
For null value in col1 in row 6 why are you not counting line 1,2,3 ? It seems that your ouput is not following your rules.Steven
@Steven You are right, I edited the question.shahram kalantari

1 Answers

1
votes

I do not find exactly the same values than you do but, based on what you said, the code would be something like this :

from pyspark.sql import functions as F

df_2_3_4 = df.groupBy("col2", "col3", "col4").agg(
    F.avg("col1").alias("avg_col1_by_2_3_4")
)
df_2_3 = df.groupBy("col2", "col3").agg(F.avg("col1").alias("avg_col1_by_2_3"))
df_2 = df.groupBy("col2").agg(F.avg("col1").alias("avg_col1_by_2"))
avg_value = df.groupBy().agg(F.avg("col1").alias("avg_col1")).first().avg_col1


df_out = (
    df.join(df_2_3_4, how="left", on=["col2", "col3", "col4"])
    .join(df_2_3, how="left", on=["col2", "col3"])
    .join(df_2, how="left", on=["col2"])
)

df_out.select(
    F.coalesce(
        F.col("col1"),
        F.col("avg_col1_by_2_3_4"),
        F.col("avg_col1_by_2_3"),
        F.col("avg_col1_by_2"),
        F.lit(avg_value),
    ).alias("col1"),
    "col2",
    "col3",
    "col4",
).show()

+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|10.0| 5.0| 5.0| 5.0|
|15.0| 5.0| 5.0| 5.0|
|20.0| 5.0| 5.0| 5.0|
|30.0| 5.0| 5.0| 6.0|
|40.0| 5.0| 5.0| 7.0|
|25.0| 5.0| 5.0| 8.0|
|50.0| 5.0| 6.0| 9.0|
|60.0| 5.0| 7.0|10.0|
|35.0| 5.0| 8.0|11.0|
|70.0| 6.0| 9.0|12.0|
|80.0| 7.0|10.0|13.0|
|45.0| 8.0|11.0|14.0|
+----+----+----+----+