0
votes

I am trying to add a set of string values (Business, Casual) to a column=category in spark dataframe.

My dataframe is :

+----------+----+--------+
|    source|live|category|
+----------+----+--------+
|      Ford|   Y|        |
|      Ford|   Y|        |
|  Caddilac|   Y|        |
|  Caddilac|   Y|        |
| Chevrolet|   Y|        |
| Chevrolet|   Y|        |
|     Skoda|   Y|        |
|     Skoda|   Y|        |
|      Fiat|   Y|        |
|      Fiat|   Y|        |
|Alfa Romeo|   Y|        |
|Alfa Romeo|   Y|        |
+----------+----+--------+

What I am working to get is a set of repeated values in a new/existing column:

|    source|live|category|
+----------+----+--------+
|      Ford|   Y|Business|
|      Ford|   Y|  Casual|
|  Caddilac|   Y|Business|
|  Caddilac|   Y|  Casual|
| Chevrolet|   Y|Business|
| Chevrolet|   Y|  Casual|
|     Skoda|   Y|Business|
|     Skoda|   Y|  Casual|
|      Fiat|   Y|Business|
|      Fiat|   Y|  Casual|
|Alfa Romeo|   Y|Business|
|Alfa Romeo|   Y|  Casual|
+----------+----+--------+

I have tried to add the "category" column using withColumn and lit() but it takes just 1 value in argument. I also tried the explode(array()) but it multiplies the table to double rows.

The values in "category" column are constant and repeated , and not dependent on any other criteria

Any help would be appreciated. Thanks

3

3 Answers

0
votes

Try this - withColumn("category", expr("element_at(array('Business', 'Casual'), row_number() over(partition by source, live order by source, live))"))

test-

Load the test data

 val data =
      """
        |    source|live
        |      Ford|   Y
        |      Ford|   Y
        |  Caddilac|   Y
        |  Caddilac|   Y
        | Chevrolet|   Y
        | Chevrolet|   Y
        |     Skoda|   Y
        |     Skoda|   Y
        |      Fiat|   Y
        |      Fiat|   Y
        |Alfa Romeo|   Y
        |Alfa Romeo|   Y
      """.stripMargin

    val stringDS = data.split(System.lineSeparator())
      .map(_.split("\\|").map(_.replaceAll("""^[ \t]+|[ \t]+$""", "")).mkString(","))
      .toSeq.toDS()
    val df = spark.read
      .option("sep", ",")
      .option("inferSchema", "true")
      .option("header", "true")
      .option("nullValue", "null")
      .csv(stringDS)

    df.show(false)
    df.printSchema()

    /**
      * +----------+----+
      * |source    |live|
      * +----------+----+
      * |Ford      |Y   |
      * |Ford      |Y   |
      * |Caddilac  |Y   |
      * |Caddilac  |Y   |
      * |Chevrolet |Y   |
      * |Chevrolet |Y   |
      * |Skoda     |Y   |
      * |Skoda     |Y   |
      * |Fiat      |Y   |
      * |Fiat      |Y   |
      * |Alfa Romeo|Y   |
      * |Alfa Romeo|Y   |
      * +----------+----+
      *
      * root
      * |-- source: string (nullable = true)
      * |-- live: string (nullable = true)
      */

Derive category column

    df.withColumn("category", expr("element_at(array('Business', 'Casual'), row_number() over(partition by source, " +
      "live order by source, live))"))
      .show(false)

    /**
      * +----------+----+--------+
      * |source    |live|category|
      * +----------+----+--------+
      * |Alfa Romeo|Y   |Business|
      * |Alfa Romeo|Y   |Casual  |
      * |Caddilac  |Y   |Business|
      * |Caddilac  |Y   |Casual  |
      * |Chevrolet |Y   |Business|
      * |Chevrolet |Y   |Casual  |
      * |Ford      |Y   |Business|
      * |Ford      |Y   |Casual  |
      * |Skoda     |Y   |Business|
      * |Skoda     |Y   |Casual  |
      * |Fiat      |Y   |Business|
      * |Fiat      |Y   |Casual  |
      * +----------+----+--------+
      */
1
votes

use functions.rand() to generate random nos. and modulo them with 2. based on 1 and 0 output your result.

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes

object RandColumn {

  def main(args: Array[String]): Unit = {

    val spark = Constant.getSparkSess

    import spark.implicits._


    val df = List(("Ford","Y"),("Ford","Y"),
      ("Caddilac","Y"),("Caddilac","Y")
    ).toDF("source","live")

    df
      .withColumn("category",when( ( (rand()*100) % 2).cast(DataTypes.IntegerType) === 0,"Business")
      .otherwise("Casual"))
      .drop("randomNo")
      .show()
  }

}

0
votes

Another way of doing it,

df.withColumn("rn", row_number().over(Window.partitionBy($"source").orderBy($"source"))).withColumn("category", when($"rn" % 2 === 0, "Casual").otherwise("Business")).drop("rn").show
+--------+----+--------+
|  source|live|category|
+--------+----+--------+
|Caddilac|   Y|Business|
|Caddilac|   Y|  Casual|
|    Ford|   Y|Business|
|    Ford|   Y|  Casual|
+--------+----+--------+