使用Sort等对数据进行排序,其中用到了TimSort
这篇博文我们就来深入理解下TimSort
可视化
推荐先观看下 Youtube 上关于TimSort可视化的视频。对TimSort
有个感性的了解。
理解timsort
看完视频后也许你会发现TimSort
和MergeSort
非常像。没错,这里推荐先阅读关于理解timsort的博文,你就会发现它其实只是对归并排序进行了一系列的改进。其中有一些是很聪明的,而也有一些是相当简单直接的。这些大大小小的改进聚集起来使得算法的效率变得十分的吸引人。
Spark TimSort 源码分析
其实OpenJDK
在Java SE 7
的Arrays
关于Object
元素数组的sort
也使用了TimSort
,而Spark
的org.apache.spark.util.collection
包中的用Java
编写的TimSort
也和Java SE 7
中的TimSort
没有太大区别。
public void sort(Buffer a, int lo, int hi, Comparator<? super K> c) { assert c != null; // 未排序的数组长度 int nRemaining = hi - lo; // 若数组大小为 0 或者 1 // 那么就以及排序了 if (nRemaining < 2) return; // 若是小数组 // 则不使用归并排序 if (nRemaining < MIN_MERGE) { // 得到递增序列的长度 int initRunLen = countRunAndMakeAscending(a, lo, hi, c); // 二分插入排序 binarySort(a, lo, hi, lo + initRunLen, c); return; } // 栈 SortState sortState = new SortState(a, c, hi - lo); // 得到最小run长度 int minRun = minRunLength(nRemaining); do { // 得到递增序列的长度 int runLen = countRunAndMakeAscending(a, lo, hi, c); // 若run太小, // 使用二分插入排序 if (runLen < minRun) { int force = nRemaining <= minRun ? nRemaining : minRun; binarySort(a, lo, lo + force, lo + runLen, c); runLen = force; } // 入栈 sortState.pushRun(lo, runLen); // 可能进行归并 sortState.mergeCollapse(); // 查找下一run的预操作 lo += runLen; nRemaining -= runLen; } while (nRemaining != 0); // 归并所有剩余的run,完成排序 assert lo == hi; sortState.mergeForceCollapse(); assert sortState.stackSize == 1; }
我们接下来逐个深入的讲解:
countRunAndMakeAscending
private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? super K> c) { assert lo < hi; int runHi = lo + 1; if (runHi == hi) return 1; K key0 = s.newKey(); K key1 = s.newKey(); // 找到run的尾部 if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // 若是递减的,找到尾部反转run while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) runHi++; reverseRange(a, lo, runHi); } else { while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) runHi++; } // 返回run的长度 return runHi - lo; }
binarySort
private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super K> c) { assert lo <= start && start <= hi; if (start == lo) start++; K key0 = s.newKey(); K key1 = s.newKey(); Buffer pivotStore = s.allocate(1); // 将位置[start,hi)上的元素二分插入排序到已经有序的[lo,start)序列中 for ( ; start < hi; start++) { s.copyElement(a, start, pivotStore, 0); K pivot = s.getKey(pivotStore, 0, key0); int left = lo; int right = start; assert left <= right; while (left < right) { int mid = (left + right) >>> 1; if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) right = mid; else left = mid + 1; } assert left == right; int n = start - left; // 对插入做简单的优化 switch (n) { case 2: s.copyElement(a, left + 1, a, left + 2); case 1: s.copyElement(a, left, a, left + 1); break; default: s.copyRange(a, left, a, left + 1, n); } s.copyElement(pivotStore, 0, a, left); } }
minRunLength
private int minRunLength(int n) { assert n >= 0; int r = 0; // 这里 MIN_MERGE 为 2 的某次方 // if n < MIN_MERGE , // then 直接返回 n // else if n >= MIN_MERGE 且 n(>1) 为 2 的某次方, // then n 的二进制低位第1位 为 0,r |= (n & 1) 一直为 0 ,即返回的是 MIN_MERGE / 2 // else r 为之后一次循环的n的二进制低位第1位值 k ,返回的值 MIN_MERGE/2< k < MIN_MERGE while (n >= MIN_MERGE) { r |= (n & 1); n >>= 1; } return n + r; }
SortState.pushRun
入栈
private void pushRun(int runBase, int runLen) { this.runBase[stackSize] = runBase; this.runLen[stackSize] = runLen; stackSize++; }
SortState.mergeCollapse
这部分代码OpenJDK
中存在着bug
,我们先来看一下Java SE 7
是如何实现的:
private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { if (runLen[n - 1] < runLen[n + 1]) n--; mergeAt(n); } else if (runLen[n] <= runLen[n + 1]) { mergeAt(n); } else { break; } } }
我们来举个例子:
当栈中的片段长度为:
120, 80, 25, 20
我们插入长度的30的片段,由于25 < 20 + 30
并且 25 < 30
,所以得到:
120, 80, 45, 30
现在,由于80 > 45 + 30
并且 45 > 30
,于是合并结束。但这并不完全符合根据不变式的重存储,因为120 < 80 + 45
!
更多细节可以参阅相关博文,Spark
也对此bug
进行了修复,修复后的代码如下:
private void mergeCollapse() { while (stackSize > 1) { int n = stackSize - 2; if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1]) || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) { if (runLen[n - 1] < runLen[n + 1]) n--; } else if (runLen[n] > runLen[n + 1]) { break; } mergeAt(n); } }
SortState. mergeAt
private void mergeAt(int i) { assert stackSize >= 2; assert i >= 0; assert i == stackSize - 2 || i == stackSize - 3; int base1 = runBase[i]; int len1 = runLen[i]; int base2 = runBase[i + 1]; int len2 = runLen[i + 1]; assert len1 > 0 && len2 > 0; assert base1 + len1 == base2; // 若 i 是从栈顶数第3个位置 // 则 将栈顶元素 赋值到 从栈顶数第2个位置 runLen[i] = len1 + len2; if (i == stackSize - 3) { runBase[i + 1] = runBase[i + 2]; runLen[i + 1] = runLen[i + 2]; } stackSize--; K key0 = s.newKey(); // 从 run1 中找到 run2的第1个元素的位置 // 在这之前的run1的元素都可以被忽略 int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c); assert k >= 0; base1 += k; len1 -= k; if (len1 == 0) return; // 从 run2 中找到 run1的最后1个元素的位置 // 在这之后的run2的元素都可以被忽略 len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); assert len2 >= 0; if (len2 == 0) return; // 归并run // 使用 min(len1, len2) 长度的临时数组 if (len1 <= len2) mergeLo(base1, len1, base2, len2); else mergeHi(base1, len1, base2, len2); }
SortState. gallopRight
// key: run2的第1个值 // a: 数组 // base: run1的起始为位置 // len: run1的长度 // hint: 从run1的hint位置开始查找,这里我们传入的值为 0 private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator<? super K> c) { assert len > 0 && hint >= 0 && hint < len; // 对二分查找的优化: // 我们要从 run1中 截取出这样一段数组 // lastOfs = k+1 // ofs = 2×k+1 // run1[lastOfs] <= key <= run1[ofs] // 即在[lastOfs,ofs],做二分查找 int ofs = 1; int lastOfs = 0; K key1 = s.newKey(); // 若 run2的第1个值 < run1的第1个值 // 其实我知道,可以直接返回 0 // 但这里还是走了完整的算法流程 if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) { // maxOfs = 1 int maxOfs = hint + 1; // 不进入循环 while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { lastOfs = ofs; ofs = (ofs << 1) + 1; if (ofs <= 0) ofs = maxOfs; } // 不进入 if (ofs > maxOfs) ofs = maxOfs; // tmp = 0 int tmp = lastOfs; // lastOfs = -1 lastOfs = hint - ofs; // ofs = 0 ofs = hint - tmp; } else { // 这种情况下,算法才会发挥真正的作用 // maxOfs = len int maxOfs = len - hint; while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { // 更新 lastOfs 和 ofs lastOfs = ofs; ofs = (ofs << 1) + 1; // 防止溢出 if (ofs <= 0) ofs = maxOfs; } if (ofs > maxOfs) ofs = maxOfs; // 这里都不会变 lastOfs += hint; ofs += hint; } assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; // 进行二分查找 lastOfs++; while (lastOfs < ofs) { int m = lastOfs + ((ofs - lastOfs) >>> 1); if (c.compare(key, s.getKey(a, base + m, key1)) < 0) // key < a[b + m] ofs = m; else // a[b + m] <= key lastOfs = m + 1; } assert lastOfs == ofs; return ofs; }
gallopLeft
和上述代码类似,就不再做讲解。
SortState. mergeLo
private void mergeLo(int base1, int len1, int base2, int len2) { assert len1 > 0 && len2 > 0 && base1 + len1 == base2; // 使用 min(len1, len2) 长度的临时数组 // 这里 len1 会较小 Buffer a = this.a; Buffer tmp = ensureCapacity(len1); s.copyRange(a, base1, tmp, 0, len1); // tmp(run1) 上的指针 int cursor1 = 0; // run2 上的指针 int cursor2 = base2; // 合并结果 上的指针 int dest = base1; // Move first element of second run and deal with degenerate cases // 优化: // 注意: run2 的第一个元素比 run1的第一个元素小 // run1 的最后一个元素 比 run2的最后一个元素大 // 把 run2 的第1个 元素复制到 最终结果的第1个位置 s.copyElement(a, cursor2++, a, dest++); if (--len2 == 0) { // 若 len2 为 1 // 直接 把 run1 拷贝到 最终结果中 s.copyRange(tmp, cursor1, a, dest, len1); return; } if (len1 == 1) { // 若 len1 为 1 // 把 run2 剩余的部分 拷贝到 最终结果中 // 再把 run1 拷贝到 最终结果中 s.copyRange(a, cursor2, a, dest, len2); s.copyElement(tmp, cursor1, a, dest + len2); return; } K key0 = s.newKey(); K key1 = s.newKey(); Comparator<? super K> c = this.c; // 对归并排序的优化: int minGallop = this.minGallop; outer: while (true) { // 主要思想为 使用 count1 count2 对插入进行计数 int count1 = 0; int count2 = 0; do { // 归并 assert len1 > 1 && len2 > 0; if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) { s.copyElement(a, cursor2++, a, dest++); count2++; count1 = 0; if (--len2 == 0) break outer; } else { s.copyElement(tmp, cursor1++, a, dest++); count1++; count2 = 0; if (--len1 == 1) break outer; } // 若某个run连续拷贝的次数超过minGallop // 退出循环 } while ((count1 | count2) < minGallop); // 我们认为若某个run连续拷贝的次数超过minGallop, // 则可能还会出现更若某个run连续拷贝的次数超过minGallop // 所有需要重新进行类似于mergeAt中的操作, // 截取出按“段”进行归并 // 直到 count1 或者 count2 < MIN_GALLOP do { assert len1 > 1 && len2 > 0; count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c); if (count1 != 0) { s.copyRange(tmp, cursor1, a, dest, count1); dest += count1; cursor1 += count1; len1 -= count1; if (len1 <= 1) // len1 == 1 || len1 == 0 break outer; } s.copyElement(a, cursor2++, a, dest++); if (--len2 == 0) break outer; count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); if (count2 != 0) { s.copyRange(a, cursor2, a, dest, count2); dest += count2; cursor2 += count2; len2 -= count2; if (len2 == 0) break outer; } s.copyElement(tmp, cursor1++, a, dest++); if (--len1 == 1) break outer; minGallop--; } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); // 调整 minGallop if (minGallop < 0) minGallop = 0; minGallop += 2; } // 退出 outer 循环 this.minGallop = minGallop < 1 ? 1 : minGallop; // 把尾部写入最终结果 if (len1 == 1) { assert len2 > 0; s.copyRange(a, cursor2, a, dest, len2); s.copyElement(tmp, cursor1, a, dest + len2); } else if (len1 == 0) { throw new IllegalArgumentException( "Comparison method violates its general contract!"); } else { assert len2 == 0; assert len1 > 1; s.copyRange(tmp, cursor1, a, dest, len1); } }
mergeHi
与上述类似,就不再讲解。
SortState.mergeForceCollapse
private void mergeForceCollapse() { // 将所有的run合并 while (stackSize > 1) { int n = stackSize - 2; // 若第3个run 长度 小于 栈顶的run // 先归并第2,3个 run if (n > 0 && runLen[n - 1] < runLen[n + 1]) n--; mergeAt(n); } }
总结
Spark TimSort
中 对MergeSort
大致有一下几点:
元素:不像
MergeSort
惰性的有原来的长度为1,再由归并自动的生成新的归并元素。TimSort
是预先按连续递增(或者将连续递减的片段反转)的片段作为一个归并元素,即run
。插入排序:若是长度小的run,
TimSort
会改用二分的InsertSort
以及对再它进行一些小优化,而不使用MergeSort
归并的时机:
MergeSort
的归并时机是定死的,而TimSort
中的时机是(n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])|| (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])
。以及,若从栈顶开始第3个run
长度 小于 栈顶的run
,先归并第2,3个run
。截取出需要归并的片段:run1是头部和
run2
的尾部都是会有可以不用进行归并的部分。 如TimSort
从run1
中 截取出这样一段片段:lastOfs = k+1
,ofs = 2×k+1
,run1[lastOfs] <= key <= run1[ofs]
。再从该片段上进行二分查找,得到run1
中需要归并的起始位置归并的优化:对
run
长度为1时,进行了小优化。实现了按单个值和按片段归并的协同。
作者:小爷Souljoy
链接:https://www.jianshu.com/p/599346731bcf
共同学习,写下你的评论
评论加载中...
作者其他优质文章