1 回答
TA贡献1848条经验 获得超2个赞
从我们可以获取代码并重写它,在最后添加另一个检查np.array_equal
def array_equal(a1, a2):
try:
a1, a2 = asarray(a1), asarray(a2)
except Exception:
return False
if a1.shape != a2.shape:
return False
eq = asarray(a1 == a2) # [ True False False True]
if not bool(eq.all()):
errors = [f"idx:{idx} ({vals[0]}!={vals[1]})"
for idx, vals in enumerate(zip(a1, a2))
if not eq[idx]]
raise AssertionError("Arrays not equal " + " ".join(errors))
return True
class TestArray(unittest.TestCase):
def test_values(self):
x = np.array([1, 1, 1, 1])
self.assertTrue(array_equal(x, [1, 2, 3, 1]))
if __name__ == "__main__":
unittest.main()
给AssertionError: Arrays not equal idx:1 (1!=2) idx:2 (1!=3)
添加回答
举报