3
votes

I need to collect partitions/batches from a big pyspark dataframe so that I can feed them into a neural network iteratively

My idea was to 1) partition the data, 2) Iteratively collect each partition, 3) transform the collected partition with toPandas()

I am a bit confused with methods like foreachPartition and mapPartitions because I can't iterate on them. Any idea?

1

1 Answers

7
votes

You can use the mapPartitions to map each partition into list of elements and get them in iterative way using toLocalIterator:

for partition in rdd.mapPartitions(lambda part: [list(part)]).toLocalIterator():
    print(len(partition)) # or do something else :-)