You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nes.py 7.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import numpy as np
  16. import os
  17. import pytest
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  21. from mindarmour.attacks.black.natural_evolutionary_strategy import NES
  22. from mindarmour.attacks.black.black_model import BlackModel
  23. from mindarmour.utils.logger import LogUtil
  24. sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)),
  25. "../../../../../"))
  26. from example.mnist_demo.lenet5_net import LeNet5
  27. context.set_context(mode=context.GRAPH_MODE)
  28. context.set_context(device_target="Ascend")
  29. LOGGER = LogUtil.get_instance()
  30. TAG = 'HopSkipJumpAttack'
  31. class ModelToBeAttacked(BlackModel):
  32. """model to be attack"""
  33. def __init__(self, network):
  34. super(ModelToBeAttacked, self).__init__()
  35. self._network = network
  36. def predict(self, inputs):
  37. """predict"""
  38. if len(inputs.shape) == 3:
  39. inputs = inputs[np.newaxis, :]
  40. result = self._network(Tensor(inputs.astype(np.float32)))
  41. return result.asnumpy()
  42. def random_target_labels(true_labels):
  43. target_labels = []
  44. for label in true_labels:
  45. while True:
  46. target_label = np.random.randint(0, 10)
  47. if target_label != label:
  48. target_labels.append(target_label)
  49. break
  50. return target_labels
  51. def _pseudorandom_target(index, total_indices, true_class):
  52. """ pseudo random_target """
  53. rng = np.random.RandomState(index)
  54. target = true_class
  55. while target == true_class:
  56. target = rng.randint(0, total_indices)
  57. return target
  58. def create_target_images(dataset, data_labels, target_labels):
  59. res = []
  60. for label in target_labels:
  61. for i in range(len(data_labels)):
  62. if data_labels[i] == label:
  63. res.append(dataset[i])
  64. break
  65. return np.array(res)
  66. def get_model(current_dir):
  67. ckpt_name = os.path.join(current_dir,
  68. '../../test_data/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt')
  69. net = LeNet5()
  70. load_dict = load_checkpoint(ckpt_name)
  71. load_param_into_net(net, load_dict)
  72. net.set_train(False)
  73. model = ModelToBeAttacked(net)
  74. return model
  75. def get_dataset(current_dir):
  76. # upload trained network
  77. # get test data
  78. test_images = np.load(os.path.join(current_dir,
  79. '../../test_data/test_images.npy'))
  80. test_labels = np.load(os.path.join(current_dir,
  81. '../../test_data/test_labels.npy'))
  82. return test_images, test_labels
  83. def nes_mnist_attack(scene, top_k):
  84. """
  85. hsja-Attack test
  86. """
  87. current_dir = os.path.dirname(os.path.abspath(__file__))
  88. test_images, test_labels = get_dataset(current_dir)
  89. model = get_model(current_dir)
  90. # prediction accuracy before attack
  91. batch_num = 5 # the number of batches of attacking samples
  92. predict_labels = []
  93. i = 0
  94. for img in test_images:
  95. i += 1
  96. pred_labels = np.argmax(model.predict(img), axis=1)
  97. predict_labels.append(pred_labels)
  98. if i >= batch_num:
  99. break
  100. predict_labels = np.concatenate(predict_labels)
  101. true_labels = test_labels
  102. accuracy = np.mean(np.equal(predict_labels, true_labels[:batch_num]))
  103. LOGGER.info(TAG, "prediction accuracy before attacking is : %s",
  104. accuracy)
  105. test_images = test_images
  106. # attacking
  107. if scene == 'Query_Limit':
  108. top_k = -1
  109. elif scene == 'Partial_Info':
  110. top_k = top_k
  111. elif scene == 'Label_Only':
  112. top_k = top_k
  113. success = 0
  114. queries_num = 0
  115. nes_instance = NES(model, scene, top_k=top_k)
  116. test_length = 1
  117. advs = []
  118. for img_index in range(test_length):
  119. # INITIAL IMAGE AND CLASS SELECTION
  120. initial_img = test_images[img_index]
  121. orig_class = true_labels[img_index]
  122. initial_img = [initial_img]
  123. target_class = random_target_labels([orig_class])
  124. target_image = create_target_images(test_images, true_labels,
  125. target_class)
  126. nes_instance.set_target_images(target_image)
  127. tag, adv, queries = nes_instance.generate(initial_img, target_class)
  128. if tag[0]:
  129. success += 1
  130. queries_num += queries[0]
  131. advs.append(adv)
  132. advs = np.reshape(advs, (len(advs), 1, 32, 32))
  133. assert (advs != test_images[:batch_num]).any()
  134. adv_pred = np.argmax(model.predict(advs), axis=1)
  135. adv_accuracy = np.mean(np.equal(adv_pred, true_labels[:test_length]))
  136. @pytest.mark.level0
  137. @pytest.mark.platform_arm_ascend_training
  138. @pytest.mark.platform_x86_ascend_training
  139. @pytest.mark.env_card
  140. @pytest.mark.component_mindarmour
  141. def test_nes_query_limit():
  142. # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only']
  143. scene = 'Query_Limit'
  144. nes_mnist_attack(scene, top_k=-1)
  145. @pytest.mark.level0
  146. @pytest.mark.platform_arm_ascend_training
  147. @pytest.mark.platform_x86_ascend_training
  148. @pytest.mark.env_card
  149. @pytest.mark.component_mindarmour
  150. def test_nes_partial_info():
  151. # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only']
  152. scene = 'Partial_Info'
  153. nes_mnist_attack(scene, top_k=5)
  154. @pytest.mark.level0
  155. @pytest.mark.platform_arm_ascend_training
  156. @pytest.mark.platform_x86_ascend_training
  157. @pytest.mark.env_card
  158. @pytest.mark.component_mindarmour
  159. def test_nes_label_only():
  160. # scene is in ['Query_Limit', 'Partial_Info', 'Label_Only']
  161. scene = 'Label_Only'
  162. nes_mnist_attack(scene, top_k=5)
  163. @pytest.mark.level0
  164. @pytest.mark.platform_arm_ascend_training
  165. @pytest.mark.platform_x86_ascend_training
  166. @pytest.mark.env_card
  167. @pytest.mark.component_mindarmour
  168. def test_value_error():
  169. """test that exception is raised for invalid labels"""
  170. with pytest.raises(ValueError):
  171. assert nes_mnist_attack('Label_Only', -1)
  172. @pytest.mark.level0
  173. @pytest.mark.platform_arm_ascend_training
  174. @pytest.mark.platform_x86_ascend_training
  175. @pytest.mark.env_card
  176. @pytest.mark.component_mindarmour
  177. def test_none():
  178. current_dir = os.path.dirname(os.path.abspath(__file__))
  179. model = get_model(current_dir)
  180. test_images, test_labels = get_dataset(current_dir)
  181. nes = NES(model, 'Partial_Info')
  182. with pytest.raises(ValueError):
  183. assert nes.generate(test_images, test_labels)

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。