5
votes

I'm looking for the Pyspark equivalent to this question: How to get the number of elements in partition?.

Specifically, I want to programmatically count the number of elements in each partition of a pyspark RDD or dataframe (I know this information is available in the Spark Web UI).

This attempt:

df.foreachPartition(lambda iter: sum(1 for _ in iter))

results in:

AttributeError: 'NoneType' object has no attribute '_jvm'

I do not want to collect the contents of the iterator into memory.

1

1 Answers

12
votes

If you are asking: can we get the number of elements in an iterator without iterating through it? The answer is No.

But we don't have to store it in memory, as in the post you mentioned:

def count_in_a_partition(idx, iterator):
  count = 0
  for _ in iterator:
    count += 1
  return idx, count

data = sc.parallelize([
    1, 2, 3, 4
], 4)

data.mapPartitionsWithIndex(count_in_a_partition).collect()

EDIT

Note that your code is very close to the solution, just that mapPartitions needs to return an iterator:

def count_in_a_partition(iterator):
  yield sum(1 for _ in iterator)

data.mapPartitions(count_in_a_partition).collect()