From c9830d20d03cd897d7ed4e41c7e15afece55fb6a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 1 Apr 2020 16:13:46 +0800 Subject: [PATCH] fix(mge/test): fix tolerance GitOrigin-RevId: 58c029b394edc17e1d6b1abb261cd76d58c6ab4f --- python_module/test/integration/test_correctness.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python_module/test/integration/test_correctness.py b/python_module/test/integration/test_correctness.py index 5ea6c5ed..778ebd77 100644 --- a/python_module/test/integration/test_correctness.py +++ b/python_module/test/integration/test_correctness.py @@ -63,7 +63,10 @@ def train(data, label, net, opt): def update_model(model_path): """ - Update the dumped model with test cases for new reference values + Update the dumped model with test cases for new reference values. + + The model with pre-trained weights is trained for one iter with the test data attached. + The loss and updated net state dict is dumped. """ net = MnistNet(has_bn=True) checkpoint = mge.load(model_path) @@ -89,9 +92,6 @@ def run_test(model_path, use_jit, use_symbolic): """ Load the model with test cases and run the training for one iter. The loss and updated weights are compared with reference value to verify the correctness. - The model with pre-trained weights is trained for one iter and the net state dict is dumped. - The test cases is appended to the model file. The reference result is obtained - by running the train for one iter. Dump a new file with updated result by calling update_model if you think the test fails due to numerical rounding errors instead of bugs. @@ -109,7 +109,7 @@ def run_test(model_path, use_jit, use_symbolic): data.set_value(checkpoint["data"]) label.set_value(checkpoint["label"]) - max_err = 0.0 + max_err = 1e-1 train_func = train if use_jit: