0
votes

I have a dataframe df looks like this

from pyspark.sql.functions import lit, col, create_map
    df = spark.createDataFrame(
        [
            ("1","A","B","2020-01-01", 6), 
            ("2","A","B","2020-01-01", 6),
            ("3","A","C","2020-01-01", 6),
            ("4","A","C","2020-01-01", 6)  ,    
            ("5","B","D","2020-01-01", 10),
            ("6","B","D","2020-01-01",10),
        ],
        ["id","map1","map2","date",'var']  
    )
    +---+----+----+----------+---+
    | id|map1|map2|      date|var|
    +---+----+----+----------+---+
    |  1|   A|   B|2020-01-01|  6|
    |  2|   A|   B|2020-01-01|  6|
    |  3|   A|   C|2020-01-01|  6|
    |  4|   A|   C|2020-01-01|  6|
    |  5|   B|   D|2020-01-01| 10|
    |  6|   B|   D|2020-01-01| 10|
    +---+----+----+----------+---+

Now I would like to map using map1 and map2 column such that ... shown in the screenshot below. enter image description here

note that for all different map1 values , (A,B) the var values are same (6,10) and map1 can not be null but map2 can be null.

I want to do this without using join/rdd/udf as much as possible, just depends on pure pyspark functions for the performance.

first, I create a column dictionary key : value

df = df.withColumn("mapp", create_map('map1', 'var'))

enter image description here

I tried using something like but this obviously does not work dynamically.

df = df.withColumn('var_mapped',  df["mapp"].getItem(df['map1']))

what are some solutions/functions to use in this case? any help would be appreciated.

1
what spark version you are using?AdibP
I am using spark 3.1.1dakjdlajsl

1 Answers

1
votes

to get all the key-value combinations of map across rows, you can use window functions. In this case, you can use collect_set of struct of column map1 and var over an all-rows window, then create a map using map_from_entries. It would be something like this

from pyspark.sql.functions import map_from_entries, collect_set, struct, col
from pyspark.sql.window import Window

df = df.withColumn("mapp", map_from_entries(collect_set(struct(col('map1'), col('var'))).over(Window.partitionBy())))
df.show()

+---+----+----+----------+---+-----------------+
| id|map1|map2|      date|var|             mapp|
+---+----+----+----------+---+-----------------+
|  1|   A|   B|2020-01-01|  6|[B -> 10, A -> 6]|
|  2|   A|   B|2020-01-01|  6|[B -> 10, A -> 6]|
|  3|   A|   C|2020-01-01|  6|[B -> 10, A -> 6]|
|  4|   A|   C|2020-01-01|  6|[B -> 10, A -> 6]|
|  5|   B|   D|2020-01-01| 10|[B -> 10, A -> 6]|
|  6|   B|   D|2020-01-01| 10|[B -> 10, A -> 6]|
+---+----+----+----------+---+-----------------+

After that, you can map column map2 using .getItem().

df = df.withColumn('res', col('mapp').getItem(col('map2'))).fillna(0)
df.show()

+---+----+----+----------+---+-----------------+---+
| id|map1|map2|      date|var|             mapp|res|
+---+----+----+----------+---+-----------------+---+
|  1|   A|   B|2020-01-01|  6|[B -> 10, A -> 6]| 10|
|  2|   A|   B|2020-01-01|  6|[B -> 10, A -> 6]| 10|
|  3|   A|   C|2020-01-01|  6|[B -> 10, A -> 6]|  0|
|  4|   A|   C|2020-01-01|  6|[B -> 10, A -> 6]|  0|
|  5|   B|   D|2020-01-01| 10|[B -> 10, A -> 6]|  0|
|  6|   B|   D|2020-01-01| 10|[B -> 10, A -> 6]|  0|
+---+----+----+----------+---+-----------------+---+