1
votes

I wanted to filter out the rows that have zero values for all the columns in a list.

Suppose for example we have the follwing df,

df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
+---+---+---+---+---+                                                           
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  0|  0|  1|  0|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+

and the list of columns is ['a', 'b', 'd'] so the filtered dataframe should be,

+---+---+---+---+---+                                                           
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+

This is what I have tried,

df = df.withColumn('total', sum(df[col] for col in ['a', 'b', 'd']))
df = df.filter(df.total > 0).drop('total')

This works fine for small datasets but fails with the following error if the col_list is very long with the following error.

ava.lang.StackOverflowErrorat org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables.org$apache$spark$sql$catalyst$analysis$ResolveLambdaVariables$$resolve(higher...

I can think of a pandas udf solution but my df is very large and that might be a bottleneck.

Edit:

When using @Psidom's answer I get the following error

py4j.protocol.Py4JJavaError: An error occurred while calling o2508.filter. : java.lang.StackOverflowError at org.apache.spark.sql.catalyst.expressions.Expression.references(Expression.scala:88) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) at scala.collection.immutable.List.flatMap(List.scala:355)

4

4 Answers

2
votes

You can pass the columns as array to a UDF and then check if all values are zeros or not and then apply the filter:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf, array, col

all_zeros_udf = udf(lambda arr: arr.count(0) == len(arr), BooleanType())

df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])

df
.withColumn('all_zeros', all_zeros_udf(array('a', 'b', 'd'))) # pass the columns as array
.filter(~col('all_zeros')) # Filter the columns where all values are NOT zeros
.drop('all_zeros')  # Drop the column
.show()

Result:

+---+---+---+---+---+
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+
2
votes

functools.reduce could be useful here:

df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], 
     ['a', 'b', 'c', 'd', 'e'])
cols = ['a', 'b', 'd']

Use reduce to create the filter expression:

from functools import reduce
predicate = reduce(lambda a, b: a | b, [df[x] != 0 for x in cols])

print(predicate)
# Column<b'(((NOT (a = 0)) OR (NOT (b = 0))) OR (NOT (d = 0)))'>

Then filter with the predicate:

df.where(predicate).show()
+---+---+---+---+---+
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+
1
votes

Here is a different solution. Haven't tried on the large set of columns, please let me know if this works.

df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
df.show()

+---+---+---+---+---+
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
|  0|  1|  1|  2|  1|
|  0|  0|  1|  0|  1|
|  1|  0|  1|  1|  1|
+---+---+---+---+---+

df = df.withColumn("Concat_cols" , F.concat(*list_of_cols)) # concat the list of columns 
df.show()

+---+---+---+---+---+-----------+
|  a|  b|  c|  d|  e|Concat_cols|
+---+---+---+---+---+-----------+
|  0|  1|  1|  2|  1|        012|
|  0|  0|  1|  0|  1|        000|
|  1|  0|  1|  1|  1|        101|
+---+---+---+---+---+-----------+

pattern =  '0' * len(list_of_cols) 

df1 = df.where(df['Concat_cols'] != pattern) # pattern will be 0's and the number will be equal to length of the columns list.
df1.show()

    +---+---+---+---+---+-----------+
    |  a|  b|  c|  d|  e|Concat_cols|
    +---+---+---+---+---+-----------+
    |  0|  1|  1|  2|  1|        012|
    |  1|  0|  1|  1|  1|        101|
    +---+---+---+---+---+-----------+
0
votes

If the intent is just to check 0 occurrence in all columns and the lists are causing problem then possibly combine them 1000 at a time and then test for non-zero occurrence.

from pyspark.sql import functions as F

# all or whatever columns you would like to test.
columns = df.columns 

# Columns required to be concatenated at a time.
split = 1000 

# list of 1000 columns concatenated into a single column
blocks = [F.concat(*columns[i*split:(i+1)*split]) 
            for i in range((len(columns)+split-1)//split)]

# where expression here replaces zeroes to check if the resultant string is blank or not.
(df.select("*")
    .where(F.regexp_replace(F.concat(*blocks).alias("concat"), "0", "") != "" )
    .show(10, False))