为了账号安全,请及时绑定邮箱和手机立即绑定

显示导致测试失败的数组条目

显示导致测试失败的数组条目

GCT1015 2022-09-06 16:42:22
作为测试套件的一部分,我必须检查函数返回的numpy数组是否正确。使用返回一个关于所有数组元素是否相同的布尔值很容易进行此检查。np.array_equal如果测试失败,错误消息对于了解导致失败的原因不是特别有用。import unittestimport numpy as npclass TestArray(unittest.TestCase):    def test_values(self):        x = np.array([1, 2])        self.assertTrue(np.array_equal(x, [1, 3]))if __name__ == "__main__":    unittest.main()测试失败消息:Traceback (most recent call last):  File "test.py", line 7, in test_values    self.assertTrue(np.array_equal(x, [1, 3]))AssertionError: False is not true有没有一种简单的方法来检查条目是否相等,以显示第一个不相等条目的索引和值?我想要一条错误消息,如下所示:AssertionError: Arrays not equal at index 1 (2 != 3) 
查看完整描述

1 回答

?
慕尼黑5688855

TA贡献1848条经验 获得超2个赞

从我们可以获取代码并重写它,在最后添加另一个检查np.array_equal


def array_equal(a1, a2):

    try:

        a1, a2 = asarray(a1), asarray(a2)

    except Exception:

        return False

    if a1.shape != a2.shape:

        return False

    eq = asarray(a1 == a2) # [ True False False True]

    if not bool(eq.all()):

        errors = [f"idx:{idx} ({vals[0]}!={vals[1]})"

                  for idx, vals in enumerate(zip(a1, a2))

                  if not eq[idx]]

        raise AssertionError("Arrays not equal " + " ".join(errors))

    return True


class TestArray(unittest.TestCase):

    def test_values(self):

        x = np.array([1, 1, 1, 1])

        self.assertTrue(array_equal(x, [1, 2, 3, 1]))


if __name__ == "__main__":

    unittest.main()

给AssertionError: Arrays not equal idx:1 (1!=2) idx:2 (1!=3)


查看完整回答
反对 回复 2022-09-06
  • 1 回答
  • 0 关注
  • 68 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信