为了账号安全,请及时绑定邮箱和手机立即绑定

深入理解Spark 2.1 Core (十一):Shuffle Reduce 端的原理与源码分析

标签:
Spark

为了有效地实现容错,RDD提供了一种高度受限的共享内存,即RDD是只读的,并且只能通过其他RDD上的批量操作来创建(注:还可以由外部存储系数据集创建,如HDFS)

可知,我们在第九,第十篇博文所讲的是传统Hadoop MapReduce类似的,在最初从HDFS中读取数据生成HadoopRDD的过程。而RDD可以通过其他RDD上的批量操作来创建,所以这里的HadoopRDD对于下一个生成的ShuffledRDD可以视为Map端,当然下一个生成的ShuffledRDD可以被下下个ShuffledRDD视为Map端。反过来说,下一个ShuffledRDD可以被``HadoopRDD视作Reduce`端。

这篇博文,我们就来讲下ShuffleReduce端。其实在RDD迭代部分和第九篇博文类似,不同的是,这里调用的是rdd.ShuffledRDD.compute:

  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {  // 得到依赖
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]    // 调用getReader,传入dep.shuffleHandle 分区 上下文 
    // 得到Reader,调用read()
    // 得到迭代器
     SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

这里调用的是shuffle.sort.SortShuffleManagergetReader

  override def getReader[K, C](
      handle: ShuffleHandle,
      startPartition: Int,
      endPartition: Int,
      context: TaskContext): ShuffleReader[K, C] = {      // 生成返回 BlockStoreShuffleReader
    new BlockStoreShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
  }

shuffle.BlockStoreShuffleReader.read:

  override def read(): Iterator[Product2[K, C]] = {  // 实例化ShuffleBlockFetcherIterator
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,      // 通过消息发送获取 ShuffleMapTask 存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),      // 设置每次传输的大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,      // // 设置Int的大小
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))    // 基于配置的压缩和加密来包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }    val serializerInstance = dep.serializer.newInstance()    // 对每个流生成 k/v 迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }    // 每条记录读取后更新任务度量
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()    // 生成完整的迭代器
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())    // 传入metricIter到可中断的迭代器
    // 为了能取消迭代
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {    // 若需要对数据进行聚合
      if (dep.mapSideCombine) {        // 若需要进行Map端(对于下一个Shuffle来说)的合并
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)        // 若只需要进行Reduce端(对于下一个Shuffle来说)的合并
      } else {        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

   
    dep.keyOrdering match {      case Some(keyOrd: Ordering[K]) =>      // 若需要排序
      // 若spark.shuffle.spill设置为否的话
      // 将不会spill到磁盘
        val sorter =          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())      case None =>
        aggregatedIter
    }
  }

类调用关系图:

webp

这里写图片描述

下面我们来深入讲解下实例化ShuffleBlockFetcherIterator的过程:

  // 实例化ShuffleBlockFetcherIterator
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,      // 通过消息发送获取 ShuffleMapTask 存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),      // 设置每次传输的大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,      // // 设置Int的大小
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

获取元数据

mapOutputTracker.getMapSizesByExecutorId

首先我们会调用mapOutputTracker.getMapSizesByExecutorId

  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")    // 得到元数据
    val statuses = getStatuses(shuffleId)    // 返回格式为:
    // Seq[BlockManagerId,Seq[(shuffle block id, shuffle block size)]]
    statuses.synchronized {      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }

mapOutputTracker.getStatuses

  private def getStatuses(shuffleId: Int): Array[MapStatus] = {  // 尝试从本地获取数据
    val statuses = mapStatuses.get(shuffleId).orNull    if (statuses == null) {    // 若本地无数据
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")      val startTime = System.currentTimeMillis      var fetchedStatuses: Array[MapStatus] = null
      fetching.synchronized {        // 若以及有其他人也准备远程获取这数据的话
        // 则等待
        while (fetching.contains(shuffleId)) {          try {
            fetching.wait()
          } catch {            case e: InterruptedException =>
          }
        }        // 尝试直接获取数据
        fetchedStatuses = mapStatuses.get(shuffleId).orNull        if (fetchedStatuses == null) {          // 若还是不得不远程获取,
          // 则将shuffleId加入fetching
          fetching += shuffleId
        }
      }      if (fetchedStatuses == null) {
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)        try {        // 远程获取
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))          // 反序列化
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")          // 将数据加入mapStatuses
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      }
      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +        s"${System.currentTimeMillis - startTime} ms")      if (fetchedStatuses != null) {      // 若直接获取,则直接返回
        return fetchedStatuses
      } else {
        logError("Missing all output locations for shuffle " + shuffleId)        throw new MetadataFetchFailedException(
          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
      }
    } else {    // 若直接获取,则直接返回
      return statuses
    }
  }

mapOutputTracker.askTracker

trackerEndpoint发送消息GetMapOutputStatuses(shuffleId)

  protected def askTracker[T: ClassTag](message: Any): T = {    try {
      trackerEndpoint.askWithRetry[T](message)
    } catch {      case e: Exception =>
        logError("Error communicating with MapOutputTracker", e)        throw new SparkException("Error communicating with MapOutputTracker", e)
    }
  }

MapOutputTrackerMasterEndpoint.receiveAndReply

    case GetMapOutputStatuses(shuffleId: Int) =>      val hostPort = context.senderAddress.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)      val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))

可以看到,这里并不是直接返回消息,而是调用tracker.post:

  def post(message: GetMapOutputMessage): Unit = {
    mapOutputRequests.offer(message)
  }

mapOutputRequests加入GetMapOutputMessage(shuffleId, context)消息。这里的mapOutputRequests是链式阻塞队列。

  private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]

MapOutputTrackerMaster.MessageLoop.run

MessageLoop启一个线程不断的参数从mapOutputRequests读取数据:

  private class MessageLoop extends Runnable {    override def run(): Unit = {      try {        while (true) {          try {            val data = mapOutputRequests.take()             if (data == PoisonPill) {
              mapOutputRequests.offer(PoisonPill)              return
            }            val context = data.context            val shuffleId = data.shuffleId            val hostPort = context.senderAddress.hostPort
            logDebug("Handling request to send map output locations for shuffle " + shuffleId +              " to " + hostPort)              // 若读到数据
              // 则序列化
            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)            // 返回数据
            context.reply(mapOutputStatuses)
          } catch {            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {        case ie: InterruptedException => // exit
      }
    }
  }

MapOutputTracker.convertMapStatuses

我们回到mapOutputTracker.getMapSizesByExecutorId中返回的MapOutputTracker.convertMapStatuses

  private def convertMapStatuses(
      shuffleId: Int,
      startPartition: Int,
      endPartition: Int,
      statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    assert (statuses != null)    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]    for ((status, mapId) <- statuses.zipWithIndex) {      if (status == null) {        val errorMessage = s"Missing an output location for shuffle $shuffleId"
        logError(errorMessage)        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
      } else {        for (part <- startPartition until endPartition) {        // 返回的Seq中的结构是status.location,Seq[ShuffleBlockId,SizeForBlock]
          splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
            ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
        }
      }
    }    // 对Seq根据status.location进行排序
    splitsByAddress.toSeq
  }

划分本地和远程Block

让我回到new ShuffleBlockFetcherIterator

storage.ShuffleBlockFetcherIterator.initialize

当我们实例化ShuffleBlockFetcherIterator时,会调用initialize:

  private[this] def initialize(): Unit = {
    context.addTaskCompletionListener(_ => cleanup())    // 划分本地和远程的blocks
    val remoteRequests = splitLocalRemoteBlocks()    // 把远程请求随机的添加到队列中
    fetchRequests ++= Utils.randomize(remoteRequests)
    assert ((0 == reqsInFlight) == (0 == bytesInFlight),      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)    // 发送远程请求获取blocks
    fetchUpToMaxBytes()    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))    // 获取本地的Blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

storage.ShuffleBlockFetcherIterator.splitLocalRemoteBlocks

  private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {    // 是的远程请求最大长度为 maxBytesInFlight / 5
    // maxBytesInFlight: 为单次航班请求的最大字节数
    // 航班: 一批请求
    // 1/5 : 是为了提高请求批发度,允许5个请求分别从5个节点获取数据
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)    // 缓存需要远程请求的FetchRequest对象
    val remoteRequests = new ArrayBuffer[FetchRequest]    // 总共 blocks 的数量
    var totalBlocks = 0
    // 我们从上文可知blocksByAddress是根据status.location进行排序的
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size      if (address.executorId == blockManager.blockManagerId.executorId) {        // 若 executorId 相同 与本 blockManagerId.executorId,
        // 则从本地获取
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {      // 否则 远程请求
      // 得到迭代器
        val iterator = blockInfos.iterator     // 当前累计块的大小
        var curRequestSize = 0L     // 当前累加块
     // 累加: 若向一个节点频繁的请求字节很少的Block,
     // 那么会造成网络阻塞
        var curBlocks = new ArrayBuffer[(BlockId, Long)]        // iterator 中的block 都是同一节点的
        while (iterator.hasNext) {          val (blockId, size) = iterator.next()          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {            throw new BlockException(blockId, "Negative block size " + size)
          }          if (curRequestSize >= targetRequestSize) {            // 若累加到大于远程请求的尺寸
            // 往remoteRequests加入FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }        // 增加最后的请求
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }

获取Block

storage.ShuffleBlockFetcherIterator.fetchUpToMaxBytes

我们回到storage.ShuffleBlockFetcherIterator.initializefetchUpToMaxBytes()来深入讲解下如何获取远程的Block

  private def fetchUpToMaxBytes(): Unit = {    // Send fetch requests up to maxBytesInFlight
    // 单次航班请求数要小于最大航班请求数
    // 单次航班字节数数要小于最大航班字节数
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 ||
        (reqsInFlight + 1 <= maxReqsInFlight &&
          bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
      sendRequest(fetchRequests.dequeue())
    }
  }

storage.ShuffleBlockFetcherIterator.sendRequest

  private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // 可根据blockID查询block大小
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap    val remainingBlocks = new HashSet[String]() ++= sizeMap.keys    val blockIds = req.blocks.map(_._1.toString)    val address = req.address    // 关于shuffleClient.fetchBlocks我们会在之后的博文讲解
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,      new BlockFetchingListener {      // 请求成功
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {          ShuffleBlockFetcherIterator.this.synchronized {            if (!isZombie) {
              buf.retain()
              remainingBlocks -= blockId
              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                remainingBlocks.isEmpty))
              logDebug("remainingBlocks: " + remainingBlocks)
            }
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }      // 请求失败
        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }

storage.ShuffleBlockFetcherIterator.fetchLocalBlocks

我们再回过头来看获取本地blocks:

  private[this] def fetchLocalBlocks() {  // 获取迭代器
    val iter = localBlocks.iterator    while (iter.hasNext) {      val blockId = iter.next()      try {      // 遍历获取数据
      // blockManager.getBlockData 会在后续博文讲解
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
      } catch {        case e: Exception =>
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))          return
      }
    }
  }



作者:小爷Souljoy
链接:https://www.jianshu.com/p/2131207184be


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
数据库工程师
手记
粉丝
42
获赞与收藏
203

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消