34
votes

Reading Spark method sortByKey :

sortByKey([ascending], [numTasks])   When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument.

Is it possible to return just "N" amount of results. So instead of returning all results, just return the top 10. I could convert the sorted collection to an Array and use take method but since this is an O(N) operation is there a more efficient method ?

3
So you know how to sort, and you are asking how to take the top N. Can I suggest editing the question summary?Daniel Darabos

3 Answers

19
votes

Most likely you have already perused the source code:

  class OrderedRDDFunctions {
   // <snip>
  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
    val part = new RangePartitioner(numPartitions, self, ascending)
    val shuffled = new ShuffledRDD[K, V, P](self, part)
    shuffled.mapPartitions(iter => {
      val buf = iter.toArray
      if (ascending) {
        buf.sortWith((x, y) => x._1 < y._1).iterator
      } else {
        buf.sortWith((x, y) => x._1 > y._1).iterator
      }
    }, preservesPartitioning = true)
  }

And, as you say, the entire data must go through the shuffle stage - as seen in the snippet.

However, your concern about subsequently invoking take(K) may not be so accurate. This operation does NOT cycle through all N items:

  /**
   * Take the first num elements of the RDD. It works by first scanning one partition, and use the
   * results from that partition to estimate the number of additional partitions needed to satisfy
   * the limit.
   */
  def take(num: Int): Array[T] = {

So then, it would seem:

O(myRdd.take(K)) << O(myRdd.sortByKey()) ~= O(myRdd.sortByKey.take(k)) (at least for small K) << O(myRdd.sortByKey().collect()

51
votes

If you only need the top 10, use rdd.top(10). It avoids sorting, so it is faster.

rdd.top makes one parallel pass through the data, collecting the top N in each partition in a heap, then merges the heaps. It is an O(rdd.count) operation. Sorting would be O(rdd.count log rdd.count), and incur a lot of data transfer — it does a shuffle, so all of the data would be transmitted over the network.

8
votes

Another option, at least from PySpark 1.2.0, is the use of takeOrdered.

In ascending order:

rdd.takeOrdered(10)

In descending order:

rdd.takeOrdered(10, lambda x: -x)

Top k values for k,v pairs:

rdd.takeOrdered(10, lambda (k, v): -v)