4
votes

Context: I need to filter a dataframe based on what contains another dataframe's column using the isin function.

For Python users working with pandas, that would be: isin().
For R users, that would be: %in%.

So I have a simple spark dataframe with id and value columns:

l = [(1, 12), (1, 44), (1, 3), (2, 54), (3, 18), (3, 11), (4, 13), (5, 78)]
df = spark.createDataFrame(l, ['id', 'value'])
df.show()

+---+-----+
| id|value|
+---+-----+
|  1|   12|
|  1|   44|
|  1|    3|
|  2|   54|
|  3|   18|
|  3|   11|
|  4|   13|
|  5|   78|
+---+-----+

I want to get all ids that appear multiple times. Here's a dataframe of unique ids in df:

unique_ids = df.groupBy('id').count().where(col('count') < 2)
unique_ids.show()

+---+-----+
| id|count|
+---+-----+
|  5|    1|
|  2|    1|
|  4|    1|
+---+-----+

So the logical operation would be:

 df = df[~df.id.isin(unique_ids.id)]
 # This is the same than:
 df = df[df.id.isin(unique_ids.id) == False]

However, I get an empty dataframe:

df.show()

+---+-----+
| id|value|
+---+-----+
+---+-----+ 

This "error" works in the opposite way:

df[df.id.isin(unique_ids.id)]

returns all the rows of df.

1
Don't use isin here- use join. For example: df.join(unique_ids, on="id").show(). You can only use isin with literal values (ex: df.where(df["id"].isin([1, 2, 3]))), not with a column.pault

1 Answers

11
votes

The expression df.id.isin(unique_ids.id) == False is evaluating if Column<b'((id IN (id)) = false)'> and this will never happen because id is in id. However, the expression df.id.isin(unique_ids.id) is evaluating if Column<b'(id IN (id))'>, and this is always true, for that reason it returns the whole data frame. unique_ids.id is a Column not a list.

isin(*cols) receives a list of values as an argument, not a column, so, to work in this way, you should execute the following:

ids = unique_ids.rdd.map(lambda x:x.id).collect()
df[df.id.isin(ids)].collect() # or show...

and you will obtain:

[Row(id=2, value=54), Row(id=4, value=13), Row(id=5, value=78)]

In any case, I think it would be better if you join both data frames:

df_ = df.join(unique_ids, on='id')

getting:

df_.show()
+---+-----+-----+
| id|value|count|
+---+-----+-----+
|  5|   78|    1|
|  2|   54|    1|
|  4|   13|    1|
+---+-----+-----+