2
votes

Given a dataframe, say that it contains 4 columns and 3 rows. I want to write a function to return the columns where all the values in that column are equal to 1.

This is a Scala code. I want to use some spark transformations to transform or filter the dataframe input. This filter should be implemented in a function.

case class Grade(c1: Integral, c2: Integral, c3: Integral, c4: Integral)
val example = Seq(
      Grade(1,3,1,1),
      Grade(1,1,null,1),
      Grade(1,10,2,1)
    )

    val dfInput = spark.createDataFrame(example)

After I call the function filterColumns()

val dfOutput = dfInput.filterColumns()

it should return 3 row 2 columns dataframe with value all 1.

3

3 Answers

1
votes

one of the options is reduce on rdd:

  import spark.implicits._

  val df= Seq(("1","A","3","4"),("1","2","?","4"),("1","2","3","4")).toDF()
  df.show()

  val first = df.first()
  val size = first.length
  val diffStr = "#"
  val targetStr = "1"

   def rowToArray(row: Row): Array[String] = {
     val arr = new Array[String](row.length)
     for (i <- 0 to row.length-1){
       arr(i) = row.getString(i)
     }
     arr
   }

  def compareArrays(a1: Array[String], a2: Array[String]): Array[String] = {
    val arr = new Array[String](a1.length)
    for (i <- 0 to a1.length-1){
      arr(i) = if (a1(i).equals(a2(i)) && a1(i).equals(targetStr)) a1(i) else diffStr
    }
    arr
  }

  val diff = df.rdd
    .map(rowToArray)
    .reduce(compareArrays)

  val cols = (df.columns zip diff).filter(!_._2.equals(diffStr)).map(s=>df(s._1))

  df.select(cols:_*).show()
    +---+---+---+---+
    | _1| _2| _3| _4|
    +---+---+---+---+
    |  1|  A|  3|  4|
    |  1|  2|  ?|  4|
    |  1|  2|  3|  4|
    +---+---+---+---+

    +---+
    | _1|
    +---+
    |  1|
    |  1|
    |  1|
    +---+
1
votes

I would try to prepare dataset for processing without nulls. In case of few columns this simple iterative approach might work fine (don't forget to import spark implicits before import spark.implicits._):

val example = spark.sparkContext.parallelize(Seq(
    Grade(1,3,1,1),
    Grade(1,1,0,1),
    Grade(1,10,2,1)
)).toDS().cache()

def allOnes(colName: String, ds: Dataset[Grade]): Boolean = {
    val row = ds.select(colName).distinct().collect()
    if (row.length == 1 && row.head.getInt(0) == 1) true
    else false
}

val resultColumns = example.columns.filter(col => allOnes(col, example))
example.selectExpr(resultColumns: _*).show()

result is:

+---+---+
| c1| c4|
+---+---+
|  1|  1|
|  1|  1|
|  1|  1|
+---+---+

If nulls are inevitable, use untyped dataset (aka dataframe):

val schema = StructType(Seq(
    StructField("c1", IntegerType, nullable = true),
    StructField("c2", IntegerType, nullable = true),
    StructField("c3", IntegerType, nullable = true),
    StructField("c4", IntegerType, nullable = true)
))

val example = spark.sparkContext.parallelize(Seq(
    Row(1,3,1,1),
    Row(1,1,null,1),
    Row(1,10,2,1)
))

val dfInput = spark.createDataFrame(example, schema).cache()

def allOnes(colName: String, df: DataFrame): Boolean = {
    val row = df.select(colName).distinct().collect()
    if (row.length == 1 && row.head.getInt(0) == 1) true
    else false
}

val resultColumns= dfInput.columns.filter(col => allOnes(col, dfInput))
dfInput.selectExpr(resultColumns: _*).show()
1
votes

A bit more readable approach using Dataset[Grade]

import org.apache.spark.sql.functions.col
import scala.collection.mutable
import org.apache.spark.sql.Column

val tmp = dfInput.map(grade => grade.dropWhenNotEqualsTo(1))
val rowsCount = dfInput.count()

val colsToRetain = mutable.Set[Column]()
 for (column <- tmp.columns) {
   val withoutNullsCount = tmp.select(column).na.drop().count()
   if (rowsCount == withoutNullsCount) colsToRetain += col(column)
}

dfInput.select(colsToRetain.toArray:_*).show()

+---+---+
| c4| c1|
+---+---+
|  1|  1|
|  1|  1|
|  1|  1|
+---+---+

And the case object

case class Grade(c1: Integer, c2: Integer, c3: Integer, c4: Integer) {
  def dropWhenNotEqualsTo(n: Integer): Grade = {
    Grade(nullOrValue(c1, n), nullOrValue(c2, n), nullOrValue(c3, n), nullOrValue(c4, n))
  }
  def nullOrValue(c: Integer, n: Integer) = if (c == n) c else null
}
  1. grade.dropWhenNotEqualsTo(1) -> returns a new Grade with values that not satisfies the condition replaced to nulls
+---+----+----+---+
| c1|  c2|  c3| c4|
+---+----+----+---+
|  1|null|   1|  1|
|  1|   1|null|  1|
|  1|null|null|  1|
+---+----+----+---+
  1. (column <- tmp.columns) -> iterate over the columns

  2. tmp.select(column).na.drop() -> drop rows with nulls e.g for c2 this will return

+---+
| c2|
+---+
|  1|
+---+
  1. if (rowsCount == withoutNullsCount) colsToRetain += col(column) -> if column contains nulls just drop it