You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuzzing.py 24 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Fuzzing.
  16. """
  17. from random import choice
  18. import numpy as np
  19. from mindspore import Model
  20. from mindspore import Tensor
  21. from mindarmour.utils._check_param import check_model, check_numpy_param, \
  22. check_param_multi_types, check_norm_level, check_param_in_range, \
  23. check_param_type, check_int_positive
  24. from mindarmour.utils.logger import LogUtil
  25. from ..adv_robustness.attacks import FastGradientSignMethod, \
  26. MomentumDiverseInputIterativeMethod, ProjectedGradientDescent
  27. from .image_transform import Contrast, Brightness, Blur, \
  28. Noise, Translate, Scale, Shear, Rotate
  29. from .model_coverage_metrics import ModelCoverageMetrics
  30. LOGGER = LogUtil.get_instance()
  31. TAG = 'Fuzzer'
  32. def _select_next(initial_seeds):
  33. """ Randomly select a seed from `initial_seeds`."""
  34. seed_num = choice(range(len(initial_seeds)))
  35. seed = initial_seeds[seed_num]
  36. del initial_seeds[seed_num]
  37. return seed, initial_seeds
  38. def _coverage_gains(coverages):
  39. """ Calculate the coverage gains of mutated samples. """
  40. gains = [0] + coverages[:-1]
  41. gains = np.array(coverages) - np.array(gains)
  42. return gains
  43. def _is_trans_valid(seed, mutate_sample):
  44. """ Check a mutated sample is valid. If the number of changed pixels in
  45. a seed is less than pixels_change_rate*size(seed), this mutate is valid.
  46. Else check the infinite norm of seed changes, if the value of the
  47. infinite norm less than pixel_value_change_rate*255, this mutate is
  48. valid too. Otherwise the opposite.
  49. """
  50. is_valid = False
  51. pixels_change_rate = 0.02
  52. pixel_value_change_rate = 0.2
  53. diff = np.array(seed - mutate_sample).flatten()
  54. size = np.shape(diff)[0]
  55. l0_norm = np.linalg.norm(diff, ord=0)
  56. linf = np.linalg.norm(diff, ord=np.inf)
  57. if l0_norm > pixels_change_rate*size:
  58. if linf < 256:
  59. is_valid = True
  60. else:
  61. if linf < pixel_value_change_rate*255:
  62. is_valid = True
  63. return is_valid
  64. def _check_eval_metrics(eval_metrics):
  65. """ Check evaluation metrics."""
  66. if isinstance(eval_metrics, (list, tuple)):
  67. eval_metrics_ = []
  68. available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac']
  69. for elem in eval_metrics:
  70. if elem not in available_metrics:
  71. msg = 'metric in list `eval_metrics` must be in {}, but ' \
  72. 'got {}.'.format(available_metrics, elem)
  73. LOGGER.error(TAG, msg)
  74. raise ValueError(msg)
  75. eval_metrics_.append(elem.lower())
  76. elif isinstance(eval_metrics, str):
  77. if eval_metrics != 'auto':
  78. msg = "the value of `eval_metrics` must be 'auto' if it's type " \
  79. "is str, but got {}.".format(eval_metrics)
  80. LOGGER.error(TAG, msg)
  81. raise ValueError(msg)
  82. eval_metrics_ = 'auto'
  83. else:
  84. msg = "the type of `eval_metrics` must be str, list or tuple, " \
  85. "but got {}.".format(type(eval_metrics))
  86. LOGGER.error(TAG, msg)
  87. raise TypeError(msg)
  88. return eval_metrics_
  89. class Fuzzer:
  90. """
  91. Fuzzing test framework for deep neural networks.
  92. Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep
  93. Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_
  94. Args:
  95. target_model (Model): Target fuzz model.
  96. train_dataset (numpy.ndarray): Training dataset used for determining
  97. the neurons' output boundaries.
  98. neuron_num (int): The number of testing neurons.
  99. segmented_num (int): The number of segmented sections of neurons'
  100. output intervals. Default: 1000.
  101. Examples:
  102. >>> net = Net()
  103. >>> mutate_config = [{'method': 'Blur',
  104. >>> 'params': {'auto_param': [True]}},
  105. >>> {'method': 'Contrast',
  106. >>> 'params': {'factor': [2]}},
  107. >>> {'method': 'Translate',
  108. >>> 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}},
  109. >>> {'method': 'FGSM',
  110. >>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}]
  111. >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
  112. >>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000)
  113. >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds)
  114. """
  115. def __init__(self, target_model, train_dataset, neuron_num,
  116. segmented_num=1000):
  117. self._target_model = check_model('model', target_model, Model)
  118. train_dataset = check_numpy_param('train_dataset', train_dataset)
  119. self._coverage_metrics = ModelCoverageMetrics(target_model,
  120. neuron_num,
  121. segmented_num,
  122. train_dataset)
  123. # Allowed mutate strategies so far.
  124. self._strategies = {'Contrast': Contrast,
  125. 'Brightness': Brightness,
  126. 'Blur': Blur,
  127. 'Noise': Noise,
  128. 'Translate': Translate,
  129. 'Scale': Scale,
  130. 'Shear': Shear,
  131. 'Rotate': Rotate,
  132. 'FGSM': FastGradientSignMethod,
  133. 'PGD': ProjectedGradientDescent,
  134. 'MDIIM': MomentumDiverseInputIterativeMethod}
  135. self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate']
  136. self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur',
  137. 'Noise']
  138. self._attacks_list = ['FGSM', 'PGD', 'MDIIM']
  139. self._attack_param_checklists = {
  140. 'FGSM': {'eps': {'dtype': [float],
  141. 'range': [0, 1]},
  142. 'alpha': {'dtype': [float],
  143. 'range': [0, 1]},
  144. 'bounds': {'dtype': [tuple]}},
  145. 'PGD': {'eps': {'dtype': [float],
  146. 'range': [0, 1]},
  147. 'eps_iter': {
  148. 'dtype': [float],
  149. 'range': [0, 1]},
  150. 'nb_iter': {'dtype': [int],
  151. 'range': [0, 100000]},
  152. 'bounds': {'dtype': [tuple]}},
  153. 'MDIIM': {'eps': {'dtype': [float],
  154. 'range': [0, 1]},
  155. 'norm_level': {'dtype': [str, int],
  156. 'range': [1, 2, '1', '2', 'l1', 'l2',
  157. 'inf', 'np.inf']},
  158. 'prob': {'dtype': [float],
  159. 'range': [0, 1]},
  160. 'bounds': {'dtype': [tuple]}}}
  161. def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
  162. eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20):
  163. """
  164. Fuzzing tests for deep neural networks.
  165. Args:
  166. mutate_config (list): Mutate configs. The format is
  167. [{'method': 'Blur',
  168. 'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}},
  169. {'method': 'Contrast',
  170. 'params': {'factor': [1, 1.5, 2]}},
  171. {'method': 'FGSM',
  172. 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}},
  173. ...].
  174. The supported methods list is in `self._strategies`, and the
  175. params of each method must within the range of optional parameters. 
  176. Supported methods are grouped in three types:
  177. Firstly, pixel value based transform methods include:
  178. 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine
  179. transform methods include: 'Translate', 'Scale', 'Shear' and
  180. 'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'.
  181. `mutate_config` must have method in the type of pixel value based
  182. transform methods. The way of setting parameters for first and
  183. second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'.
  184. For third type methods, the optional parameters refer to
  185. `self._attack_param_checklists`.
  186. initial_seeds (list[list]): Initial seeds used to generate mutated
  187. samples. The format of initial seeds is [[image_data, label],
  188. [...], ...].
  189. coverage_metric (str): Model coverage metric of neural networks. All
  190. supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'.
  191. eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the
  192. type is 'auto', it will calculate all the metrics, else if the
  193. type is list or tuple, it will calculate the metrics specified
  194. by user. All supported evaluate methods are 'accuracy',
  195. 'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'.
  196. max_iters (int): Max number of select a seed to mutate.
  197. Default: 10000.
  198. mutate_num_per_seed (int): The number of mutate times for a seed.
  199. Default: 20.
  200. Returns:
  201. - list, mutated samples in fuzz_testing.
  202. - list, ground truth labels of mutated samples.
  203. - list, preds of mutated samples.
  204. - list, strategies of mutated samples.
  205. - dict, metrics report of fuzzer.
  206. Raises:
  207. TypeError: If the type of `eval_metrics` is not str, list or tuple.
  208. TypeError: If the type of metric in list `eval_metrics` is not str.
  209. ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str.
  210. ValueError: If metric in list `eval_metrics` is not in ['accuracy',
  211. 'attack_success_rate', 'kmnc', 'nbc', 'snac'].
  212. """
  213. # Check parameters.
  214. eval_metrics_ = _check_eval_metrics(eval_metrics)
  215. if coverage_metric not in ['KMNC', 'NBC', 'SNAC']:
  216. msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], " \
  217. "but got {}.".format(coverage_metric)
  218. LOGGER.error(TAG, msg)
  219. raise ValueError(msg)
  220. max_iters = check_int_positive('max_iters', max_iters)
  221. mutate_num_per_seed = check_int_positive('mutate_num_per_seed',
  222. mutate_num_per_seed)
  223. mutate_config = self._check_mutate_config(mutate_config)
  224. mutates = self._init_mutates(mutate_config)
  225. initial_seeds = check_param_type('initial_seeds', initial_seeds, list)
  226. if not initial_seeds:
  227. msg = 'initial_seeds must not be empty.'
  228. raise ValueError(msg)
  229. for seed in initial_seeds:
  230. check_param_type('seed', seed, list)
  231. if len(seed) != 2:
  232. msg = 'seed in initial seeds must have two element image and ' \
  233. 'label, but got {} element.'.format(len(seed))
  234. raise ValueError(msg)
  235. check_numpy_param('seed[0]', seed[0])
  236. check_numpy_param('seed[1]', seed[1])
  237. seed.append(0)
  238. seed, initial_seeds = _select_next(initial_seeds)
  239. fuzz_samples = []
  240. true_labels = []
  241. fuzz_preds = []
  242. fuzz_strategies = []
  243. iter_num = 0
  244. while initial_seeds and iter_num < max_iters:
  245. # Mutate a seed.
  246. mutate_samples, mutate_strategies = self._metamorphic_mutate(seed,
  247. mutates,
  248. mutate_config,
  249. mutate_num_per_seed)
  250. # Calculate the coverages and predictions of generated samples.
  251. coverages, predicts = self._get_coverages_and_predict(mutate_samples,
  252. coverage_metric)
  253. coverage_gains = _coverage_gains(coverages)
  254. for mutate, cov, pred, strategy in zip(mutate_samples,
  255. coverage_gains,
  256. predicts, mutate_strategies):
  257. fuzz_samples.append(mutate[0])
  258. true_labels.append(mutate[1])
  259. fuzz_preds.append(pred)
  260. fuzz_strategies.append(strategy)
  261. # if the mutate samples has coverage gains add this samples in
  262. # the initial seeds to guide new mutates.
  263. if cov > 0:
  264. initial_seeds.append(mutate)
  265. seed, initial_seeds = _select_next(initial_seeds)
  266. iter_num += 1
  267. metrics_report = None
  268. if eval_metrics_ is not None:
  269. metrics_report = self._evaluate(fuzz_samples,
  270. true_labels,
  271. fuzz_preds,
  272. fuzz_strategies,
  273. eval_metrics_)
  274. return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report
  275. def _get_coverages_and_predict(self, mutate_samples,
  276. coverage_metric="KNMC"):
  277. """ Calculate the coverages and predictions of generated samples."""
  278. samples = [s[0] for s in mutate_samples]
  279. samples = np.array(samples)
  280. coverages = []
  281. predictions = self._target_model.predict(
  282. Tensor(samples.astype(np.float32)))
  283. predictions = predictions.asnumpy()
  284. for index in range(len(samples)):
  285. mutate = samples[:index + 1]
  286. self._coverage_metrics.calculate_coverage(mutate.astype(np.float32))
  287. if coverage_metric == 'KMNC':
  288. coverages.append(self._coverage_metrics.get_kmnc())
  289. if coverage_metric == 'NBC':
  290. coverages.append(self._coverage_metrics.get_nbc())
  291. if coverage_metric == 'SNAC':
  292. coverages.append(self._coverage_metrics.get_snac())
  293. return coverages, predictions
  294. def _metamorphic_mutate(self, seed, mutates, mutate_config,
  295. mutate_num_per_seed):
  296. """Mutate a seed using strategies random selected from mutate_config."""
  297. mutate_samples = []
  298. mutate_strategies = []
  299. for _ in range(mutate_num_per_seed):
  300. only_pixel_trans = seed[2]
  301. strategy = choice(mutate_config)
  302. # Choose a pixel value based transform method
  303. if only_pixel_trans:
  304. while strategy['method'] not in self._pixel_value_trans_list:
  305. strategy = choice(mutate_config)
  306. transform = mutates[strategy['method']]
  307. params = strategy['params']
  308. method = strategy['method']
  309. selected_param = {}
  310. for p in params:
  311. selected_param[p] = choice(params[p])
  312. if method in list(self._pixel_value_trans_list + self._affine_trans_list):
  313. if method == 'Shear':
  314. shear_keys = selected_param.keys()
  315. if 'factor_x' in shear_keys and 'factor_y' in shear_keys:
  316. selected_param[choice(['factor_x', 'factor_y'])] = 0
  317. transform.set_params(**selected_param)
  318. mutate_sample = transform.transform(seed[0])
  319. else:
  320. for param_name in selected_param:
  321. transform.__setattr__('_' + str(param_name),
  322. selected_param[param_name])
  323. mutate_sample = transform.generate([seed[0].astype(np.float32)],
  324. [seed[1]])[0]
  325. if method not in self._pixel_value_trans_list:
  326. only_pixel_trans = 1
  327. mutate_sample = [mutate_sample, seed[1], only_pixel_trans]
  328. if _is_trans_valid(seed[0], mutate_sample[0]):
  329. mutate_samples.append(mutate_sample)
  330. mutate_strategies.append(method)
  331. if not mutate_samples:
  332. mutate_samples.append(seed)
  333. mutate_strategies.append(None)
  334. return np.array(mutate_samples), mutate_strategies
  335. def _check_mutate_config(self, mutate_config):
  336. """Check whether the mutate_config meet the specification."""
  337. mutate_config = check_param_type('mutate_config', mutate_config, list)
  338. has_pixel_trans = False
  339. for config in mutate_config:
  340. check_param_type("config", config, dict)
  341. if set(config.keys()) != {'method', 'params'}:
  342. msg = "The key of each config must be in ('method', 'params'), " \
  343. "but got {}.".format(set(config.keys()))
  344. LOGGER.error(TAG, msg)
  345. raise KeyError(msg)
  346. method = config['method']
  347. params = config['params']
  348. # Method must be in the optional range.
  349. if method not in self._strategies.keys():
  350. msg = "Config methods must be in {}, but got {}." \
  351. .format(self._strategies.keys(), method)
  352. LOGGER.error(TAG, msg)
  353. raise ValueError(msg)
  354. if config['method'] in self._pixel_value_trans_list:
  355. has_pixel_trans = True
  356. check_param_type('params', params, dict)
  357. # Check parameters of attack methods. The parameters of transformed
  358. # methods will be verified in transferred parameters.
  359. if method in self._attacks_list:
  360. self._check_attack_params(method, params)
  361. else:
  362. for key in params.keys():
  363. check_param_type(str(key), params[key], list)
  364. # Methods in `metate_config` should at least have one in the type of
  365. # pixel value based transform methods.
  366. if not has_pixel_trans:
  367. msg = "mutate methods in mutate_config should at least have one " \
  368. "in {}".format(self._pixel_value_trans_list)
  369. raise ValueError(msg)
  370. return mutate_config
  371. def _check_attack_params(self, method, params):
  372. """Check input parameters of attack methods."""
  373. allow_params = self._attack_param_checklists[method].keys()
  374. for param_name in params:
  375. if param_name not in allow_params:
  376. msg = "parameters of {} must in {}".format(method, allow_params)
  377. raise ValueError(msg)
  378. check_param_type(param_name, params[param_name], list)
  379. for param_value in params[param_name]:
  380. if param_name == 'bounds':
  381. bounds = check_param_multi_types('bounds', param_value, [tuple])
  382. if len(bounds) != 2:
  383. msg = 'The format of bounds must be format (lower_bound, upper_bound),' \
  384. 'but got its length as{}'.format(len(bounds))
  385. raise ValueError(msg)
  386. for bound_value in bounds:
  387. _ = check_param_multi_types('bound', bound_value,
  388. [int, float])
  389. if bounds[0] >= bounds[1]:
  390. msg = "upper bound must more than lower bound, " \
  391. "but upper bound got {}, lower bound " \
  392. "got {}".format(bounds[0], bounds[1])
  393. raise ValueError(msg)
  394. elif param_name == 'norm_level':
  395. _ = check_norm_level(param_value)
  396. else:
  397. allow_type = self._attack_param_checklists[method][param_name]['dtype']
  398. allow_range = self._attack_param_checklists[method][param_name]['range']
  399. _ = check_param_multi_types(str(param_name), param_value, allow_type)
  400. _ = check_param_in_range(str(param_name),
  401. param_value,
  402. allow_range[0],
  403. allow_range[1])
  404. def _init_mutates(self, mutate_config):
  405. """ Check whether the mutate_config meet the specification."""
  406. mutates = {}
  407. for mutate in mutate_config:
  408. method = mutate['method']
  409. if method not in self._attacks_list:
  410. mutates[method] = self._strategies[method]()
  411. else:
  412. network = self._target_model._network
  413. loss_fn = self._target_model._loss_fn
  414. mutates[method] = self._strategies[method](network,
  415. loss_fn=loss_fn)
  416. return mutates
  417. def _evaluate(self, fuzz_samples, true_labels, fuzz_preds,
  418. fuzz_strategies, metrics):
  419. """
  420. Evaluate generated fuzz_testing samples in three dimensions: accuracy,
  421. attack success rate and neural coverage.
  422. Args:
  423. fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples
  424. according to seeds.
  425. true_labels ([numpy.ndarray, list]): Ground truth labels of seeds.
  426. fuzz_preds ([numpy.ndarray, list]): Predictions of generated fuzz samples.
  427. fuzz_strategies ([numpy.ndarray, list]): Mutate strategies of fuzz samples.
  428. metrics (Union[list, tuple, str]): evaluation metrics.
  429. Returns:
  430. dict, evaluate metrics include accuracy, attack success rate
  431. and neural coverage.
  432. """
  433. true_labels = np.asarray(true_labels)
  434. fuzz_preds = np.asarray(fuzz_preds)
  435. temp = np.argmax(true_labels, axis=1) == np.argmax(fuzz_preds, axis=1)
  436. metrics_report = {}
  437. if metrics == 'auto' or 'accuracy' in metrics:
  438. if temp.any():
  439. acc = np.sum(temp) / np.size(temp)
  440. else:
  441. acc = 0
  442. metrics_report['Accuracy'] = acc
  443. if metrics == 'auto' or 'attack_success_rate' in metrics:
  444. cond = [elem in self._attacks_list for elem in fuzz_strategies]
  445. temp = temp[cond]
  446. if temp.any():
  447. attack_success_rate = 1 - np.sum(temp) / np.size(temp)
  448. else:
  449. attack_success_rate = None
  450. metrics_report['Attack_success_rate'] = attack_success_rate
  451. if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics:
  452. self._coverage_metrics.calculate_coverage(
  453. np.array(fuzz_samples).astype(np.float32))
  454. if metrics == 'auto' or 'kmnc' in metrics:
  455. kmnc = self._coverage_metrics.get_kmnc()
  456. metrics_report['Neural_coverage_KMNC'] = kmnc
  457. if metrics == 'auto' or 'nbc' in metrics:
  458. nbc = self._coverage_metrics.get_nbc()
  459. metrics_report['Neural_coverage_NBC'] = nbc
  460. if metrics == 'auto' or 'snac' in metrics:
  461. snac = self._coverage_metrics.get_snac()
  462. metrics_report['Neural_coverage_SNAC'] = snac
  463. return metrics_report

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。