3
votes

I have a question. I have a spark dataframe with several columns looking like:

id Color
1 Red, Blue, Black
2 Red, Green
3 Blue, Yellow, Green
...

I also have a map file looking like :
Red,0
Blue,1
Green,2
Black,3
Yellow,4

what I need to do is to map the color name into different ids, such as mapping "Red, Blue, Black" into an array of [1,1,0,1,0]. I write a code like this way:

def mapColor(label_string:String):Array[Int]={
var labels = label_string.split(",")
var index_array = new Array[Int](COLOR_LENGTH)
for (label<-labels){
  if(COLOR_MAP.contains(label)){
    index_array(COLOR_MAP(label))=1
  }
  else{
    //dictionary does not contain the label, the last index set to be one
    index_array(COLOR_LENGTH-1)=1
  }
}
index_array 
}

The COLOR_LENGTH is the length of the dictionary, and COLOR_MAP is the dictionary that contains the string->id relationship.

I call this function like this way:

 val color_function = udf(mapColor:(String)=>Array[Int])
 sql.withColumn("color_idx",color_function(col("Color")))

Since I have multiple columns need this operation, but different columns need different dictionaries. Currently, I duplicate this function for each column (just change the dictionary and length information). But the code looks tedious. Is there any method, I can pass the length and dictionary into the mapping function, such as

def map(label_string:String,map:Map[String,Integer],len:Int):Array[Int] 

But how should I call this function in the spark dataframe? Since there is no way for me to pass the parameter in the declaration

val color_function = udf(mapColor:(String)=>Array[Int])
2

2 Answers

8
votes

You can use a UDF that comes with the color Map as the base argument, like in the following example:

val df = Seq(
  (1, "Red, Blue, Black"),
  (2, "Red, Green"),
  (3, "Blue, Yellow, Green")
).toDF("id", "color")

val colorMap = Map("Red"-> 0, "Blue"->1, "Green"->2, "Black"->3, "Yellow"->4)

def mapColorCode(m: Map[String, Int]) = udf( (s: String) =>
  s.split("""\s*,\s*""").map(c => m.getOrElse(c, -99))
)

df.select($"id", mapColorCode(colorMap)($"color").as("colorcode")).show
// +---+----------+
// | id| colorcode|
// +---+----------+
// |  1| [0, 1, 3]|
// |  2|    [0, 2]|
// |  3| [1, 4, 2]|
// +---+----------+
-1
votes

Here is the full code for brevity -

val colrMapList = List("Red" -> 0, "Blue" -> 1, "Green" -> 2).toMap

def getColor = udf((colors: Seq[String]) => { if(!colors.isEmpty) colors.map(color => colrMapList.getOrElse(color,"0")).mkString(",") else "0"  } )

val colors = List((1, Array("Red","Blue","Black")),(2,Array("Red", "Green")))
val colrDF = sc.parallelize(colors).toDF

colrDF.withColumn("colorMap", getColor($"colors")).show

Explanation

  1. Create a map for color to integer mapping.
  2. The getColor function pulls the corresponding integers given the colors
  3. Finally you apply the function of the colrDF to get the output