SparkStreaming之解析updateStateByKey

说到Spark Streaming的状态管理,就会想到updateStateByKey,还有mapWithState。今天整理了一下,着重了解一下前者。

状态管理的需求

举一个最简单的需求例子来解释状态(state)管理,现在有这样的一个需求:计算从数据流开始到目前为止单词出现的次数。是不是看起来很眼熟,这其实就是一个升级版的wordcount,只不过需要在每个batchInterval计算当前batch的单词计数,然后对各个批次的计数进行累加。每一个批次的累积的计数就是当前的一个状态值。我们需要把这个状态保存下来,和后面批次单词的计数结果来进行计算,这样我们就能不断的在历史的基础上进行次数的更新。

SparkStreaming提供了两种方法来解决这个问题:updateStateByKey和mapWithState。mapWithState是1.6版本新增的功能,官方说性能较updateStateByKey提升10倍。

updateStateByKey概述

updateStateByKey,统计全局的Key的状态,就算没有数据输入,也会在每一个批次的时候返回之前的key的状态。假设5s产生一个批次的数据,那么就是5s的时候就更新一次key的值,然后返回。如果数据量又比较大,又需要不断的更新每个Key的state, 那么就一定会涉及到状态的保存和容错。所以,要使用updateStateByKey就需要设置一个checkpoint目录,开启checkpoint机制。因为key的State是在内存中维护的,如果宕机,则重启之后之前维护的状态就没有了,所以要长期保存的话则需要启用checkpoint,以便于恢复数据。

updateStateByKey代码例子

现在我们来看看怎么用的,首先看一个updateStateByKey使用的简单例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.streaming.{Seconds, StreamingContext}

/**
* Author: shawn pross
* Date: 2018/06/18
* Description: test updateStateByKey op
*/
object TestUpdateStateByKey {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName(s"${this.getClass.getSimpleName}")
conf.setMaster("local[2]")
val sc = new SparkContext(conf)
/**
* 创建context,3 second batch size
* 创建一个接收器(ReceiverInputDStream),接收从机器上的端口,通过socket发过来的数据
*/
val ssc = new StreamingContext(sc, Seconds(3))
val bathLine = ssc.socketTextStream("127.0.0.1", 9999)
/**
* 传入updateStateByKey的函数
*
* 源码定义:
* def updateStateByKey[S: ClassTag](
* updateFunc: (Seq[V], Option[S]) => Option[S]): DStream[(K, S)] = ssc.withScope {
* updateStateByKey(updateFunc, defaultPartitioner())
* }
*/
val updateFunc = (currValues: Seq[Int], prevValueState: Option[Int]) => {
//通过Spark内部的reduceByKey按key规约,然后这里传入某key当前批次的Seq/List,再计算当前批次的总和
val currentCount = currValues.sum
// 已累加的值
val previousCount = prevValueState.getOrElse(0)
// 返回累加后的结果,是一个Option[Int]类型
Some(currentCount + previousCount)
}
//先聚合成键值对的形式
bathLine.map(x=>(x,1)).updateStateByKey(updateFunc).print()

ssc.start()
ssc.awaitTermination()
}
}

代码很简单,注释也比较详细。其中要说明的是updateStateByKey的参数还有几个可选项:

1
2
3
4
5
6
7
8
9
10
11
12
def updateStateByKey[S: ClassTag](
//状态更新功能
updateFunc: (Seq[V], Option[S]) => Option[S],
//用于控制新RDD中每个RDD的分区的分区程序
partitioner: Partitioner,
//是否记住生成的RDD中的分区对象
rememberPartitioner: Boolean,
//每个键的初始状态值
initialRDD: RDD[(K, S)],
): DStream[(K, S)] = ssc.withScope {
...
}

updateStateByKey源码分析

通过上面简单的小例子可以知道,使用updateStateByKey是需要先转换为键值对的形式的,而map返回的是MappedDStream,而进入MappedDStream中也没有updateStateByKey方法,然后其父类DStream中也没有。但是DStream的半生对象中有一个隐式的转换函数toPairDStreamFunctions

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// map实现,返回MappedDStream
def map[U: ClassTag](mapFunc: T => U): DStream[U] = ssc.withScope {
new MappedDStream(this, context.sparkContext.clean(mapFunc))
}

// MappedDStream父类是DStream
class MappedDStream[T: ClassTag, U: ClassTag] (
parent: DStream[T],
mapFunc: T => U
) extends DStream[U](parent.ssc) {
...
}

// DStream中隐式的转换函数
implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):
PairDStreamFunctions[K, V] = {
new PairDStreamFunctions[K, V](stream)
}

看到new PairDStreamFunctions就不陌生了。PairDStreamFunctions中存在updateStateByKey方法,Seq[V]表示当前key对应的所有值,Option[S] 是当前key的历史状态,返回的是新的状态。也就是绕了一个圈子又回到原地。最后updateStateByKey最终会在这里面new出了一个StateDStream对象。

1
2
3
4
5
6
7
8
9
10
11
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
initialRDD: RDD[(K, S)]): DStream[(K, S)] = ssc.withScope {
val cleanedFunc = ssc.sc.clean(updateFunc)
val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
cleanedFunc(it)
}
new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, Some(initialRDD))
}

继续进去StateDStream看看,在其compute方法中,会先获取上一个batch计算出的RDD(包含了至程序开始到上一个batch单词的累计计数),然后在获取本次batch中StateDStream的父类计算出的RDD(本次batch的单词计数)分别是prevStateRDDparentRDD,然后在调用 computeUsingPreviousRDD 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private [this] def computeUsingPreviousRDD(
batchTime: Time,
parentRDD: RDD[(K, V)],
prevStateRDD: RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map { t =>
val itr = t._2._2.iterator
val headOption = if (itr.hasNext) Some(itr.next()) else None
(t._1, t._2._1.toSeq, headOption)
}
updateFuncLocal(batchTime, i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
}

最后返回stateRDD结果。至此,updateStateByKey方法源码执行过程水落石出。