0
votes

I have a DataFrame with the following schema and some sample records:

// df.printSchema
root
 |-- CUST_NAME: string (nullable = true)
 |-- DIRECTION: string (nullable = true)
 |-- BANK_NAME: string (nullable = true)
 |-- TXN_AMT: double (nullable = false)


// df.show(false)
+---------+---------+---------+-------+
|CUST_NAME|DIRECTION|BANK_NAME|TXN_AMT|
+---------+---------+---------+-------+
|ABC      |D        |Bank1    |300.0  |
|DEF      |C        |Bank2    |10.0   |
|GHI      |C        |Bank3    |12.0   |
|JKL      |D        |Bank4    |500.0  |
+---------+---------+---------+-------+

Now, based on the value in the direction column, I need to conditionally add two new columns:

  1. FROM_BANK
  2. TO_BANK

In terms of simple code, it would look something like this:

var from_bank, to_bank
val direction = "D"
val bank_name = "Test"

direction match {
  case "D" => {
   from_bank = bank_name
   to_bank = null
  }
  case "C" => {
   from_bank = null
   to_bank = bank_name
  }
}

This above code is just an explanation of what I am trying to achieve and I know it's not something that could work with a Spark DataFrame.

I know I can get what I want with multiple when/otherwise clauses as follows:

val df2 = df.withColumn(
  "FROM_BANK",
    when($"DIRECTION" === "D", $"BANK_NAME")
    .otherwise(lit(null))
  )
  .withColumn(
    "TO_BANK",
    when($"DIRECTION" === "C", $"BANK_NAME")
      .otherwise(lit(null))
  )

df2.show(100,false)
//    +---------+---------+---------+-------+---------+-------+
//    |CUST_NAME|DIRECTION|BANK_NAME|TXN_AMT|FROM_BANK|TO_BANK|
//    +---------+---------+---------+-------+---------+-------+
//    |ABC      |D        |Bank1    |300.0  |Bank1    |null   |
//    |DEF      |C        |Bank2    |10.0   |null     |Bank2  |
//    |GHI      |C        |Bank3    |12.0   |null     |Bank3  |
//    |JKL      |D        |Bank4    |500.0  |Bank4    |null   |
//    +---------+---------+---------+-------+---------+-------+

This above approach seems simple but it's very verbose because in reality, I will be needing to do this for a total of 8 more columns. Another option I have considered, is using a .map function on the DataFrame as follows:

import spark.implicits._
val df3 = test_df.map(row => {
      val direction = row.getAs[String]("Direction")

      if (direction == "D")
        (row.getAs[String]("CUST_NAME"),
          row.getAs[String]("DIRECTION"),
          row.getAs[String]("BANK_NAME"),
          row.getAs[Double]("TXN_AMT"),
          row.getAs[String]("BANK_NAME"), // This will become the FROM_BANK column
          null // This will become to the TO_BANK column
        )
      else if (direction == "C")
        (row.getAs[String]("CUST_NAME"),
          row.getAs[String]("DIRECTION"),
          row.getAs[String]("BANK_NAME"),
          row.getAs[Double]("TXN_AMT"),
          null, // This will become the FROM_BANK column
          row.getAs[String]("BANK_NAME") // This will become to the TO_BANK column
        )
    }).toDF("CUST_NAME","DIRECTION","BANK_NAME","TXN_AMOUNT","FROM_BANK","TO_BANK")

However, when running the above, I am getting the following error:

Error:(35, 26) Unable to find encoder for type stored in a Dataset.  Primitive types (Int, String, etc) and Product types (case classes) are supported by importing spark.implicits._  Support for serializing other types will be added in future releases.
    val df3 = test_df.map(row => {

I tried modifying the above by creating a statically typed DataSet but still the same issue:

import spark.implicits._

case class Record(CUST_NAME: String, DIRECTION: String, BANK_NAME: String, TXN_AMT: Double)

val test_df4 = test_df.as[Record].map(row => {
  val direction = row.DIRECTION

  if (direction == "D")
    (
      row.CUST_NAME,
      row.DIRECTION,
      row.BANK_NAME,
      row.TXN_AMT,
      row.BANK_NAME, // This will become the FROM_BANK column
      null // This will become to the TO_BANK column
    )
  else if (direction == "C")
    (
      row.CUST_NAME,
      row.DIRECTION,
      row.BANK_NAME,
      row.TXN_AMT,
      null, // This will become the FROM_BANK column
      row.BANK_NAME // This will become to the TO_BANK column
    )
}).toDF("CUST_NAME","DIRECTION","BANK_NAME","TXN_AMOUNT","FROM_BANK","TO_BANK")
test_df4.show(100,false)

I know the first option can work, but I was hoping to do it in a more programmatic way since I need to do this for multiple columns that are all based from the DIRECTION column value. Would appreciate any feedback or suggestions on this.

Thanks!

1

1 Answers

0
votes

You can put the when statements in a list (or add them programmatically to the list), and then select them. Then you don't need to chain a bunch of withColumn statements. Also note that .otherwise(null) is not necessary because that's the default behaviour.

val newcols = List(
    col("*"),
    when($"DIRECTION" === "D", $"BANK_NAME").as("FROM_BANK"),
    when($"DIRECTION" === "C", $"BANK_NAME").as("TO_BANK")
)

val df2 = df.select(newcols: _*)