SparkStreaming之解析mapWithState

最近经历挫折教育,今天闲得时间,整理状态管理之解析mapWithState。今天说道的mapWithState是从Spark1.6开始引入的一种新的状态管理机制,支持输出全量的状态和更新的状态,支持对状态超时的管理,和自主选择需要的输出。

举个例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package stateParse

import org.apache.spark.streaming._
import org.apache.spark.{SparkConf, SparkContext}

/**
* Author: shawn pross
* Date: 2018/09/10
* Description:
*/
object TestMapWithState {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName(s"${this.getClass.getSimpleName}")
conf.setMaster("local[2]")
val sc = new SparkContext(conf)
val ssc = new StreamingContext(sc, Seconds(3))
ssc.checkpoint("/checkpoint/")

val line = ssc.socketTextStream("127.0.0.1",9999)
val wordDStream = line.flatMap(_.split(",")).map(x=>(x,1))

//状态更新函数,output是输出,state是状态
val mappingFunc = (userId:String,value:Option[Int],state:State[Int])=>{
val sum= value.getOrElse(0) + state.getOption().getOrElse(0)
val output = (userId,sum)
state.update(sum)
output
}

//通过mapWithState更新状态,设置状态超时时间为1小时
val stateDStream = wordDStream.mapWithState(StateSpec.function(mappingFunc).timeout(Minutes(60))).print()

ssc.start()
ssc.awaitTermination()
}
}

mapWithState接收的参数是一个StateSpec对象,在StateSpec中封装了状态管理的函数。我们定义了一个状态更新函数mappingFunc,该函数会更新指定用户的状态,同时会返回更新后的状态,将该函数传给mapWithState,并设置状态超时时间。SparkStreaming通过根据我们定义的更新函数,在每个计算时间间隔内更新内部维护的状态,同时返回经过mappingFunc处理后的结果数据流。

我们来看看它的实现方式。

源码分析

代码例子中可以看到,我们是在StateSpec中封装状态管理的函数,深入研究,看看怎么实现的。我们发现mapWithState函数会创建了MapWithStateDStreamImpl对象;继续往下寻找,其实MapWithStateDStreamImp是继承了MapWithStateDStream,而MapWithStateDStream的父类实际上是DStream

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// mapWithState函数中创建了MapWithStateDStreamImpl对象
def mapWithState[StateType: ClassTag, MappedType: ClassTag](
spec: StateSpec[K, V, StateType, MappedType]
): MapWithStateDStream[K, V, StateType, MappedType] = {
new MapWithStateDStreamImpl[K, V, StateType, MappedType](
self,
spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
)
}

/**
* MapWithStateDStreamImpl继承MapWithStateDStream
* 创建InternalMapWithStateDStream类型对象的internalStream
*/
private[streaming] class MapWithStateDStreamImpl[
KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
dataStream: DStream[(KeyType, ValueType)],
spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {

private val internalStream =
new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)

...

override def compute(validTime: Time): Option[RDD[MappedType]] = {
internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}


//MapWithStateDStream的父类实际上是DStream
sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
ssc: StreamingContext) extends DStream[MappedType](ssc) {

/** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
def stateSnapshots(): DStream[(KeyType, StateType)]
}

我们可以看到,MapWithStateDStreamImpl 中创建了一个InternalMapWithStateDStream类型对象internalStream,在MapWithStateDStreamImpl的compute方法中调用了internalStream的getOrCompute方法。

实际上,InternalMapWithStateDStream中没有getOrCompute方法,最终是调用父类 DStream 的getOrCpmpute方法,该方法中最终会调用InternalMapWithStateDStream中重写的Compute方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
 /** Method that generates an RDD for the given time */
override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}


// Compute the new state RDD with previous state RDD and partitioned data RDD
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(new MapWithStateRDD(
prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
}

InternalMapWithStateDStream中重写的Compute方法是关键。根据给定的时间生成一个MapWithStateRDD,首先获取了先前状态的RDD:preStateRDD和当前时间的RDD:dataRDD,然后对dataRDD基于先前状态RDD的分区器进行重新分区获取partitionedDataRDD,最后将preStateRDDpartitionedDataRDD,用户定义的函数mappingFunction和时间戳传给新生成的MapWithStateRDD对象,得出结果并返回。

我们再来看一下关键的MapWithStateRDD的compute方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
override def compute(
partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
//preRecord代表一个分区的数据
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
val newRecord = MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timed-out data only when full scan is enabled
)
Iterator(newRecord)
}

MapWithStateRDDRecord 对应MapWithStateRDD 的一个分区,StateMap存储了key的状态,mappedData存储了mappingFunc函数的返回值。

1
2
private[streaming] case class MapWithStateRDDRecord[K, S, E](
var stateMap: StateMap[K, S], var mappedData: Seq[E])

我们再来看下MapWithStateRDDRecord.updateRecordWithData方法,newStateMap是创建一个StateMap,从过去的Record中复制(如果存在),否则就创建新的StateMap对象。最终返回MapWithSateRDDRecord对象交给MapWithStateRDD的compute函数,MapWithStateRDD的compute函数将其封装成Iterator返回出去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
): MapWithStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

val mappedData = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()

// Call the mapping function on each record in the data iterator, and accordingly
// update the states touched, and collect the data returned by the mapping function
dataIterator.foreach { case (key, value) =>
// 获取key对应的状态
wrappedState.wrap(newStateMap.get(key))
// 调用mappingFunc获取返回值
val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
//维护newStateMap的值
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated
|| (wrappedState.exists && timeoutThresholdTime.isDefined)) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
mappedData ++= returned
}

// Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTimingOutState(state)
val returned = mappingFunction(batchTime, key, None, wrappedState)
mappedData ++= returned
newStateMap.remove(key)
}
}

MapWithStateRDDRecord(newStateMap, mappedData)
}
}

至此,mapWithState方法源码执行过程水落石出。