I wanted to filter out the rows that have zero values for all the columns in a list.
Suppose for example we have the follwing df,
df = spark.createDataFrame([(0, 1, 1, 2,1), (0, 0, 1, 0, 1), (1, 0, 1, 1 ,1)], ['a', 'b', 'c', 'd', 'e'])
+---+---+---+---+---+
| a| b| c| d| e|
+---+---+---+---+---+
| 0| 1| 1| 2| 1|
| 0| 0| 1| 0| 1|
| 1| 0| 1| 1| 1|
+---+---+---+---+---+
and the list of columns is ['a', 'b', 'd'] so the filtered dataframe should be,
+---+---+---+---+---+
| a| b| c| d| e|
+---+---+---+---+---+
| 0| 1| 1| 2| 1|
| 1| 0| 1| 1| 1|
+---+---+---+---+---+
This is what I have tried,
df = df.withColumn('total', sum(df[col] for col in ['a', 'b', 'd']))
df = df.filter(df.total > 0).drop('total')
This works fine for small datasets but fails with the following error if the col_list is very long with the following error.
ava.lang.StackOverflowErrorat org.apache.spark.sql.catalyst.analysis.ResolveLambdaVariables.org$apache$spark$sql$catalyst$analysis$ResolveLambdaVariables$$resolve(higher...
I can think of a pandas udf solution but my df is very large and that might be a bottleneck.
Edit:
When using @Psidom's answer I get the following error
py4j.protocol.Py4JJavaError: An error occurred while calling o2508.filter. : java.lang.StackOverflowError at org.apache.spark.sql.catalyst.expressions.Expression.references(Expression.scala:88) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) at org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$references$1.apply(Expression.scala:88) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.TraversableLike$$anonfun$flatMap$1.apply(TraversableLike.scala:241) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike$class.flatMap(TraversableLike.scala:241) at scala.collection.immutable.List.flatMap(List.scala:355)