From 2174fd12b9c8f69cd3591ec258b272f00ff91ec0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Apr 2020 14:10:53 +0800 Subject: [PATCH] refactor(mge/test): make test_converge more stable GitOrigin-RevId: 710e3ede406119d9a497c4c92f1e6ca768e05128 --- python_module/test/integration/test_converge.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python_module/test/integration/test_converge.py b/python_module/test/integration/test_converge.py index ec0efa87..e7e4f6c4 100644 --- a/python_module/test/integration/test_converge.py +++ b/python_module/test/integration/test_converge.py @@ -103,6 +103,12 @@ def test_training_converge(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"