1
votes

I am trying to filter a DataFrame comparing two date columns using Scala and Spark. Based on the filtered DataFrame there are calculations running on top to calculate new columns. Simplified my data frame has the following schema:

|-- received_day: date (nullable = true)
|-- finished: int (nullable = true)

On top of that I create two new column t_start and t_end that would be used for filtering the DataFrame. They have 10 and 20 days difference from the original column received_day:

val dfWithDates= df
      .withColumn("t_end",date_sub(col("received_day"),10))
      .withColumn("t_start",date_sub(col("received_day"),20))

I now want to have a new calculated column that indicates for each row of data how many rows of the dataframe are in the t_start to t_end period. I thought I can achieve this the following way:

val dfWithCount = dfWithDates
       .withColumn("cnt", lit(
        dfWithDates.filter(
          $"received_day".lt(col("t_end")) 
          && $"received_day".gt(col("t_start"))).count()))

However, this count only returns 0 and I believe that the problem is in the the argument that I am passing to lt and gt.

From following that issue here Filtering a spark dataframe based on date I realized that I need to pass a string value. If I try with hard coded values like lt(lit("2018-12-15")), then the filtering works. So I tried casting my columns to StringType:

val dfWithDates= df
      .withColumn("t_end",date_sub(col("received_day"),10).cast(DataTypes.StringType))
      .withColumn("t_start",date_sub(col("received_day"),20).cast(DataTypes.StringType))

But the filter still returns an empty dataFrame. I would assume that I am not handling the data type right.

I am running on Scala 2.11.0 with Spark 2.0.2.

2
Maybe also someone knows where I can find the documentation on lt() and gt()? I tried searching for it but couldn't find what I was looking for.Inna
You can find all functions operable on Column Data types here : spark.apache.org/docs/latest/api/scala/…philantrovert

2 Answers

1
votes

Yes you are right. For $"received_day".lt(col("t_end") each reveived_day value is compared with the current row's t_end value, not the whole dataframe. So each time you'll get zero as count. You can solve this by writing a simple udf. Here is the way how you can solve the issue:

Creating sample input dataset:

import org.apache.spark.sql.{Row, SparkSession}
import java.sql.Date
import org.apache.spark.sql.functions._
import spark.implicits._
val df = Seq((Date.valueOf("2018-10-12"),1),
              (Date.valueOf("2018-10-13"),1),
              (Date.valueOf("2018-09-25"),1),
              (Date.valueOf("2018-10-14"),1)).toDF("received_day", "finished")

val dfWithDates= df
  .withColumn("t_start",date_sub(col("received_day"),20))
  .withColumn("t_end",date_sub(col("received_day"),10))
dfWithDates.show()
    +------------+--------+----------+----------+
|received_day|finished|   t_start|     t_end|
+------------+--------+----------+----------+
|  2018-10-12|       1|2018-09-22|2018-10-02|
|  2018-10-13|       1|2018-09-23|2018-10-03|
|  2018-09-25|       1|2018-09-05|2018-09-15|
|  2018-10-14|       1|2018-09-24|2018-10-04|
+------------+--------+----------+----------+

Here for 2018-09-25 we desire count 3

Generate output:

val count_udf = udf((received_day:Date) => {
        (dfWithDates.filter((col("t_end").gt(s"$received_day")) && col("t_start").lt(s"$received_day")).count())
    })
    val dfWithCount = dfWithDates.withColumn("count",count_udf(col("received_day")))
    dfWithCount.show()
    +------------+--------+----------+----------+-----+
|received_day|finished|   t_start|     t_end|count|
+------------+--------+----------+----------+-----+
|  2018-10-12|       1|2018-09-22|2018-10-02|    0|
|  2018-10-13|       1|2018-09-23|2018-10-03|    0|
|  2018-09-25|       1|2018-09-05|2018-09-15|    3|
|  2018-10-14|       1|2018-09-24|2018-10-04|    0|
+------------+--------+----------+----------+-----+

To make computation faster i would suggest to cache dfWithDates as there are repetition of same operation for each row.

0
votes

You can cast date value to string with any pattern using DateTimeFormatter

import java.time.format.DateTimeFormatter

date.format(DateTimeFormatter.ofPattern("yyyy-MM-dd"))