0
votes

I have a PySpark dataframe-

df1 = spark.createDataFrame([
    ("u1", 1),
    ("u1", 2),
    ("u2", 1),
    ("u2", 1),
    ("u2", 1),
    ("u3", 3),
    ],
    ['user_id', 'var1'])

print(df1.printSchema())
df1.show(truncate=False)

Output-

root
 |-- user_id: string (nullable = true)
 |-- var1: long (nullable = true)

None
+-------+----+
|user_id|var1|
+-------+----+
|u1     |1   |
|u1     |2   |
|u2     |1   |
|u2     |1   |
|u2     |1   |
|u3     |3   |
+-------+----+

Now I want to group all the unique users and show the number of unique var for them in a new column. The desired output would look like-

+-------+---------------+
|user_id|num_unique_var1|
+-------+---------------+
|u1     |2              |
|u2     |1              |
|u3     |1              |
+-------+---------------+

I can use collect_set and make a udf to find the set's length. But I think there must be a better way to do it. How do I achieve this in one line of code?

2

2 Answers

3
votes
df1.groupBy('user_id').agg(F.countDistinct('var1').alias('num')).show()

countDistinct is exactly what I needed.

Output-

+-------+---+
|user_id|num|
+-------+---+
|     u3|  1|
|     u2|  1|
|     u1|  2|
+-------+---+
2
votes

countDistinct is surely the best way to do it, but for the sake of completeness, what you said in your question is also possible without using an UDF. You can use size to get the length of the collect_set:

df1.groupBy('user_id').agg(F.size(F.collect_set('var1')).alias('num'))

this is helpful if you want to use it in a window function, because countDistinct is not supported in a window function.

e.g.

df1.withColumn('num', F.countDistinct('var1').over(Window.partitionBy('user_id')))

would fail, but

df1.withColumn('num', F.size(F.collect_set('var1')).over(Window.partitionBy('user_id')))

would work.