递归算法实战

本节将会以 3 个有意思的 leetcode 编程题来实践递归算法,帮助大家更加深刻理解和掌握递归算法。

1. 常规的递归算法

这道题是 leetcode 的第 70 题,题目名称为爬楼梯。题目内容如下:

假设你正在爬楼梯。需要 n 阶你才能到达楼顶。每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢?注意:给定 n 是一个正整数。

示例 1

输入: 2
输出: 2
解释: 有两种方法可以爬到楼顶。
1.  1 阶 + 1 阶
2.  2 阶

示例 2

输入: 3
输出: 3
解释: 有三种方法可以爬到楼顶。
1.  1 阶 + 1 阶 + 1 阶
2.  1 阶 + 2 阶
3.  2 阶 + 1 阶

使用前面的递归三套件来解答这个基础的递归问题,即函数 f(n) 为 n 阶楼梯的爬楼总方法数,则有:

终止条件:很明显,当楼梯阶数为 1 时,我们知道答案肯定为 1,即 f(1) = 1;此外 n = 2时,也知 f(2) = 2

递归公式:递归指的是用前面计算出来的 f(n-1), f(n-2),~,f(1) 等等的值递推得到 f(n)

这里思考下,首先我们的台阶往上减一级,即到达 n-1 级台阶的方法一共有 f(n-1) 种,然后只能跨 1 级到达第 n 级台阶,这是一种爬到楼顶的方法;由于我每次可以爬 1 个或者 2 个台阶,那么另一种爬到楼顶的方法是在 n-2 级台阶,然后爬 2 级就到了楼顶,而到达 n-2 级台阶的方法正好有 f(n-2) 种。综合得到递推公式为:

f(n) = f(n-1) + f(n-2)

返回预定结果:每个函数返回的结果是爬到 n 级台阶楼顶的总方法。

综合这三步,我们就可以得到如下的函数:

def climb_stairs(n):
    # 终止条件
    if n <= 2:
        return n
    
    # 递推公式和返回预定结果
    return climb_stairs(n - 1) + climb_stairs(n - 2) 

但是这样的递归算法在 leetcode 上时无法通过的,原因就是我们前面提到的递归算法的可能会导致的一个问题:冗余计算,这样会使得递归算法的时间复杂度随着问题规模呈指数级上升,非常低效。

图片描述

递归超时

我们来分析下这个递归算法造成冗余计算的原因,参考下图:

图片描述

计算f(5)时的冗余计算

可以看到,在上面的递归分解计算图中可以看到,计算 f(5) 时,f(3) 会被重复递归计算。如果是计算 f(6) 时,f(5)f(4) 以及 f(3) 都会被重复计算,具体的图就不画了。而且随着输入的值越大,冗余的数越多,会导致一个 f(k) 可能被重复计算好多次。这也就造成了该算法无法通过题解的原因。改进方法当然比较简单,我们有了递推式,不用递归即可:

def climbStairs(self, n: int) -> int:
    if n <= 2:
        return n

    s = [1, 2]
    for _ in range(3, n + 1):
        s[0], s[1] = s[1], s[0] + s[1]
        return s[1]

因此,有时候递归算法看起来美好但需要慎用,特别对于递推关系式中用到前面多个值时,要小心分析,避免出现冗余计算情况。

2. 二叉树中的递归算法应用

在二叉树的问题中,几乎处处用着递归。最经典的例子就是二叉树的前中后序的遍历,使用递归算法较为简单和明了,而使用非递归算法实现时会显得十分复杂,尤其是后序遍历,非常难写。今天我们来看二叉树中的几个非常简单的问题,全部使用递归方法解决这些问题。

给定两个二叉树,编写一个函数来检验它们是否相同。如果两个树在结构上相同,并且节点具有相同的值,则认为它们是相同的。

示例 1

输入:      1         1
          / \       / \
         2   3     2   3

        [1,2,3],   [1,2,3]

输出: true

示例 2

输入:      1          1
          /           \
         2             2

        [1,2],     [1,null,2]

输出: false

示例 3

输入:       1         1
          / \       / \
         2   1     1   2

        [1,2,1],   [1,1,2]

输出: false

问题也比较简单,leetcode 官方给我们定义了二叉树类:

# Definition for a binary tree node.
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

继续我们的递归三要素:终止条件,递推公式,预定输出。首先看看递归函数的输出:相同的树(True) 和不同的树(False),输入是两个待比较二叉树的根节点,那么递归函数这样写:

def is_same_tree(p, q):
    # ...
    
    reture False

然后来看看终止条件,对于二叉树的终止条件就是到输入的两个根节点只要有一个为空即可。当两个根节点都为空是,返回 True;当只有其中一个根节点为空而另一个根节点不为空时,明显两个树不是相同的树,故返回 False:

def is_same_tree(p, q):
    ################### 终止条件   ########################
    if not p and not q:
        return True
    if not p or not q:
        return False
    #####################################################
    # 递归比较
    # ...
    reture False

来看递归公式,判断一棵二叉树是否相同,我们首先是比较根节点的值,如果根节点的值不相同,那就直接返回 False;如果根节点相同,我们递归比较左子树和右子树,左子树或者右子树都相同时,那么这棵二叉树才是相同的:

def is_same_tree(p, q):
    # 终止条件
    # ...
    # 递归比较,返回True/False
    return p.val == q.val and is_same_tree(p.left, q.left) and is_same_tree(p.right, q.right)

看看这个递归的方法是不是非常简洁?那么这种写法会不会存在冗余的计算呢?答案时不会的,因为我们可以看到这里递归计算的左子树和右子树时完全没有重叠的部分,所以不存在冗余计算。因此,对于该问题而言,递归是一种非常优美的写法。完整的递归代码如下:

class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

def isSameTree(p, q):
    if not p and not q:
    return True

    if not p or not q:
        return False

    return p.val == q.val and isSameTree(p.left, q.left) and isSameTree(p.right, q.right)

3. 递归穷举

我们来看 leetcode 的第 15 题:三数之和。该题的难度为中等,题目内容如下:

给你一个包含 n 个整数的数组 nums,判断 nums 中是否存在三个元素 a,b,c ,使得 a + b + c = 0 ?请你找出所有满足条件且不重复的三元组。**注意:**答案中不可以包含重复的三元组。

示例

给定数组 nums = [-1, 0, 1, 2, -1, -4],

满足要求的三元组集合为:
[
  [-1, 0, 1],
  [-1, -1, 2]
]

我们今天并不打算通过这道题的题解,因为这道题用递归算法是无法通过题解的,原因和之前一样,算法的时间复杂度高,最后会超出时间限制。另外我们去掉后面的注意部分事项,允许答案包含重复的三元组,我们使用递归算法相当于穷举出所有可能的情况,判断三元组的值是否能为 0。首先继续我们的解题三部曲:

函数定义,输入和输出:

def three_sum(nums, target, count):
    """
    输入: 
       num: 输入的数组
       target: 目标值
       count: 在数组中找到几个数之和满足target
    输出:
       []或者[[1,2,3], [-1,4,3]] 这样的满足条件的全部结果
    """
    
    res = []
    # ...
    
    return res

注意: 定义这样的递归函数是经过思考的,因为后续递归调用时需要依赖目标值 (target) 或元素个数 (count) 这样两个参数。返回的参数要么为空,要么是所有找到的满足条件的三元组的集合。

接下来是递归方法的终止条件,首先考虑以下几个终止条件:

  • 如果输入的 nums 列表为空,那么直接返回 [];
  • 如果输入的 count 等于1,就要开始判断了,因为这个时候只需要判断 target 是否在列表中存在即可;

综上,我们写出终止条件的代码:

def three_sum(nums, target, count):
    """
    输入: 
       num: 输入的数组
       target: 目标值
       count: 在数组中找到几个数之和满足target
    输出:
       []或者[[1,2,3], [-1,4,3]] 这样的满足条件的全部结果
    """
    res = []
    ######################  终止条件  ######################################
    if not nums:
        return res
    if count == 1 and target in nums:
        return [[ target ]]
    elif count == 1 and target not in nums:
        # count等于1时,如果target没有出现在剩余的nums中,说明不存在满足条件的数组元素
        return res
    #######################################################################
    
    # 返回值
    return res

接下来最重要的,就是递归的公式了,递归的方向一定要朝着减小目标函数规模进行

很明显,我们的递归应该是这样子:以 nums 的第一个元素为递归点,整个 nums 列表中和为 target 的 count 个元素的结果可以分为包含 nums[0] 和不包含 nums[0] 的结果组成,简单点说就是:

  • 如果包含 nums[0],那么接下来的 nums[1:] 列表中我们就要找值为 target - nums[0] 的 count - 1 个元素,也即 three_sum(nums[1:], target - nums[0], count -1),然后我们还需要在得到的元组集中的最前面位置加上 nums[0];
  • 如果不包含 nums[0],那么就是在 nums[1:] 列表中找值为 target 的 count 个元素,用递归函数表示就是 three_sum(nums[1:], target, count);这样找到的结果正是 count 个元素。

组合上述两个递归得到的结果,就得到了函数 three_sum(nums, target, count) 的结果,代码如下:

res = []
# 包含nums[0]
t1 = three_sum(nums[1:], target - nums[0], count - 1)
# 不包含nums[0]
t2 = three_sum(nums[1:], target, count)
if t1:
    for i in range(len(t1)):
        t = [nums[0]]
        t.extend(t1[i])
        # 每个得到的结果前面加上 nums[0]
        res.append(t)
if t2:
    for j in range(len(t2)):
        res.append(t2[j]) 
        
# 此时得到的res就是递归的最后结果

综合就可以得到递归遍历所有三个元素和的情况并最终找出所有满足条件结果的三元集:

def three_sum(nums, target, count):
    res = []

    # 终止条件
    if not nums:
        return res

    if count == 1 and target in nums:
        # 一定要这样写
        return [[ target ]]
    elif count == 1 and target not in nums:
        return res
    
    # 包含nums[0]
    t1 = three_sum(nums[1:], target - nums[0], count - 1)
    # 不包含nums[0]
    t2 = three_sum(nums[1:], target, count)
    if t1:
        for i in range(len(t1)):
            # 犯了一个巨大的错误,extend() 方法的使用,它无返回,只会扩充原数组
            # res.append([nums[0]].extend(t1[i]))
            t = [nums[0]]
            t.extend(t1[i])
            res.append(t)
    if t2:
        for j in range(len(t2)):
            res.append(t2[j])
     
    return res

调用该函数的方式如下:

nums = [-1, 0, 1, 2, -1, -4]
# 0 为目标值,3为多少个元素和为target
res = three_sum(nums, 0, 3)

这样的递归遍历思想在穷举中用的比较多,因为它以非常优雅的方式简化了穷举代码。不过这道题使用递归算法等价于穷举法,时间复杂度为 O(n3)O(n^3),因此显得并不高效。对于最优的解法读者可以自行思考下然后解决它。

4. 小结

今天我们用 3 道编程题来体验了一把递归算法,可以看到递归算法在编写时会使得代码整体看起来简洁优雅,但是有时候也会存在美丽的陷阱。例如第一个算法题中,使用递归算法会导致大量的冗余计算,使得算法的复杂度呈指数级增长。对于是否会存在冗余计算是在使用递归算法时一定要慎重考虑的,它会极大地影响算法的复杂度,如果存在的话,尽量不要使用递归算法或者想办法避免它。