I have a Spark scala streaming app that sessionizes user generated events coming from Kafka, using mapWithState. I want to mature the setup by enabling to pauze and resume the app in the case of maintenance. I’m already writing kafka offset information to a database, so when restarting the app I can pick up at the last offset processed. But I also want to keep the state information.
So my goal is to;
- materialize session information after a key identifying the user times out.
- materialize a .stateSnapshot() when I gracefully shutdown the application, so I can use that data when restarting the app by feeding it as a parameter to StateSpec.
1 is working, with 2 I have issues.
For the sake of completeness, I also describe 1 because I’m always interested in a better solution for it:
1) materializing session info after key time out
Inside my update function for mapWithState, I have:
if (state.isTimingOut()) {
// if key is timing out.
val output = (key, stateFilterable(isTimingOut = true
, start = state.get().start
, end = state.get().end
, duration = state.get().duration
))
That isTimingOut boolean I then later on use as:
streamParsed
.filter(a => a._2.isTimingOut)
.foreachRDD(rdd =>
rdd
.map(stuff => Model(key = stuff._1,
start = stuff._2.start,
duration = stuff._2.duration)
.saveToCassandra(keyspaceName, tableName)
)
2) materialize a .stateSnapshot() with graceful shutdown
Materializing snapshot info doesn’t work. What is tried:
// define a class Listener
class Listener(ssc: StreamingContext, state: DStream[(String, stateFilterable)]) extends Runnable {
def run {
if( ssc == null )
System.out.println("The spark context is null")
else
System.out.println("The spark context is fine!!!")
var input = "continue"
while( !input.equals("D")) {
input = readLine("Press D to kill: ")
System.out.println(input + " " + input.equals("D"))
}
System.out.println("Accessing snapshot and saving:")
state.foreachRDD(rdd =>
rdd
.map(stuff => Model(key = stuff._1,
start = stuff._2.start,
duration = stuff._2.duration)
.saveToCassandra("some_keyspace", "some_table")
)
System.out.println("Stopping context!")
ssc.stop(true, true)
System.out.println("We have stopped!")
}
}
// Inside the app object:
val state = streamParsed.stateSnapshots()
var listener = new Thread(new Listener(ssc, state))
listener.start()
So the full code becomes:
package main.scala.cassandra_sessionizing
import java.text.SimpleDateFormat
import java.util.Calendar
import org.apache.spark.streaming.dstream.{DStream, MapWithStateDStream}
import scala.collection.immutable.Set
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.streaming._
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.kafka.KafkaUtils
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{StructType, StructField, StringType, DoubleType, LongType, ArrayType, IntegerType}
import _root_.kafka.serializer.StringDecoder
import com.datastax.spark.connector._
import com.datastax.spark.connector.cql.CassandraConnector
case class userAction(datetimestamp: Double
, action_name: String
, user_key: String
, page_id: Integer
)
case class actionTuple(pages: scala.collection.mutable.Set[Int]
, start: Double
, end: Double)
case class stateFilterable(isTimingOut: Boolean
, start: Double
, end: Double
, duration: Int
, pages: Set[Int]
, events: Int
)
case class Model(user_key: String
, start: Double
, duration: Int
, pages: Set[Int]
, events: Int
)
class Listener(ssc: StreamingContext, state: DStream[(String, stateFilterable)]) extends Runnable {
def run {
var input = "continue"
while( !input.equals("D")) {
input = readLine("Press D to kill: ")
System.out.println(input + " " + input.equals("D"))
}
// Accessing snapshot and saving:
state.foreachRDD(rdd =>
rdd
.map(stuff => Model(user_key = stuff._1,
start = stuff._2.start,
duration = stuff._2.duration,
pages = stuff._2.pages,
events = stuff._2.events))
.saveToCassandra("keyspace1", "snapshotstuff")
)
// Stopping context
ssc.stop(true, true)
}
}
object cassandra_sessionizing {
// where we'll store the stuff in Cassandra
val tableName = "sessionized_stuff"
val keyspaceName = "keyspace1"
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("cassandra-sessionizing")
.set("spark.cassandra.connection.host", "10.10.10.10")
.set("spark.cassandra.auth.username", "keyspace1")
.set("spark.cassandra.auth.password", "blabla")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// setup the cassandra connector and recreate the table we'll use for storing the user session data.
val cc = CassandraConnector(conf)
cc.withSessionDo { session =>
session.execute(s"""DROP TABLE IF EXISTS $keyspaceName.$tableName;""")
session.execute(
s"""CREATE TABLE IF NOT EXISTS $keyspaceName.$tableName (
user_key TEXT
, start DOUBLE
, duration INT
, pages SET<INT>
, events INT
, PRIMARY KEY(user_key, start)) WITH CLUSTERING ORDER BY (start DESC)
;""")
}
// setup the streaming context and make sure we can checkpoint, given we're using mapWithState.
val ssc = new StreamingContext(sc, Seconds(60))
ssc.checkpoint("hdfs:///user/keyspace1/streaming_stuff/")
// Defining the stream connection to Kafka.
val kafkaStream = {
KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc,
Map("metadata.broker.list" -> "kafka1.prod.stuff.com:9092,kafka2.prod.stuff.com:9092"), Set("theTopic"))
}
// this schema definition is needed so the json string coming from Kafka can be parsed into a dataframe using spark read.json.
// if an event does not conform to this structure, it will result in all null values, which are filtered out later.
val struct = StructType(
StructField("datetimestamp", DoubleType, nullable = true) ::
StructField("sub_key", StructType(
StructField("user_key", StringType, nullable = true) ::
StructField("page_id", IntegerType, nullable = true) ::
StructField("name", StringType, nullable = true) :: Nil), nullable = true) ::
)
/*
this is the function needed to keep track of an user key's session.
3 options:
1) key already exists, and new values are coming in to be added to the state.
2) key is new, so initialize the state with the incoming value
3) key is timing out, so mark it with a boolean that can be used by filtering later on. Given the boolean, the data can be materialized to cassandra.
*/
def trackStateFunc(batchTime: Time
, key: String
, value: Option[actionTuple]
, state: State[stateFilterable])
: Option[(String, stateFilterable)] = {
// 1 : if key already exists and we have a new value for it
if (state.exists() && value.orNull != null) {
var current_set = state.getOption().get.pages
var current_start = state.getOption().get.start
var current_end = state.getOption().get.end
if (value.get.pages != null) {
current_set ++= value.get.pages
}
current_start = Array(current_start, value.get.start).min // the starting epoch is used to initialize the state, but maybe some earlier events are processed a bit later.
current_end = Array(current_end, value.get.end).max // always update the end time of the session with new events coming in.
val new_event_counter = state.getOption().get.events + 1
val new_output = stateFilterable(isTimingOut = false
, start = current_start
, end = current_end
, duration = (current_end - current_start).toInt
, pages = current_set
, events = new_event_counter)
val output = (key, new_output)
state.update(new_output)
return Some(output)
}
// 2: if key does not exist and we have a new value for it
else if (value.orNull != null) {
var new_set: Set[Int] = Set()
val current_value = value.get.pages
if (current_value != null) {
new_set ++= current_value
}
val event_counter = 1
val current_start = value.get.start
val current_end = value.get.end
val new_output = stateFilterable(isTimingOut = false
, start = current_start
, end = current_end
, duration = (current_end - current_start).toInt
, pages = new_set
, events = event_counter)
val output = (key, new_output)
state.update(new_output)
return Some(output)
}
// 3: if key is timing out
if (state.isTimingOut()) {
val output = (key, stateFilterable(isTimingOut = true
, start = state.get().start
, end = state.get().end
, duration = state.get().duration
, pages = state.get().pages
, events = state.get().events
))
return Some(output)
}
// this part of the function should never be reached.
throw new Error(s"Entered dead end with $key $value")
}
// defining the state specification used later on as a step in the stream pipeline.
val stateSpec = StateSpec.function(trackStateFunc _)
.numPartitions(16)
.timeout(Seconds(4000))
// RDD 1
val streamParsedRaw = kafkaStream
.map { case (k, v: String) => v } // key is empty, so get the value containing the json string.
.transform { rdd =>
val df = sqlContext.read.schema(struct).json(rdd) // apply schema defined above and parse the json into a dataframe,
.selectExpr("datetimestamp"
, "action.name AS action_name"
, "action.user_key"
, "action.page_id"
)
df.as[userAction].rdd // transform dataframe into spark Dataset so we easily cast to the case class userAction.
}
val initialCount = actionTuple(pages = collection.mutable.Set(), start = 0.0, end = 0.0)
val addToCounts = (left: actionTuple, ua: userAction) => {
val current_start = ua.datetimestamp
val current_end = ua.datetimestamp
if (ua.page_id != null) left.pages += ua.page_id
actionTuple(left.pages, current_start, current_end)
}
val sumPartitionCounts = (p1: actionTuple, p2: actionTuple) => {
val current_start = Array(p1.start, p2.start).min
val current_end = Array(p1.end, p2.end).max
actionTuple(p1.pages ++= p2.pages, current_start, current_end)
}
// RDD 2: add the mapWithState part.
val streamParsed = streamParsedRaw
.map(s => (s.user_key, s)) // create key value tuple so we can apply the mapWithState to the user_key.
.transform(rdd => rdd.aggregateByKey(initialCount)(addToCounts, sumPartitionCounts)) // reduce to one row per user key for each batch.
.mapWithState(stateSpec)
// RDD 3: if the app is shutdown, this rdd should be materialized.
val state = streamParsed.stateSnapshots()
state.print(2)
// RDD 4: Crucial: loop up sessions timing out, extract the fields that we want to keep and materialize in Cassandra.
streamParsed
.filter(a => a._2.isTimingOut)
.foreachRDD(rdd =>
rdd
.map(stuff => Model(user_key = stuff._1,
start = stuff._2.start,
duration = stuff._2.duration,
pages = stuff._2.pages,
events = stuff._2.events))
.saveToCassandra(keyspaceName, tableName)
)
// add a listener hook that we can use to gracefully shutdown the app and materialize the RDD containing the state snapshots.
var listener = new Thread(new Listener(ssc, state))
listener.start()
ssc.start()
ssc.awaitTermination()
}
}
But when running this (so launching the app, waiting several minutes for some state information to build up, and then entering key 'D', I get the below. So I can't do anything 'new' with a dstream after quitting the ssc. I hoped to move from a DStream RDD to a regular RDD, quit the streaming context, and wrap up by saving the normal RDD. But don't know how. Hope someone can help!
Exception in thread "Thread-52" java.lang.IllegalStateException: Adding new inputs, transformations, and output operations after sta$
ting a context is not supported
at org.apache.spark.streaming.dstream.DStream.validateAtInit(DStream.scala:222)
at org.apache.spark.streaming.dstream.DStream.<init>(DStream.scala:64)
at org.apache.spark.streaming.dstream.ForEachDStream.<init>(ForEachDStream.scala:34)
at org.apache.spark.streaming.dstream.DStream.org$apache$spark$streaming$dstream$DStream$$foreachRDD(DStream.scala:687)
at org.apache.spark.streaming.dstream.DStream$$anonfun$foreachRDD$1.apply$mcV$sp(DStream.scala:661)
at org.apache.spark.streaming.dstream.DStream$$anonfun$foreachRDD$1.apply(DStream.scala:659)
at org.apache.spark.streaming.dstream.DStream$$anonfun$foreachRDD$1.apply(DStream.scala:659)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:111)
at org.apache.spark.SparkContext.withScope(SparkContext.scala:714)
at org.apache.spark.streaming.StreamingContext.withScope(StreamingContext.scala:260)
at org.apache.spark.streaming.dstream.DStream.foreachRDD(DStream.scala:659)
at main.scala.feaUS.Listener.run(feaUS.scala:119)
at java.lang.Thread.run(Thread.java:745)