0
votes

I have an RDD with 6 columns, where the last 5 columns might contain NaNs. My intention is to replace the NaNs with the average value of the rest of the last 5 values of the row which are not Nan. For instance, having this input:

1, 2, 3, 4, 5, 6
2, 2, 2, NaN, 4, 0
3, NaN, NaN, NaN, 6, 0
4, NaN, NaN, 4, 4, 0 

The output should be:

1, 2, 3, 4, 5, 6
2, 2, 2, 2, 4, 0
3, 3, 3, 3, 6, 0
4, 3, 3, 4, 4, 0

I know how to fill those NaNs with the average value of the column transforming the RDD to DataFrame:

var aux1 = df.select(df.columns.map(c => mean(col(c))) :_*)
var aux2 = df.na.fill(/*get values of aux1*/)

My question is, how can you do this operation but instead of filling the NaN with the column average, fill it with an average of the values of a subgroup of the row?

3

3 Answers

2
votes

You can do this by defining a function to get the mean, and another function to fill nulls in a row.

Given the DF you presented:

val df = sc.parallelize(List((Some(1),Some(2),Some(3),Some(4),Some(5),Some(6)),(Some(2),Some(2),Some(2),None,Some(4),Some(0)),(Some(3),None,None,None,Some(6),Some(0)),(Some(4),None,None,Some(4),Some(4),Some(0)))).toDF("a","b","c","d","e","f")

We need a function to get the mean of a Row:

import org.apache.spark.sql.Row
def rowMean(row: Row): Int = {
   val nonNulls = (0 until row.length).map(i => (!row.isNullAt(i), row.getAs[Int](i))).filter(_._1).map(_._2).toList
   nonNulls.sum / nonNulls.length
}

And another to fill nulls in a Row:

def rowFillNulls(row: Row, fill: Int): Row = {
   Row((0 until row.length).map(i => if (row.isNullAt(i)) fill else row.getAs[Int](i)) : _*)
}

Now we can first compute each row mean:

val rowWithMean = df.map(row => (row,rowMean(row)))

And then fill it:

val result = sqlContext.createDataFrame(rowWithMean.map{case (row,mean) => rowFillNulls(row,mean)}, df.schema)

Finally view before and after...

df.show
+---+----+----+----+---+---+
|  a|   b|   c|   d|  e|  f|
+---+----+----+----+---+---+
|  1|   2|   3|   4|  5|  6|
|  2|   2|   2|null|  4|  0|
|  3|null|null|null|  6|  0|
|  4|null|null|   4|  4|  0|
+---+----+----+----+---+---+

result.show
+---+---+---+---+---+---+
|  a|  b|  c|  d|  e|  f|
+---+---+---+---+---+---+
|  1|  2|  3|  4|  5|  6|
|  2|  2|  2|  2|  4|  0|
|  3|  3|  3|  3|  6|  0|
|  4|  3|  3|  4|  4|  0|
+---+---+---+---+---+---+

This will work for any width DF with Int columns. You can easily update this to other datatypes, even non-numeric (hint, inspect the df schema!)

1
votes

A bunch of imports:

import org.apache.spark.sql.functions.{col, isnan, isnull, round, when}
import org.apache.spark.sql.Column

A few helper functions:

def nullOrNan(c: Column) = isnan(c) || isnull(c)

def rowMean(cols: Column*): Column = {
  val sum = cols
    .map(c => when(nullOrNan(c), lit(0.0)).otherwise(c))
    .fold(lit(0.0))(_ + _)
  val count = cols
    .map(c => when(nullOrNan(c), lit(0.0)).otherwise(lit(1.0)))
    .fold(lit(0.0))(_ + _)
  sum / count
}

A solution:

val mean = round(
  rowMean(df.columns.tail.map(col): _*)
).cast("int").alias("mean")

val exprs = df.columns.tail.map(
  c => when(nullOrNan(col(c)), mean).otherwise(col(c)).alias(c)
)

val filled = df.select(col(df.columns(0)) +: exprs: _*)
1
votes

Well, this is a fun little problem - I will post my solution, but I will definitely watch and see if someone comes up with a better way of doing it :)

First I would introduce a couple of udfs:

val avg = udf((values: Seq[Integer]) => {
  val notNullValues = values.filter(_ != null).map(_.toInt)
  notNullValues.sum/notNullValues.length
})

val replaceNullWithAvg = udf((x: Integer, avg: Integer) => if(x == null) avg else x)

which I would then apply to the DataFrame like this:

dataframe
  .withColumn("avg", avg(array(df.columns.tail.map(s => df.col(s)):_*)))
  .select('col1, replaceNullWithAvg('col2, 'avg) as "col2", replaceNullWithAvg('col3, 'avg) as "col3", replaceNullWithAvg('col4, 'avg) as "col4", replaceNullWithAvg('col5, 'avg) as "col5", replaceNullWithAvg('col6, 'avg) as "col6")

This will get you what you are looking for, but arguably not the most sophisticated code I have ever put together...