MapWithStateDStream
MapWithStateDStream
为mapWithState
算子的结果;
def stateSnapshots(): DStream[(KeyType, StateType)]
MapWithStateDStream
是sealed abstract class
类型,因此所有其实现均在其srouce文件中可见;MapWithStateDStreamImpl
是MapWithStateDStream
的唯一实现;
sealed
关键字的作用:
其修饰的trait,class只能在当前文件里面被继承
用sealed修饰这样做的目的是告诉scala编译器在检查模式匹配的时候,让scala知道这些case的所有情况,scala就能够在编译的时候进行检查,看你写的代码是否有没有漏掉什么没case到,减少编程的错误。
MapWithStateDStreamImpl
MapWithStateDStreamImpl
为内部(私有)、其父依赖为key-value的DStream;其内部实现依赖`InternalMapWithStateDStream类;
slideDuration
/dependencies
值均取自internalStream
变量;
InternalMapWithStateDStream
InternalMapWithStateDStream
用于实现MapWithStateDStreamImpl
;其集成
DStream[MapWithStateRDDRecord[K, S, E]]
类,并默认使用MEMORY_ONLY
存储级别;其使用
StateSpec
的HashPartitioner
作为其分区;其强制执行checkpoint(
override val mustCheckpoint = true
),如果checkpointDuration
为空,则设置为sliceDuration窗口大小;
InternalMapWithStateDStream.compute()
/** Method that generates an RDD for the given time */ // 生成给定时间的RDD,其主要作用是将State操作->转换为MapWithRecordRDD override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { case Some(rdd) => if (rdd.partitioner != Some(partitioner)) { // If the RDD is not partitioned the right way, let us repartition it using the // partition index as the key. This is to ensure that state RDD is always partitioned // before creating another state RDD using it // 如果之前的RDD的partition不一致,需要基于partition index作为key进行repartition, // 这是确保state RDD 在使用之前是paritition正确 MapWithStateRDD.createFromRDD[K, V, S, E]( rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) } else { rdd } case None => MapWithStateRDD.createFromPairRDD[K, V, S, E]( spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), partitioner, validTime ) } // Compute the new state RDD with previous state RDD and partitioned data RDD // Even if there is no data RDD, use an empty one to create a new state RDD // 基于之前的state RDD,计算新的RDD // 如果没有data RDD,使用一个空的创建 val dataRDD = parent.getOrCompute(validTime).getOrElse { context.sparkContext.emptyRDD[(K, V)] } val partitionedDataRDD = dataRDD.partitionBy(partitioner) val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => (validTime - interval).milliseconds } Some(new MapWithStateRDD( prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) }
下面我们研究MapWithStateRDD.createFromPairRDD
方法,
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, updateTime: Time): MapWithStateRDD[K, V, S, E] = { // 将pairRDD转换为 MapWithStateRDDRecord() val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None new MapWithStateRDD[K, V, S, E]( stateRDD, emptyDataRDD, noOpFunc, updateTime, None) }
MapWithStateRDD
继承RDD, 其Dependencies依赖prevStateRDD和partitionedDataRDD;
RDD[MapWithStateRDDRecord[K, S, E]]( partitionedDataRDD.sparkContext, List( new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD), new OneToOneDependency(partitionedDataRDD))
其compute()逻辑:
override def compute( partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] val prevStateRDDIterator = prevStateRDD.iterator( stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None val newRecord = MapWithStateRDDRecord.updateRecordWithData( prevRecord, dataIterator, mappingFunction, batchTime, timeoutThresholdTime, removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled ) Iterator(newRecord) }
其主要依赖MapWithStateRDDRecord.updateRecordWithData
的方法,生成一个Iterator迭代器,其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one // 如果之前的state map存在,则clone它; // 否则则创建一个空的; // Key -> State之间的mapping ,存储了key的状态 val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } // 调动mappingFunction()的返回结果集,mapping function函数的返回值 val mappedData = new ArrayBuffer[E] // State的wrap实现 val wrappedState = new StateImpl[S]() // Call the mapping function on each record in the data iterator, and accordingly // update the states touched, and collect the data returned by the mapping function // 此处调用mappingFunction方法,并更新其state存储状态 dataIterator.foreach { case (key, value) => wrappedState.wrap(newStateMap.get(key)) val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) } else if (wrappedState.isUpdated || (wrappedState.exists && timeoutThresholdTime.isDefined)) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } mappedData ++= returned } // Get the timed out state records, call the mapping function on each and collect the // data returned // 用户可以设置超时时的处理机制,此处遍历所有超时key,并触发其超时逻辑 if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => wrappedState.wrapTimingOutState(state) val returned = mappingFunction(batchTime, key, None, wrappedState) mappedData ++= returned newStateMap.remove(key) } } MapWithStateRDDRecord(newStateMap, mappedData) }
StateMap
/** Internal interface for defining the map that keeps track of sessions. */private[streaming] abstract class StateMap[K, S] extends Serializable { /** Get the state for a key if it exists */ def get(key: K): Option[S] /** Get all the keys and states whose updated time is older than the given threshold time */ def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] /** Get all the keys and states in this map. */ def getAll(): Iterator[(K, S, Long)] /** Add or update state */ def put(key: K, state: S, updatedTime: Long): Unit /** Remove a key */ def remove(key: K): Unit /** * Shallow copy `this` map to create a new state map. * Updates to the new map should not mutate `this` map. */ def copy(): StateMap[K, S] def toDebugString(): String = toString() }
位置org.apache.spark.streaming.util.StateMap;
存储Spark Streaming 状态信息类;
默认提供
EmptyStateMap
和OpenHashMapBasedStateMap
两种实现;OpenHashMap为支持
nullable
d的HashMap,其性能为jdk默认HashMap的5倍以上,但是当处理0.0/0/0L/non-exist值时,用户需要小心;
Demo
object SparkStatefulRunner { /** * Aggregates User Sessions using Stateful Streaming transformations. * * Usage: SparkStatefulRunner <hostname> <port> * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. */ def main(args: Array[String]): Unit = { if (args.length < 2) { System.err.println("Usage: SparkRunner <hostname> <port>") System.exit(1) } val sparkConfig = loadConfigOrThrow[SparkConfiguration]("spark") val sparkContext = new SparkContext(sparkConfig.sparkMasterUrl, "Spark Stateful Streaming") val ssc = new StreamingContext(sparkContext, Milliseconds(4000)) ssc.checkpoint(sparkConfig.checkpointDirectory) val stateSpec = StateSpec .function(updateUserEvents _) .timeout(Minutes(sparkConfig.timeoutInMinutes)) ssc .socketTextStream(args(0), args(1).toInt) .map(deserializeUserEvent) .filter(_ != UserEvent.empty) .mapWithState(stateSpec) .foreachRDD { rdd => if (!rdd.isEmpty()) { rdd.foreach(maybeUserSession => maybeUserSession.foreach { userSession => // Store user session here println(userSession) }) } } ssc.start() ssc.awaitTermination() } def deserializeUserEvent(json: String): (Int, UserEvent) = { json.decodeEither[UserEvent] match { case \/-(userEvent) => (userEvent.id, userEvent) case -\/(error) => println(s"Failed to parse user event: $error") (UserEvent.empty.id, UserEvent.empty) } } def updateUserEvents(key: Int, value: Option[UserEvent], state: State[UserSession]): Option[UserSession] = { def updateUserSessions(newEvent: UserEvent): Option[UserSession] = { val existingEvents: Seq[UserEvent] = state .getOption() .map(_.userEvents) .getOrElse(Seq[UserEvent]()) val updatedUserSessions = UserSession(newEvent +: existingEvents) updatedUserSessions.userEvents.find(_.isLast) match { case Some(_) => state.remove() Some(updatedUserSessions) case None => state.update(updatedUserSessions) None } } value match { case Some(newEvent) => updateUserSessions(newEvent) case _ if state.isTimingOut() => state.getOption() } } }
作者:分裂四人组
链接:https://www.jianshu.com/p/f5efa9a4c10c
共同学习,写下你的评论
评论加载中...
作者其他优质文章