
I want group a column based on unique values from two columns of pyspark dataframe. The output of the dataframe should be such that once some value used for groupby and if it is present in another column then it should not repeat.

    |   fruit          |     fruits        | 
    |    apple         |     banana        |
    |    banana        |     apple         |
    |    apple         |     mango         |
    |    orange        |     guava         |
    |    apple         |    pineapple      |
    |    mango         |    apple          |
    |   banana         |     mango         |
    |   banana         |    pineapple      |
    | -------------------------------------|

I have tried to group by using single column and it needs to be modified or some other logic should be required.


I am getting following output from above query;

       |   fruit          |     values                     | 
       |  apple           | ['banana','mango','pineapple'] |
       |  banana          | ['apple']                      |
       |  orange          | ['guava']                      |
       |  mango           | ['apple']                      |

But I want following output;

       |   fruit          |     values                     | 
       |  apple           | ['banana','mango','pineapple'] |
       |  orange          | ['guava']                      |
This really depends on the order of your dataframe right? Am I correct to assume that if banana would be processed before apple then banana would still be there?Laurens Koppenol
If once the value is grouped then from that grouped data no value should be repeated i.e. if banana comes first then it should be [banana | ['apple','mango','pineapple'] and then [orange | ['guava'] be the outputamol desai
spark dataframes do not have a guaranteed order, does that matter for your outcome?Laurens Koppenol
No. It doesn't matter for outcomeamol desai

1 Answers


This looks like a connected components problem. There are a couple ways you can go about doing this.

1. GraphFrames

You can use the GraphFrames package. Each row of your dataframe defines an edge, and you can just create a graph using df as edges and a dataframe of all the distinct fruits as vertices. Then call the connectedComponents method. You can then manipulate the output to get what you want.

2. Just Pyspark

The second method is a bit of a hack. Create a "hash" for each row like

hashed_df = df.withColumn('hash', F.sort_array(F.array(F.col('fruit'), F.col('fruits'))))

Drop all non-distinct rows for that column

distinct_df = hashed_df.dropDuplicates(['hash'])

Split up the items again

revert_df = distinct_df.withColumn('fruit', F.col('hash')[0]) \
    .withColumn('fruits', F.col('hash')[1])

Group by the first column

grouped_df = revert_df.groupBy('fruit').agg(F.collect_list('fruits').alias('group'))

You might need to "stringify" your hash usingF.concat_ws if Pyspark complains, but the idea is the same.