From 878ce911651e71f9e00722c01266b44e1bed79c5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 7 Jan 2022 10:38:10 +0800 Subject: [PATCH] fix(mge/test): replace equal with allclose to fix rng test for ci GitOrigin-RevId: 12758cf5d5e22c883d0de893fca9e8acc6ec0556 --- imperative/python/test/unit/random/test_rng.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index a33a5840..1083e947 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -226,7 +226,7 @@ def test_UniformRNG(): out2 = m2.uniform(size=(100,)) out3 = m3.uniform(size=(100,)) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() @@ -254,7 +254,7 @@ def test_NormalRNG(): out2 = m2.normal(size=(100,)) out3 = m3.normal(size=(100,)) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() @@ -283,7 +283,7 @@ def test_GammaRNG(): out2 = m2.gamma(2, size=(100,)) out3 = m3.gamma(2, size=(100,)) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() @@ -316,7 +316,7 @@ def test_BetaRNG(): out2 = m2.beta(2, 1, size=(100,)) out3 = m3.beta(2, 1, size=(100,)) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() @@ -351,7 +351,7 @@ def test_PoissonRNG(): out2 = m2.poisson(lam.to("xpu1"), size=(100,)) out3 = m3.poisson(lam.to("xpu0"), size=(100,)) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() @@ -381,7 +381,7 @@ def test_PermutationRNG(symbolic): out2 = m2.permutation(1000) out3 = m3.permutation(1000) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() assert not (out1.numpy() == out1_.numpy()).all() @@ -443,7 +443,7 @@ def test_ShuffleRNG(): m2.shuffle(out2) m3.shuffle(out3) - np.testing.assert_equal(out1.numpy(), out2.numpy()) + np.testing.assert_allclose(out1.numpy(), out2.numpy(), atol=1e-6) assert out1.device == "xpu0" and out2.device == "xpu1" assert not (out1.numpy() == out3.numpy()).all() @@ -465,7 +465,7 @@ def test_seed(): set_global_seed(10) out3 = uniform(size=[10, 10]) - np.testing.assert_equal(out1.numpy(), out3.numpy()) + np.testing.assert_allclose(out1.numpy(), out3.numpy(), atol=1e-6) set_global_seed(11) out4 = uniform(size=[10, 10])