0
votes

Referring this question: Spark / Scala: forward fill with last observation,

I am trying to reproduce the problem and to solve it.

I've created a file mre.csv:

Date,B
2015-06-01,33
2015-06-02,
2015-06-03,
2015-06-04,
2015-06-05,22
2015-06-06,
2015-06-07,

Then I read the file:

var df = spark.read.format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
  .load("D:/playground/mre.csv")

df.show()

val rows: RDD[Row] = df.orderBy($"Date").rdd
val schema = df.schema

Then I managed to solve the problem using this code:

df = df.withColumn("id",lit(1))
var spec = Window.partitionBy("id").orderBy("Date")
val df2 = df.withColumn("B", coalesce((0 to 6).map(i=>lag(df.col("B"),i,0).over(spec)): _*))

df2.show()

Output:

+-------------------+---+---+
|               Date|  B| id|
+-------------------+---+---+
|2015-06-01 00:00:00| 33|  1|
|2015-06-02 00:00:00| 33|  1|
|2015-06-03 00:00:00| 33|  1|
|2015-06-04 00:00:00| 33|  1|
|2015-06-05 00:00:00| 22|  1|
|2015-06-06 00:00:00| 22|  1|
|2015-06-07 00:00:00| 22|  1|
+-------------------+---+---+

The problem however is that it's all calculated in a single partition so I don't really take advantage of Spark here.

So I tried insead this code:

def notMissing(row: Row): Boolean = { !row.isNullAt(1) }

val toCarry: scala.collection.Map[Int,Option[org.apache.spark.sql.Row]] = rows
  .mapPartitionsWithIndex{ case (i, iter) =>
    Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }
  .collectAsMap

val toCarryBd = sc.broadcast(toCarry)

def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  if (iter.contains(null)) iter.map(row => Row(toCarryBd.value(i).get(1))) else iter
}

val imputed: RDD[Row] = rows
  .mapPartitionsWithIndex{ case (i, iter) => fill(i, iter) }

val df2 = spark.createDataFrame(imputed, schema).toDF()

df2.show()

But the output is disappointing:

+----+---+
|Date|  B|
+----+---+
+----+---+
1
Hi Alon, since you're working with Dates, can we consider the dataset is small enough to be broadcasted ? The idea would be to broadcast at least rows having a defined B so you can work in a distributed way, considering date intervals to determine which value to assign to B. - baitmbarek

1 Answers

0
votes

The implementation of the fill function is wrong here. Take a look at the steps mentioned in the answer in the referred question.

def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  // If it is the beginning of partition and value is missing
  // extract value to fill from toCarryBd.value
  // Remember to correct for empty / only missing partitions
  // otherwise take last not-null from the current partition
}

I have implemented it below:

def notMissing(row: Row): Boolean = { !row.isNullAt(1) }

val toCarryTemp: scala.collection.Map[Int,Option[org.apache.spark.sql.Row]] = rows
  .mapPartitionsWithIndex{ case (i, iter) =>
    Iterator((i, iter.filter(notMissing(_)).toSeq.lastOption)) }
  .collectAsMap

Extract col B value from the map and traverse it to fill the value with previous partition value in case current partition has null value. If we skip this step we will end up with output like:

+-------------------+---+
|               Date|  B|
+-------------------+---+
|2015-06-01 00:00:00| 33|
|2015-06-02 00:00:00|  0|
|2015-06-03 00:00:00|  0|
|2015-06-04 00:00:00|  0|
|2015-06-05 00:00:00| 22|
|2015-06-06 00:00:00|  0|
|2015-06-07 00:00:00|  0|
+-------------------+---+
var toCarry = scala.collection.mutable.Map[Int, Int]()

for(i <- 0 until rows.getNumPartitions) {
     toCarry(i) = toCarryTemp(i) match {
         case Some(row) => row.getInt(1) 
         case None if(i > 0) => toCarry(i-1)
         case None => 0
     }
 }

val toCarryBd = sc.broadcast(toCarry)
def fillUtil(row: Row, value: Int) = {
    if(!notMissing(row)) {
        Row(row.getTimestamp(0), value)
    }
    else row
}


def fill(i: Int, iter: Iterator[Row]): Iterator[Row] = {
  val carry = toCarryBd.value(i)
  if(iter.isEmpty) iter
  else {
      val myListHead::myListTail = iter.toList
      val resultHead = fillUtil(myListHead, carry)   //only for the first index toCarry 
      var currVal = resultHead.getInt(1)             //is used, for others we maintain 
      val resultTail = myListTail.map{row =>         //curr value
        val row1 = fillUtil(row, currVal)
        currVal = row1.getInt(1)
        row1
      }
      (resultHead :: resultTail).iterator
    }
  }


val imputed: RDD[Row] = rows.mapPartitionsWithIndex{ case (i, iter) => fill(i, iter) }

val df2 = spark.createDataFrame(imputed, schema).toDF()

df2.show()

Output:

+-------------------+---+
|               Date|  B|
+-------------------+---+
|2015-06-01 00:00:00| 33|
|2015-06-02 00:00:00| 33|
|2015-06-03 00:00:00| 33|
|2015-06-04 00:00:00| 33|
|2015-06-05 00:00:00| 22|
|2015-06-06 00:00:00| 22|
|2015-06-07 00:00:00| 22|
+-------------------+---+