From ec9d46a7e1e7dbf3895fc02e7f67a3d9346bb577 Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Fri, 29 May 2020 15:13:44 +0800 Subject: [PATCH] fix parameter check --- mindarmour/diff_privacy/mechanisms/mechanisms.py | 5 ++++- mindarmour/utils/_check_param.py | 6 +++--- tests/ut/python/attacks/test_iterative_gradient_method.py | 2 +- tests/ut/python/detectors/test_region_based_detector.py | 6 +++--- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mindarmour/diff_privacy/mechanisms/mechanisms.py b/mindarmour/diff_privacy/mechanisms/mechanisms.py index f0608c4..750e51b 100644 --- a/mindarmour/diff_privacy/mechanisms/mechanisms.py +++ b/mindarmour/diff_privacy/mechanisms/mechanisms.py @@ -159,7 +159,10 @@ class AdaGaussianRandom(Mechanisms): alpha = check_param_type('alpha', alpha, float) self._alpha = Tensor(np.array(alpha, np.float32)) - self._decay_policy = check_param_type('decay_policy', decay_policy, str) + if decay_policy not in ['Time', 'Step']: + raise NameError("The decay_policy must be in ['Time', 'Step'], but " + "get {}".format(decay_policy)) + self._decay_policy = decay_policy self._mean = 0.0 self._sub = P.Sub() self._mul = P.Mul() diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 06a37c6..36ebc87 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -43,7 +43,7 @@ def check_param_type(arg_name, arg_value, valid_type): valid_type, type(arg_value).__name__) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value @@ -54,7 +54,7 @@ def check_param_multi_types(arg_name, arg_value, valid_types): msg = 'type of {} must be in {}, but got {}' \ .format(arg_name, valid_types, type(arg_value).__name__) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value @@ -157,7 +157,7 @@ def check_numpy_param(arg_name, arg_value): msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format( arg_name) LOGGER.error(TAG, msg) - raise ValueError(msg) + raise TypeError(msg) return arg_value diff --git a/tests/ut/python/attacks/test_iterative_gradient_method.py b/tests/ut/python/attacks/test_iterative_gradient_method.py index 3a9fcb0..34cc2e3 100644 --- a/tests/ut/python/attacks/test_iterative_gradient_method.py +++ b/tests/ut/python/attacks/test_iterative_gradient_method.py @@ -167,7 +167,7 @@ def test_momentum_diverse_input_iterative_method(): @pytest.mark.env_card @pytest.mark.component_mindarmour def test_error(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): # check_param_multi_types assert IterativeGradientMethod(Net(), bounds=None) attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0)) diff --git a/tests/ut/python/detectors/test_region_based_detector.py b/tests/ut/python/detectors/test_region_based_detector.py index c958749..f4b891a 100644 --- a/tests/ut/python/detectors/test_region_based_detector.py +++ b/tests/ut/python/detectors/test_region_based_detector.py @@ -100,16 +100,16 @@ def test_value_error(): with pytest.raises(ValueError): assert RegionBasedDetector(model, search_step=0) - with pytest.raises(ValueError): + with pytest.raises(TypeError): assert RegionBasedDetector(model, sparse='False') detector = RegionBasedDetector(model) - with pytest.raises(ValueError): + with pytest.raises(TypeError): # radius must not empty assert detector.detect(adv) radius = detector.fit(ori, labels) detector.set_radius(radius) - with pytest.raises(ValueError): + with pytest.raises(TypeError): # adv type should be in (list, tuple, numpy.ndarray) assert detector.detect(adv.tostring())