Browse Source

!30 solve DI: [MA][diff_privacy][Func] decay_policy not checked and error should be TypeError https://gitee.com/mindspore/dashboard/issues?id=I1IS7V

Merge pull request !30 from ZhidanLiu/master
tags/v1.2.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
93eae12155
4 changed files with 11 additions and 8 deletions
  1. +4
    -1
      mindarmour/diff_privacy/mechanisms/mechanisms.py
  2. +3
    -3
      mindarmour/utils/_check_param.py
  3. +1
    -1
      tests/ut/python/attacks/test_iterative_gradient_method.py
  4. +3
    -3
      tests/ut/python/detectors/test_region_based_detector.py

+ 4
- 1
mindarmour/diff_privacy/mechanisms/mechanisms.py View File

@@ -159,7 +159,10 @@ class AdaGaussianRandom(Mechanisms):
alpha = check_param_type('alpha', alpha, float)
self._alpha = Tensor(np.array(alpha, np.float32))

self._decay_policy = check_param_type('decay_policy', decay_policy, str)
if decay_policy not in ['Time', 'Step']:
raise NameError("The decay_policy must be in ['Time', 'Step'], but "
"get {}".format(decay_policy))
self._decay_policy = decay_policy
self._mean = 0.0
self._sub = P.Sub()
self._mul = P.Mul()


+ 3
- 3
mindarmour/utils/_check_param.py View File

@@ -43,7 +43,7 @@ def check_param_type(arg_name, arg_value, valid_type):
valid_type,
type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)

return arg_value

@@ -54,7 +54,7 @@ def check_param_multi_types(arg_name, arg_value, valid_types):
msg = 'type of {} must be in {}, but got {}' \
.format(arg_name, valid_types, type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)

return arg_value

@@ -157,7 +157,7 @@ def check_numpy_param(arg_name, arg_value):
msg = 'type of {} must be in (list, tuple, numpy.ndarray)'.format(
arg_name)
LOGGER.error(TAG, msg)
raise ValueError(msg)
raise TypeError(msg)
return arg_value




+ 1
- 1
tests/ut/python/attacks/test_iterative_gradient_method.py View File

@@ -167,7 +167,7 @@ def test_momentum_diverse_input_iterative_method():
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_error():
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# check_param_multi_types
assert IterativeGradientMethod(Net(), bounds=None)
attack = IterativeGradientMethod(Net(), bounds=(0.0, 1.0))


+ 3
- 3
tests/ut/python/detectors/test_region_based_detector.py View File

@@ -100,16 +100,16 @@ def test_value_error():
with pytest.raises(ValueError):
assert RegionBasedDetector(model, search_step=0)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
assert RegionBasedDetector(model, sparse='False')

detector = RegionBasedDetector(model)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# radius must not empty
assert detector.detect(adv)

radius = detector.fit(ori, labels)
detector.set_radius(radius)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
# adv type should be in (list, tuple, numpy.ndarray)
assert detector.detect(adv.tostring())

Loading…
Cancel
Save