I have a PySpark dataframe-
df1 = spark.createDataFrame([
("u1", 1),
("u1", 2),
("u2", 1),
("u2", 1),
("u2", 1),
("u3", 3),
],
['user_id', 'var1'])
print(df1.printSchema())
df1.show(truncate=False)
Output-
root
|-- user_id: string (nullable = true)
|-- var1: long (nullable = true)
None
+-------+----+
|user_id|var1|
+-------+----+
|u1 |1 |
|u1 |2 |
|u2 |1 |
|u2 |1 |
|u2 |1 |
|u3 |3 |
+-------+----+
Now I want to group all the unique users and show the number of unique var for them in a new column. The desired output would look like-
+-------+---------------+
|user_id|num_unique_var1|
+-------+---------------+
|u1 |2 |
|u2 |1 |
|u3 |1 |
+-------+---------------+
I can use collect_set and make a udf to find the set's length. But I think there must be a better way to do it. How do I achieve this in one line of code?