1
votes

I want to use spark SQL window functions to do some aggregations and windowing.

Suppose I'm using the example table provided here a: https://databricks.com/blog/2015/07/15/introducing-window-functions-in-spark-sql.html

enter image description here

I want to run the query to give me the max 2 revenue for each category and also the count of product for each category.

After I run this query

SELECT
  product,
  category,
  revenue
FROM (
  SELECT
    product,
    category,
    revenue,
    dense_rank() OVER (PARTITION BY category ORDER BY revenue DESC) as rank
    count(*) OVER (PARTITION BY category ORDER BY revenue DESC) as count
  FROM productRevenue) tmp
WHERE
  rank <= 2

I got the table like this:

product category    revenue count
pro2    tablet  6500    1
mini    tablet  5500    2

instead of

product category    revenue count
pro2    tablet  6500    5
mini    tablet  5500    5

which is what I expected.

How should I write my code to get the right count for each category (instead of using another separate Group By statement)?

2

2 Answers

2
votes

In Spark if window clause having order by window defaults to ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW.

For your case add ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING in count(*) window clause.

Try with:

 SELECT
  product,
  category,
  revenue,count
FROM (
  SELECT
    product,
    category,
    revenue,
    dense_rank() OVER (PARTITION BY category ORDER BY revenue DESC) as rank,
    count(*) OVER (PARTITION BY category ORDER BY revenue DESC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as count
  FROM productRevenue) tmp
WHERE
  rank <= 2
0
votes

Change count(*) OVER (PARTITION BY category ORDER BY revenue DESC) as count to count(*) OVER (PARTITION BY category ORDER BY category DESC) as count. You will get expected result.

Try below code.

scala> spark.sql("""SELECT
     |   product,
     |   category,
     |   revenue,
     |   rank,
     |   count
     | FROM (
     |   SELECT
     |     product,
     |     category,
     |     revenue,
     |     dense_rank() OVER (PARTITION BY category ORDER BY revenue DESC) as rank,
     |     count(*) OVER (PARTITION BY category ORDER BY category DESC) as count
     |   FROM productRevenue) tmp
     | WHERE
     |   tmp.rank <= 2 """).show(false)

+----------+----------+-------+----+-----+
|product   |category  |revenue|rank|count|
+----------+----------+-------+----+-----+
|Pro2      |tablet    |6500   |1   |5    |
|Mini      |tablet    |5500   |2   |5    |
|Thin      |cell phone|6000   |1   |5    |
|Very thin |cell phone|6000   |1   |5    |
|Ultra thin|cell phone|5000   |2   |5    |
+----------+----------+-------+----+-----+