2
votes

Is it possible to optimize Cross Joins in Spark SQL ? The requirement is to populate a column band_id based on age-ranges defined in another table. So far I have been able to implement the same thru a Cross Join and WHERE clause. But, I was hoping if there is a better way to code this and alleviate performance issues. Can I use a broadcast hint ? (sql provided below)

Customer: (10 M records)

id | name | age
X1 | John | 22
V2 | Mark | 29
F4 | Peter| 42

Age_band table: (10 records)

band_id | low_age | high_age
B123    |  10     | 19
X745    |  20     | 29
P134    |  30     | 39
Q245    |  40     | 50

Expected Output:

id | name | age | band_id
X1 | John | 22  | X745
V2 | Mark | 29  | X745
F4 | Peter| 42  | Q245

Query:

select
from cust a
cross join age_band b
where a.age between b.low_age and b.high_age;

Please advise.

2

2 Answers

2
votes

From SparkStrategies.scala source, it seems like in your case you can but you don't have to specify either cross or broadcast hint, because Broadcast Nested Loop Join is what Spark will select regardless:

   * ...
   * - Broadcast nested loop join (BNLJ):
   *     Supports both equi-joins and non-equi-joins.
   *     Supports all the join types, but the implementation is optimized for:
   *       1) broadcasting the left side in a right outer join;
   *       2) broadcasting the right side in a left outer, left semi, left anti or existence join;
   *       3) broadcasting either side in an inner-like join.
   *     For other cases, we need to scan the data multiple times, which can be rather slow. 
   * ...
0
votes

You do not need to use the cross join but the left join is enough. When I do both of them, the physical plans for query execution are slightly different. I prefer to use the later one.

val df3 = spark.sql("""
    SELECT 
        id, name, age, band_id
    FROM 
        cust a
    CROSS JOIN 
        age_band b
    ON 
        age BETWEEN low_age and high_age
""")

df3.explain

== Physical Plan ==
*(3) Project [id#75, name#76, age#77, band_id#97]
+- BroadcastNestedLoopJoin BuildLeft, Cross, ((age#77 >= low_age#98) AND (age#77 <= high_age#99))
   :- BroadcastExchange IdentityBroadcastMode, [id=#157]
   :  +- *(1) Project [id#75, name#76, age#77]
   :     +- *(1) Filter isnotnull(age#77)
   :        +- FileScan csv [id#75,name#76,age#77] Batched: false, DataFilters: [isnotnull(age#77)], Format: CSV, Location: InMemoryFileIndex[file:/test1.csv], PartitionFilters: [], PushedFilters: [IsNotNull(age)], ReadSchema: struct<id:string,name:string,age:int>
   +- *(2) Project [band_id#97, low_age#98, high_age#99]
      +- *(2) Filter (isnotnull(low_age#98) AND isnotnull(high_age#99))
         +- FileScan csv [band_id#97,low_age#98,high_age#99] Batched: false, DataFilters: [isnotnull(low_age#98), isnotnull(high_age#99)], Format: CSV, Location: InMemoryFileIndex[file:/test2.csv], PartitionFilters: [], PushedFilters: [IsNotNull(low_age), IsNotNull(high_age)], ReadSchema: struct<band_id:string,low_age:int,high_age:int>


val df4 = spark.sql("""
    SELECT  /*+ BROADCAST(age_band) */ 
        id, name, age, band_id
    FROM 
        cust a
    LEFT JOIN 
        age_band b
    ON 
        age BETWEEN low_age and high_age
""")

df4.explain

== Physical Plan ==
*(2) Project [id#75, name#76, age#77, band_id#97]
+- BroadcastNestedLoopJoin BuildRight, LeftOuter, ((age#77 >= low_age#98) AND (age#77 <= high_age#99))
   :- FileScan csv [id#75,name#76,age#77] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex[file:/test1.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:string,name:string,age:int>
   +- BroadcastExchange IdentityBroadcastMode, [id=#192]
      +- *(1) Project [band_id#97, low_age#98, high_age#99]
         +- *(1) Filter (isnotnull(low_age#98) AND isnotnull(high_age#99))
            +- FileScan csv [band_id#97,low_age#98,high_age#99] Batched: false, DataFilters: [isnotnull(low_age#98), isnotnull(high_age#99)], Format: CSV, Location: InMemoryFileIndex[file:/test2.csv], PartitionFilters: [], PushedFilters: [IsNotNull(low_age), IsNotNull(high_age)], ReadSchema: struct<band_id:string,low_age:int,high_age:int>