9
votes

Context: I am using Apache Spark to aggregate a running count of different event types from logs. The logs are stored in both Cassandra for historical analysis purposes and Kafka for realtime analysis purposes. Each log has a date and event type. For the purposes of simplicity, let's assume I wanted to keep track of the number of logs of a single type for each day.

We have two RDDs, an RDD of batch data from Cassandra and another streaming RDD from Kafka. Pseudocode:

CassandraJavaRDD<CassandraRow> cassandraRowsRDD = CassandraJavaUtil.javaFunctions(sc).cassandraTable(KEYSPACE, TABLE).select("date", "type");

JavaPairRDD<String, Integer> batchRDD = cassandraRowsRDD.mapToPair(new PairFunction<CassandraRow, String, Integer>() {
    @Override
    public Tuple2<String, Integer> call(CassandraRow row) {
        return new Tuple2<String, Integer>(row.getString("date"), 1);
    }
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
    @Override
    public Integer call(Integer count1, Integer count2) {
        return count1 + count2;
    }
});

save(batchRDD) // Assume this saves the batch RDD somewhere

...

// Assume we read a chunk of logs from the Kafka stream every x seconds.
JavaPairReceiverInputDStream<String, String> kafkaStream =  KafkaUtils.createStream(...);
JavaPairDStream<String, Integer> streamRDD = kafkaStream.flatMapToPair(new PairFlatMapFunction<Tuple2<String, String>, String, Integer>() {
    @Override
    public Iterator<Tuple2<String, Integer> call(Tuple2<String, String> data) {
        String jsonString = data._2;
        JSON jsonObj = JSON.parse(jsonString);
        Date eventDate = ... // get date from json object
        // Assume startTime is broadcast variable that is set to the time when the job started.
        if (eventDate.after(startTime.value())) { 
            ArrayList<Tuple2<String, Integer>> pairs = new ArrayList<Tuple2<String, Integer>>();
            pairs.add(new Tuple2<String, Integer>(jsonObj.get("date"), 1));
            return pairs;
        } else {
            return new ArrayList<Tuple2<String, Integer>>(0); // Return empty list when we ignore some logs
        }
    }
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
    @Override
    public Integer call(Integer count1, Integer count2) {
        return count1 + count2;
    }
}).updateStateByKey(new Function2<List<Integer>, Optional<List<Integer>>, Optional<Integer>>() {
    @Override
    public Optional<Integer> call(List<Integer> counts, Optional<Integer> state) {
        Integer previousValue = state.or(0l);
        Integer currentValue = ... // Sum of counts
        return Optional.of(previousValue + currentValue);
    }
});
save(streamRDD); // Assume this saves the stream RDD somewhere

sc.start();
sc.awaitTermination();

Question: How do I combine the results from the streamRDD with the batchRDD? Let's say that batchRDD has the following data and this job was run on 2014-10-16:

("2014-10-15", 1000000)
("2014-10-16", 2000000)

Since the Cassandra query only included all the data up to the start time of the batch query, we must read from Kafka when the query is finished, only considering logs after the job's start time. We assume that the query takes a long time. This means that I need to combine the historical results with the streaming results.

For illustration:

    |------------------------|-------------|--------------|--------->
tBatchStart             tStreamStart   streamBatch1  streamBatch2

Then suppose that in the first stream batch we got this data:

("2014-10-19", 1000)

Then I want to combine the batch RDD with this stream RDD so that the stream RDD now has the value:

("2014-10-19", 2001000)

Then suppose that in the second stream batch we got this data:

("2014-10-19", 4000)

Then the stream RDD should be updated to have the value:

("2014-10-19", 2005000)

And so on...

It's possible to use streamRDD.transformToPair(...) to combine the streamRDD data with the batchRDD data using a join, but if we do this for each stream chunk, then we would be adding the count from the batchRDD for every stream chunk making the state value "double counted", when it should only be added to the first stream chunk.

2

2 Answers

5
votes

To address this case, I'd union the base rdd with the result of the aggregated StateDStream that keeps the totals of the streaming data. This effectively provides a baseline for data reported on every streaming interval, without counting said baseline x times.

I tried that idea using the sample WordCount and it works. Drop this on the REPL for a live example:

(use nc -lk 9876 on a separate shell to provide input to the socketTextStream )

import org.apache.spark.SparkConf
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.storage.StorageLevel

@transient val defaults = List("magic" -> 2, "face" -> 5, "dust" -> 7 )
val defaultRdd = sc.parallelize(defaults)

@transient val ssc = new StreamingContext(sc, Seconds(10))
ssc.checkpoint("/tmp/spark")

val lines = ssc.socketTextStream("localhost", 9876, StorageLevel.MEMORY_AND_DISK_SER)
val words = lines.flatMap(_.split(" "))
val wordCount = words.map(x => (x, 1)).reduceByKey(_ + _)
val historicCount = wordCount.updateStateByKey[Int]{(newValues: Seq[Int], runningCount: Option[Int]) => 
    Some(newValues.sum + runningCount.getOrElse(0))
}
val runningTotal = historicCount.transform{ rdd => rdd.union(defaultRdd)}.reduceByKey( _+_ )

wordCount.print()
historicCount.print()
runningTotal.print()
ssc.start()
1
votes

You could give updateStateByKey a try:

def main(args: Array[String]) {

    val updateFunc = (values: Seq[Int], state: Option[Int]) => {
        val currentCount = values.foldLeft(0)(_ + _)
        val previousCount = state.getOrElse(0)
        Some(currentCount + previousCount)
    }

    // stream
    val ssc = new StreamingContext("local[2]", "NetworkWordCount", Seconds(1))
    ssc.checkpoint(".")
    val lines = ssc.socketTextStream("127.0.0.1", 9999)
    val words = lines.flatMap(_.split(" "))
    val pairs = words.map(word => (word, 1))
    val stateWordCounts = pairs.updateStateByKey[Int](updateFunc)
    stateWordCounts.print()
    ssc.start()
    ssc.awaitTermination()
}