为了账号安全,请及时绑定邮箱和手机立即绑定

Spark矩阵相乘原理

标签:
Spark

1 hadoop中矩阵相乘原理



如果想要了解Spark中的矩阵相乘原理,需要先大体了解一下MapReduce的矩阵相乘过程,设

webp


webp

,那么

webp


矩阵乘法要求左矩阵A!的列数与右矩阵B的行数相等。
现在我们来分析一下,哪些操作是相互独立的(从而可以进行分布式计算)。很显然,C中各个元素的计算都是相互独立的。这样,我们在Map阶段,可以把计算C中每个元素所需要的元素集中到同一个key中,然后,在Reduce阶段就可以从中解析出各个元素来计算。
整个计算过程如下所示:

webp

MapReduce中的矩阵相乘过程

2 Spark中的矩阵相乘原理

在Spark中,Spark自带的org.apache.spark.mllib.linalg.distributed.BlockMatrix实现了分布式矩阵乘法,BlockMatrix是使用内积法实现的分布式分块矩阵的乘法。
Spark中的分块矩阵乘法不使用外积法实现,主要考虑到外积法内存占用量大。
Spark自带BlockMatrix乘法源码分析:
必要的注释已经在源码中给出。

def multiply(other: BlockMatrix): BlockMatrix = {
.......
if  (colsPerBlock == other.rowsPerBlock) {//GridPartitioner一共分为numRowBlocks*other.numColBlocks个partitionval resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks,
math.max(blocks.partitions.length, other.blocks.partitions.length))// 这里是计算每个leftDestinations和rightDestinations的类型都是Map[(Int,Int),Set[Int]],也就是先计算左右矩阵的// 每一块会shuffle到哪个partitionval (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner)// Each block of A must be multiplied with the corresponding blocks in the columns of B.val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
}// Each block of B must be multiplied with the corresponding blocks in each row of A.val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
}// GridPartitioner一共有numRowBlocks*other.numColBlocks 个分区,所以在cogroup的时候,在计算A*B=C的时候,C矩阵所用到的所有A和B中的//分块都会在一个partition中,在reduceByKey的时候就可以进行combineByKey进行优化,事实上在reduceByKey的过程中,只有相加的过程,// 没有shuffle的过程。val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) =>a.flatMap { case(leftRowIndex, leftColIndex, leftBlock) =>b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>//在进行矩阵乘法实现的时候,本地矩阵计算使用com.github.fommil.netlib包提供的矩阵算法,矩阵加法调用的是scalanlp包提供的矩阵加法val C = rightBlock match{case dense: DenseMatrix => leftBlock.multiply(dense)case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense)case _ =>throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.")
}
 ((leftRowIndex, rightColIndex), C.toBreeze)
}
}
}.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze)// TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matricesnew BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols())
} else {.......
}
}

以上代码有一个simulateMultiply方法比较重要,源码注释如下:

private[distributed] def simulateMultiply(
other: BlockMatrix,partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cachedval rightMatrix = other.blocks.keys.collect()//以下这段代码这样理解,假设A*B=C,因为A11在计算C11到C1n的时候会用到,所以A11在计算C11到C1n的机器都会存放一份。val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>//左矩阵中列号会和右矩阵行号相同的块相乘,得到所有右矩阵中行索引和左矩阵中列索引相同的矩阵的位置。// 由于有这个判断,右矩阵中没有值的快左矩阵就不会重复复制了,避免了零值计算。val rightCounterparts = rightMatrix.filter(_._1 == colIndex)// 因为矩阵乘完之后还有相加的操作(reduceByKey),相加的操作如果在同一部机器上可以用combineBy进行优化,// 这里直接得到每一个分块在进行完乘法之后会在哪些partition中用到。            val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))        ((rowIndex, colIndex), partitions.toSet)
}.toMap
val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))  
  ((rowIndex, colIndex), partitions.toSet)}
.toMap
(leftDestinations, rightDestinations)}

从代码中可以知道,Spark中自带的分块矩阵乘法要求每个Executor的内存最少能够存下左矩阵一行中所有非零块和右矩阵一列中的所有非零块。这样使得BlockMatrix乘法算法更高效,能够避免不必要的零值计算。在计算的过程中只需要一次shuffle。在实践中,使用Spark自带的BlockMatrix算法要注意内存的使用,分块的时候,块的大小是多少除了注意内存之外,还要注意令子块中的数据能够尽量的紧凑,减少零值计算。



作者:九七学姐
链接:https://www.jianshu.com/p/f0c8e5a1c309


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
数据库工程师
手记
粉丝
42
获赞与收藏
203

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消