You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mag_net.py 4.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Mag-net detector test.
  16. """
  17. import numpy as np
  18. import pytest
  19. import mindspore.ops.operations as P
  20. from mindspore.nn import Cell
  21. from mindspore.ops.operations import TensorAdd
  22. from mindspore import Model
  23. from mindspore import context
  24. from mindarmour.detectors.mag_net import ErrorBasedDetector
  25. from mindarmour.detectors.mag_net import DivergenceBasedDetector
  26. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  27. class Net(Cell):
  28. """
  29. Construct the network of target model.
  30. """
  31. def __init__(self):
  32. super(Net, self).__init__()
  33. self.add = TensorAdd()
  34. def construct(self, inputs):
  35. """
  36. Construct network.
  37. Args:
  38. inputs (Tensor): Input data.
  39. """
  40. return self.add(inputs, inputs)
  41. class PredNet(Cell):
  42. """
  43. Construct the network of target model.
  44. """
  45. def __init__(self):
  46. super(PredNet, self).__init__()
  47. self.shape = P.Shape()
  48. self.reshape = P.Reshape()
  49. self._softmax = P.Softmax()
  50. def construct(self, inputs):
  51. """
  52. Construct network.
  53. Args:
  54. inputs (Tensor): Input data.
  55. """
  56. data = self.reshape(inputs, (self.shape(inputs)[0], -1))
  57. return self._softmax(data)
  58. @pytest.mark.level0
  59. @pytest.mark.platform_arm_ascend_training
  60. @pytest.mark.platform_x86_ascend_training
  61. @pytest.mark.env_card
  62. @pytest.mark.component_mindarmour
  63. def test_mag_net():
  64. """
  65. Compute mindspore result.
  66. """
  67. np.random.seed(5)
  68. ori = np.random.rand(4, 4, 4).astype(np.float32)
  69. np.random.seed(6)
  70. adv = np.random.rand(4, 4, 4).astype(np.float32)
  71. model = Model(Net())
  72. detector = ErrorBasedDetector(model)
  73. detector.fit(ori)
  74. detected_res = detector.detect(adv)
  75. expected_value = np.array([1, 1, 1, 1])
  76. assert np.all(detected_res == expected_value)
  77. @pytest.mark.level0
  78. @pytest.mark.platform_arm_ascend_training
  79. @pytest.mark.platform_x86_ascend_training
  80. @pytest.mark.env_card
  81. @pytest.mark.component_mindarmour
  82. def test_mag_net_transform():
  83. """
  84. Compute mindspore result.
  85. """
  86. np.random.seed(6)
  87. adv = np.random.rand(4, 4, 4).astype(np.float32)
  88. model = Model(Net())
  89. detector = ErrorBasedDetector(model)
  90. adv_trans = detector.transform(adv)
  91. assert np.any(adv_trans != adv)
  92. @pytest.mark.level0
  93. @pytest.mark.platform_arm_ascend_training
  94. @pytest.mark.platform_x86_ascend_training
  95. @pytest.mark.env_card
  96. @pytest.mark.component_mindarmour
  97. def test_mag_net_divergence():
  98. """
  99. Compute mindspore result.
  100. """
  101. np.random.seed(5)
  102. ori = np.random.rand(4, 4, 4).astype(np.float32)
  103. np.random.seed(6)
  104. adv = np.random.rand(4, 4, 4).astype(np.float32)
  105. encoder = Model(Net())
  106. model = Model(PredNet())
  107. detector = DivergenceBasedDetector(encoder, model)
  108. threshold = detector.fit(ori)
  109. detector.set_threshold(threshold)
  110. detected_res = detector.detect(adv)
  111. expected_value = np.array([1, 0, 1, 1])
  112. assert np.all(detected_res == expected_value)
  113. @pytest.mark.level0
  114. @pytest.mark.platform_arm_ascend_training
  115. @pytest.mark.platform_x86_ascend_training
  116. @pytest.mark.env_card
  117. @pytest.mark.component_mindarmour
  118. def test_mag_net_divergence_transform():
  119. """
  120. Compute mindspore result.
  121. """
  122. np.random.seed(6)
  123. adv = np.random.rand(4, 4, 4).astype(np.float32)
  124. encoder = Model(Net())
  125. model = Model(PredNet())
  126. detector = DivergenceBasedDetector(encoder, model)
  127. adv_trans = detector.transform(adv)
  128. assert np.any(adv_trans != adv)
  129. @pytest.mark.level0
  130. @pytest.mark.platform_arm_ascend_training
  131. @pytest.mark.platform_x86_ascend_training
  132. @pytest.mark.env_card
  133. @pytest.mark.component_mindarmour
  134. def test_value_error():
  135. np.random.seed(6)
  136. adv = np.random.rand(4, 4, 4).astype(np.float32)
  137. encoder = Model(Net())
  138. model = Model(PredNet())
  139. detector = DivergenceBasedDetector(encoder, model, option='bad_op')
  140. with pytest.raises(NotImplementedError):
  141. assert detector.detect_diff(adv)

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。