Spark's progamming model is not ideal for what you are trying to achieve, if we take the general problem of "aggregating elements depending on something that can only be known by inspecting previous elements", for two reasons :
- Spark does not, generally speaking, impose an ordering over the datas (but it can do it)
- Sparks deals with datas in partitions, and the sizes of the partitions are not usually (e.g. by default) dependant on the contents of the data, but by a default partitionner whose role is to divide datas evenly into partitions.
So it's not really a question of possible (it is), it rather is a question of "how much does it cost" (CPU / memory / time), for what it buys you.
A draft for an exact solution
If I were to shoot for an exact solution (by exact, I mean : preserving elements order, defined by, e.g. a timestamp in the JSONs, and grouping exactly consecutive inputs to the largest amount that approaches the boundary), I would :
- Impose an ordering on the RDD (there is a
sortBy function, which does that) : this is a full data shuffle, so it IS expensive.
- Give each row an id, after the sort, (there is a RDD version of zipWithIndex which respects ordering on the RDD, if it exists. There is also a faster dataframe equivalent, that creates monotically increasing indexes, albeit non consecutive ones).
- Collect the fraction of the result that is necessary to calculate size boundaries (the boundaries being the ids defined at step 2), pretty much as you did. This again is a full pass on the datas.
- Create a partitionner of datas that respects these boundaries (e.g. make sure that each elements of a single boundary are all in the same partition), and apply this partitionner to the RDD obtained at step 2 (another full shuffle on the datas). You just got yourself partitions that are logically equivalent to what you expect, e.g. groups of elements whose sum of sizes is under a certain limit. But the ordering inside each partition may have been lost in the repartitionning process. So you are not over yet !
- Then I would mapPartitions on this result to :
5.1. resort the datas locally to each partition,
5.2. group items in the data structure I need once sorted
One of the key being not to apply anything that messes with partitions between step 4 and 5.
As long as the "partition map" fits into the driver's memory, this is almost a practical solution, but a very costly one.
A simpler version (with relaxed constraints)
If it is ok for groups not to reach an optimal size, then the solution becomes much simpler (and it respects the ordering of the RDD if you have set one) : it is pretty much what you would code if there was no Spark at all, just an Iterator of JSON files.
Personnaly, I'd define a recursive accumulator function (nothing spark related) like so (I guess you could write your shorter, more efficient version using takeWhile) :
/**
* Aggregate recursively the contents of an iterator into a Seq[Seq[]]
* @param remainingJSONs the remaining original JSON contents to be aggregated
* @param currentAccSize the size of the active accumulation
* @param currentAcc the current aggregation of json strings
* @param resultAccumulation the result of aggregated JSON strings
*/
@tailrec
def acc(remainingJSONs: Iterator[String], currentAccSize: Int, currentAcc: Seq[String], resultAccumulation: Seq[Seq[String]]): Seq[Seq[String]] = {
// IF there is nothing more in the current partition
if (remainingJSONs.isEmpty) {
// And were not in the process of acumulating
if (currentAccSize == 0)
// Then return what was accumulated before
resultAccumulation
else
// Return what was accumulated before, and what was in the process of being accumulated
resultAccumulation :+ currentAcc
} else {
// We still have JSON items to process
val itemToAggregate = remainingJSONs.next()
// Is this item too large for the current accumulation ?
if (currentAccSize + itemToAggregate.size > MAX_SIZE) {
// Finish the current aggregation, and proceed with a fresh one
acc(remainingJSONs, itemToAggregate.size, Seq(itemToAggregate), resultAccumulation :+ currentAcc)
} else {
// Accumulate the current item on top of the current aggregation
acc(remainingJSONs, currentAccSize + itemToAggregate.size, currentAcc :+ itemToAggregate, resultAccumulation)
}
}
}
No you take this accumulating code, and make it run for each partition of spark's dataframe :
val jsonRDD = ...
val groupedJSONs = jsonRDD.mapPartitions(aPartition => {
acc(aPartition, 0, Seq(), Seq()).iterator
})
This will turn your RDD[String] into a RDD[Seq[String]] where each Seq[String] is made of consecutive RDD elements (which may be predictible if the RDD has been sorted, and may not otherwise), whose total length is below the threshold.
What may be "sub-optimal" is that, at the end of each partition, may lie a Seq[String] with just a few (possibly, a single) JSONs, while at the beginning of the following partition, a full one was created.