为了账号安全,请及时绑定邮箱和手机立即绑定

线段树入门

标签:
Python 算法

高级数据结构,线段树入门

一、线段树的基本思想

线段树是一种常用来维护区间信息的数据结构,它适用于对区间内进行单点查询、更新、求最值等操作,且时间复杂度能控制到 O(logN)。它的构建过程用到了二分的思想,通过不 段的二分将区间分成两段,并分别对应左孩子和右孩子。

下面举例来说明:比如有一个数组 [1, 2, 5, 7, 8, 10, 13, 15],它的长度是 8,所以范围是 [1, 8]。如果用二分的思想来分解构造出的线段树如下所示:

接下来我们来看看怎么定义线段树的数据结构。通常有两种方式,一种方式是定义一个 class,一种方式是使用连续的数组。首先我们来看下自定义 class 的方式,这里使用 Python 代码:

class SegTree:
    def __init__(self, left, right):
        # 当前结点的左边界
        self.lo = left
        # 当前结点的右边界
        self.hi = right
        # 记录额外的信息,这里通常可以是最值或是区间和,根据题目需求来定义
        self.other_inf = 0
        # 左、右孩子
        self.left = None
        self.right = None

这种定义方式比较直接,但遍历起来稍麻烦一点。而第二种方式是使用连续的数组。从上图我们构造的线段树可以看出,抛开叶子结点,树是一个满二叉树,所以可以使用连续的数组来存储,且父子结点的关系为

parent[i]
parent.left = parent[i*2]
parent.right = parent[i*2+1]

# 对于 i * 2 和 i*2+1,使用位运算可快速得到
i * 2 = i << 1
i * 2 + 1 = i << 1 | 1

使用数组时需要注意,叶子结点其实就是对应的给定数组的值,但数组的长度不一定能满足满二叉树叶子结点的个数,这个时候代码编写上就比较灵活了,一般有两种方式:

  • 使用满二叉树的数组个数,不足处补 0

这种思路的其实相对比较好理解,因为给定的数组都需要放到叶子结点,那如果想要树是一棵满二叉树,则叶子结点的个数必须是 2^n。所以我们需要找到第一个大于等于数组长度的 2 的 n 次幂。对于求第一个大于等于数组长度的 2 的 n 次幂的方法有很多,通过几个位运算就能实现的,可以参考 Java HashMap 的源码,也可以看 Integer 的 highestOneBit 方法,代码如下(这里不解释具体原因):

public static int highestOneBit(int i) {
        // HD, Figure 3-1
        i |= (i >>  1);
        i |= (i >>  2);
        i |= (i >>  4);
        i |= (i >>  8);
        i |= (i >> 16);
        return i - (i >>> 1);
}

而用一个我们比较好理解的方法,如下:

n = 1
while n < len(nums):
   n <<= 1

找到这个数值之后,就可以进行初始化:

 # 因为 n 是第一个大于或等于 len(nums) 的 2 次幂,它是等于它之前所有结点和 + 1 的
# 而一般在线段树中第 0 位通常不用
# 因此 [0] * n 即所有非叶结点的初始化
# [nums] 则是初始化数组,并将其分配到叶子结点
# [0] * (n - len(nums)) 叶子结点未被分配到值的用 0 补全
self.seg_tree = [0] * n + nums + [0] * (n - len(nums))
# 初始化赋值, 根据父子关系的公式
for k in range(n - 1, 0, -1):
    self.seg_tree[k] = self.seg_tree[2 * k] + self.seg_tree[2 * k + 1]

# 这里的做法参考:https://leetcode-cn.com/problems/range-sum-query-mutable/solution/python-shu-zhuang-shu-zu-binary-indexed-tree-by-ze/
  • 根据推导规律,在保证父子结点关系的情况下,初始化 2*n 长度的数组

因为如果长度为 n 的数组都需要放到叶子结点上,则它的上层有 n/2 个结点,再上层 n/4…,根据等比求和公式很容易得出所有结点个数一定小于 2n。所以我们整个线段树数组的值设置为 2n 就足够使用了。代码如下:

n = len(nums)
self._n = n
self._tree = [0] * (n << 1)
# 将最后 n 位放置到叶子结点,也就是数组的最后 n 位
for i in range(n, len(self._tree)):
    self._tree[i] = nums[i - n]

for i in range(n - 1, 0, -1):
    # 父结点 = 左结点(父结点序号*2) + 右结点(父结果序号*2+1)
    self._tree[i] = self._tree[i << 1] + self._tree[i << 1 | 1]

此外,线段树常用的方法有:

# 将第 i 个位置的数值更新为 val
def update(i, val)
# 将第 i 个位置的数值加上 val
def add(i, val)
# 查询区间[i, j]上的区间和或最值,根据具体需求来具体分析
def query(i, j)

这里就不一一实现这些方法,通过两个题目来具体实践下。

二、实践

首先来看一个简单点的题目:

给定一个整数数组  nums,求出数组从索引 i 到 j  (i ≤ j) 范围内元素的总和,包含 i,  j 两点。

update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。

示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
说明:

数组仅可以在 update 函数下进行修改。
你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。

来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/range-sum-query-mutable
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

这个题目其实很简单,首先抛开线段树的思想,其实直接通过 Python list 就可以实现,也是可以 AC 的:

from typing import List
class NumArray:
    def __init__(self, nums: List[int]):
        if not nums:
            self._nums = []
        else:
            self._nums = nums

    def update(self, i: int, val: int) -> None:
        if i >= len(self._nums) or i < 0:
            return
        self._nums[i] = val

    def sumRange(self, i: int, j: int) -> int:
        return sum(self._nums[i:j + 1])

那如果使用线段树呢?这里我们使用数组存储的方式来实现。首先线段数数组初始化可以直接套用上面说到的两种方式,关键是 update 与 sumRange。而对于上面提到的两种方式其实 update 和 sumRange 在解决的的时候实际代码是一样的。这里我们以第二种方式初始化方式为例(毕竟会减少空间的消耗):

class NumArray:
    def __init__(self, nums: List[int]):
        if not nums:
            self._tree = []
            return
        n = len(nums)
        self._n = n
        self._tree = [0] * (n << 1)
        # 将最后 n 位放置到叶子结点,也就是数组的最后 n 位
        for i in range(n, len(self._tree)):
            self._tree[i] = nums[i - n]

        for i in range(n - 1, 0, -1):
            # 父结点 = 左结点(父结点序号*2) + 右结点(父结果序号*2+1)
            self._tree[i] = self._tree[i << 1] + self._tree[i << 1 | 1]
        # print(self._tree)

接下搂我们来看下 update 方法:

def update(self, i: int, val: int) -> None:
    if i < 0 or i >= self._n:
        return
    # 更新数组的第 i 个位置,即更新 self._tree 的第 n + i 个位置,i 是从 0 开始
    # 记录和改变了多少
    change = val - self._tree[self._n + i]
    self._tree[self._n + i] = val
    # 更新 n+i 结点的所有父结点
    parent = (self._n + i) // 2
    while parent:
        self._tree[parent] += change
        parent //= 2
    # print(self._tree)

update 方法其实思路也比较简单,先去更新叶子结点上的数值,并记录改变差值,然后依次更新父结点的记录和。我们再来看下 sumRange:

def sumRange(self, l: int, r: int) -> int:
    # 做一些特殊边界情况判断
    if l > r:
        return 0
    if r < 0:
        return 0
    if l > self._n:
        return 0
    if r < 0:
        i = 0
    if l >= self._n:
        j = self._n - 1

    l += self._n
    r += self._n
    result = 0
    # 当 l <= r 时
    while l <= r:
        # 如果左边界是右孩子,则说明不能加它的父结点的值,所以它的值需要单独加
        if l % 2 == 1:
            result += self._tree[l]
            # 加完之后,l向后移动,则移到了父结点右孩子的左孩子结点
            l += 1
        # 如果右边界在左孩子,则左孩子需要单独加
        if r % 2 == 0:
            result += self._tree[r]
            r -= 1
        l //= 2
        r //= 2
    return result

我个人认为 sumRange 比 update 要难理解一点,它的主要思想在于如果当前要求值的范围比当前结点记录的范围要大(即既需要左孩子,也需要右孩子),则找父结点,如果只需要当前结点,就加上当前结点。

至此,这个题目就解决了。使用数组的话,代码在理解上会复杂一点,主要是要对父子关系的灵活运用。

接下来,我们来看另外一个题目,我也是在刷这个题目时了解到线段树这个数据结构:

给定一个整数数组 nums,返回区间和在 [lower, upper] 之间的个数,包含 lower 和 upper。
区间和 S(i, j) 表示在 nums 中,位置从 i 到 j 的元素之和,包含 i 和 j (i ≤ j)。

说明:
最直观的算法复杂度是 O(n2) ,请在此基础上优化你的算法。

示例:

输入: nums = [-2,5,-1], lower = -2, upper = 2,
输出: 3 
解释: 3个区间分别是: [0,0], [2,2], [0,2],它们表示的和分别为: -2, -1, 2。

来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/count-of-range-sum
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

这个题目在分析时,我们需要将式子做一个转换,一旦做了这个转换,基本就成功一半了。转换关系如下:

求:lower <=  sum(i, j) <= upper
而 sum(i, j) = nums[0] + nums[1] + ... nums[j] - (nums[0] + nums[1] + ... + nums[i-1])
即 sum(i, j) = prefixSum[j] - prefixSum[i-1] , 这里 prefixSum 为 nums 的前缀和数组

所以题目可以转换为求:
lower <=  prefixSum[j] - prefixSum[i-1] <= upper

-> 

# lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1]
# 或是
# prefixSum[j] - upper <= prefixSum[i-1] <= prefixSum[j] - lower

# 如果是第一种移动方法,则表示当给定一个 i 位置的前缀和时,需要找从 i+1 位置往后,满足前缀和在 lower + prefixSum[i] 和 upper + prefixSum[i] 的个数
# 如果是第二种移动方法,则表示当给定一个 j 位置的前缀各时,需要找从 0 到 j-1 位置的前缀各,在 prefixSum[j] - upper 和 prefixSum[j] - lower 范围内的

当我们分析到这一步时,我们可以发现,其实我们要求的就是当给定一个数值,然后求在一个范围内的数值中,在指定范围的数值有多少个。比如题目中的例子:nums = [-2,5,-1], lower = -2, upper = 2,前缀和数组为 [-2, 3, 2],因此我们可以遍历前缀和数组,如果是从后往前遍历,则是利用 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1] 这个转换,如果是从前往后遍历,则是使用 prefixSum[j] - upper <= prefixSum[i-1] <= prefixSum[j] - lower 转换(主要是需要确定范围,所以固定的是式子左右两边的变量)。

不管使用哪种方式,其实思路都是一样的。我们首先找到一个基准的前缀和,然后从当前这个基准向前(或向后)找在范围内的个数,找到之后,将当前这个基准加入到某种数据结构中,在这个数据结构里记录的就是当前基准以前所有的前缀和。而且我们需要这个数据结构来保证,在这个数据结构中查询在指定范围内的数值个数时,性能很高,此外因为还会不断做插入,也要保证插入的性能。

因此,我们明确了,解题需要前缀和数组,和一个能在区间内快速做查询和插入(也可以是更新)的数据结构。显然和我们线段树的适用范围是很相似的。直接看代码吧。

# 使用线段树,第一种移动方式,即 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1]
class SegTree:
    def __init__(self, left, right):
        # 当前结点的左边界
        self.lo = left
        # 当前结点的右边界
        self.hi = right
        # 记录在当前范围内的数有多少个
        self.count = 0
        # 左、右孩子
        self.left = None
        self.right = None


class Solution:
    def buildSegTree(self, left: int, right: int) -> SegTree:
        # 感觉也可以用数组来代替
        node = SegTree(left, right)
        if left == right:
            return node
        mid = (left + right) // 2
        # 左边一半
        left = self.buildSegTree(left, mid)
        # 右边一半
        right = self.buildSegTree(mid + 1, right)
        node.left = left
        node.right = right
        return node

    def countOfSegTree(self, node: SegTree, left: int, right: int) -> int:
        """统计在线段树中,在 [left, right] 范围内的数值
        """
        # 如果范围比当前 lo hi 的范围大,则直接返回 count 值
        if left <= node.lo and node.hi <= right:
            return node.count
        # 如果没在当前范围内
        if left > node.hi or right < node.lo:
            return 0
        return self.countOfSegTree(node.left, left, right) + self.countOfSegTree(node.right, left, right)

    def insertToSegTree(self, node: SegTree, val: int) -> None:
        """在线段树中,插入一个 val 的值
        """
        node.count += 1
        if node.lo == node.hi == val:
            return
        mid = (node.lo + node.hi) // 2
        if val <= mid:
            self.insertToSegTree(node.left, val)
        else:
            self.insertToSegTree(node.right, val)

    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        if not nums:
            return 0

        # 第一步,求前缀和,注意第一个元素为0,这样在遍历时,第一个元素到最后一个元素的情况也会考虑进去
        prefix_sum = [0]
        for n in nums:
            prefix_sum.append(prefix_sum[-1] + n)
        # 第二步,求出所有 lower + prefixSum[i-1] 和 upper + prefixSum[i-1],以及 prefix_sum 本身
        allNumbers = set()
        for n in prefix_sum:
            allNumbers.add(n)
            allNumbers.add(lower + n)
            allNumbers.add(upper + n)
        # 将 allNumbers 通过 hash 离散到一个连续的数组中
        nums_map = {}
        for i, n in enumerate(sorted(allNumbers)):
            nums_map[n] = i
        root = self.buildSegTree(0, len(nums_map))
        res = 0
        # 因为这里是看 lower + prefixSum[i-1] <= prefixSum[j] <= upper + prefixSum[i-1],每次都是从当前位置往后看所有的前缀和,所以在遍历时,应该从最后一个前缀和往前遍历
        for n in prefix_sum[::-1]:
            left, right = nums_map[lower + n], nums_map[upper + n]
            res += self.countOfSegTree(root, left, right)
            self.insertToSegTree(root, nums_map[n])
        return res


s = Solution()
print(s.countRangeSum([-2, 5, -1], -2, 2))

# 代码参考了官方题解的java版本,官方版本用的是第二种移动方式

代码中需要注意 allNumber 和 nums_map 的理解,这里线段树主要记录的是在 left,right 范围内的数值个数,因为前缀和可能比较散乱,所以对数值做了映射处理,将它映射到一个连接的数组中。

在题解中也看到了另外一种解决方法,使用的是有序数组+二分来代替的这种线段树这种数据结构。代码如下:

# 作者:fan-cai
# 链接:https://leetcode-cn.com/problems/count-of-range-sum/solution/python3-6xing-dai-ma-jian-ji-qian-zhui-he-er-fen-c/
# 来源:力扣(LeetCode)
# 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        """按照上面的思路,一种解法就是
        计算累积和数组 sums 的,其中 sum[i] = nums[0] + nums[1] + ... + nums[i],对于某个i来说,只有那些满足 lower <= sum[j] - sum[i] <= upper 的 j 能形成一个区间 [i, j] 满足题意,则有:sum[i] + lower =< sum[j] <= sum[i] + upper,目标就是来找到有多少个这样的 j满足上述条件。从后向前遍历累加和数组,相当于固定sum[i]后,算出有多少的sum[j]满足左右条件。因为必须满足0 =< i <= j,所以sum[j]的范围一定是由sum[i]之后的元素组成的数组。对sum[j]的范围数组排序,l是找数组中第一个大于等于给定值(左条件)的数,而 r 是找数组中最后一个小于等于给定值(右条件)的数,那么两者相减,就是j的个数。
        """

        import bisect
        res, pre, now = 0, [0], 0
        for n in nums:
            # now 相当于前缀和
            now += n
            # 这种解法是针对上面的第二种移动方法
            # pre记录了从当前前缀和位置往前所有的前缀和,而且是排好序的,只需要在 pre 里找到对应的左边界和右边界即可
            res += bisect.bisect_right(pre, now - lower) - \
                bisect.bisect_left(pre, now - upper)
            bisect.insort(pre, now)
        return res

三、总结

  1. 线段树采用了二分的思想,适用在区间范围内做查询、更新,见到类似在区间内获取和、最值等问题,都可以使用线段树
  2. 个人认为线段树问题难点在于如果构造线段树。而如果采用连续数组的方式来存储,要充分利用数组要存放在叶子结点这一特性
  3. leetcode官方题解比较难理解(可能因为都是高手写的),关键还是需要多看代码,多 debug
  4. 看过关于线段树的其实实现版本,有做懒更新与懒插入,后续有机会再详细总结下

四、参考资料

为了解决上面说的第二个问题,花了几天时间,主要是连别人的题解都看不懂,参考了一些别人的解决思路,最后 debug 官方的 java 实现版本

点击查看更多内容
1人点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消