diff --git a/examples/ai_fuzzer/lenet5_mnist_coverage.py b/examples/ai_fuzzer/lenet5_mnist_coverage.py index 4e77a32..27150a5 100644 --- a/examples/ai_fuzzer/lenet5_mnist_coverage.py +++ b/examples/ai_fuzzer/lenet5_mnist_coverage.py @@ -14,11 +14,10 @@ import numpy as np from mindspore import Model from mindspore import context -from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.adv_robustness.attacks import FastGradientSignMethod -from mindarmour.fuzz_testing import ModelCoverageMetrics +from mindarmour.fuzz_testing.model_coverage_metrics import NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage,\ + SuperNeuronActivateCoverage, KMultisectionNeuronCoverage from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset @@ -46,13 +45,6 @@ def test_lenet_mnist_coverage(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) - neuron_num = 10 - segmented_num = 1000 - top_k = 3 - threshold = 0.1 - - # initialize fuzz test with training dataset - model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -67,31 +59,31 @@ def test_lenet_mnist_coverage(): test_images.append(images) test_labels.append(labels) test_images = np.concatenate(test_images, axis=0) - test_labels = np.concatenate(test_labels, axis=0) - model_fuzz_test.calculate_coverage(test_images) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) - model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + # initialize fuzz test with training dataset + nc = NeuronCoverage(model, threshold=0.1) + nc_metric = nc.get_metrics(test_images) + + tknc = TopKNeuronCoverage(model, top_k=3) + tknc_metrics = tknc.get_metrics(test_images) + + snac = SuperNeuronActivateCoverage(model, train_images) + snac_metrics = snac.get_metrics(test_images) + + nbc = NeuronBoundsCoverage(model, train_images) + nbc_metrics = nbc.get_metrics(test_images) - # generate adv_data - loss = SoftmaxCrossEntropyWithLogits(sparse=True) - attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) - adv_data = attack.batch_generate(test_images, test_labels, batch_size=32) - model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) - LOGGER.info(TAG, 'KMNC of this adv data is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) + kmnc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + kmnc_metrics = kmnc.get_metrics(test_images) - model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + print('KMNC of this test is: ', kmnc_metrics) + print('NBC of this test is: ', nbc_metrics) + print('SNAC of this test is: ', snac_metrics) + print('NC of this test is: ', nc_metric) + print('TKNC of this test is: ', tknc_metrics) if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") test_lenet_mnist_coverage() diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index 7707c44..2c49e39 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -14,11 +14,11 @@ import numpy as np from mindspore import Model from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import load_checkpoint, load_param_into_net from mindarmour.fuzz_testing import Fuzzer -from mindarmour.fuzz_testing import ModelCoverageMetrics -from mindarmour.utils.logger import LogUtil +from mindarmour.fuzz_testing import KMultisectionNeuronCoverage +from mindarmour.utils import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 @@ -52,7 +52,7 @@ def test_lenet_mnist_fuzzing(): 'params': {'auto_param': [True]}}, {'method': 'FGSM', 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} - ] + ] # get training data data_list = "../common/dataset/MNIST/train" @@ -63,11 +63,6 @@ def test_lenet_mnist_fuzzing(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) - neuron_num = 10 - segmented_num = 1000 - - # initialize fuzz test with training dataset - model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -88,21 +83,20 @@ def test_lenet_mnist_fuzzing(): # make initial seeds for img, label in zip(test_images, test_labels): initial_seeds.append([img, label]) - + coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) + kmnc = coverage.get_metrics(test_images[:100]) + print('KMNC of initial seeds is: ', kmnc) initial_seeds = initial_seeds[:100] - model_coverage_test.calculate_coverage( - np.array(test_images[:100]).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) + model_fuzz_test = Fuzzer(model) + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, + mutate_num_per_seed=20) - model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto') if metrics: for key in metrics: - LOGGER.info(TAG, key + ': %s', metrics[key]) + print(key + ': ', metrics[key]) if __name__ == '__main__': - # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + # device_target can be "CPU"GPU, "" or "Ascend" + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") test_lenet_mnist_fuzzing() diff --git a/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py b/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py index 803edd0..ed12d5c 100644 --- a/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py +++ b/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py @@ -20,19 +20,21 @@ from mindspore.ops import TensorSummary def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """Wrap conv.""" weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="valid") def fc_with_initialize(input_channels, out_channels): + """Wrap initialize method of full connection layer.""" weight = weight_variable() bias = weight_variable() return nn.Dense(input_channels, out_channels, weight, bias) def weight_variable(): + """Wrap initialize variable.""" return TruncatedNormal(0.05) @@ -50,7 +52,6 @@ class LeNet5(nn.Cell): self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() - self.summary = TensorSummary() def construct(self, x): @@ -59,8 +60,6 @@ class LeNet5(nn.Cell): Returns: x (tensor): network output """ - self.summary('input', x) - x = self.conv1(x) self.summary('1', x) diff --git a/mindarmour/fuzz_testing/__init__.py b/mindarmour/fuzz_testing/__init__.py index 25026dc..871f4c1 100644 --- a/mindarmour/fuzz_testing/__init__.py +++ b/mindarmour/fuzz_testing/__init__.py @@ -16,7 +16,13 @@ This module provides a neuron coverage-gain based fuzz method to evaluate the robustness of given model. """ from .fuzzing import Fuzzer -from .model_coverage_metrics import ModelCoverageMetrics +from .model_coverage_metrics import CoverageMetrics, NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage, \ + SuperNeuronActivateCoverage, KMultisectionNeuronCoverage __all__ = ['Fuzzer', - 'ModelCoverageMetrics'] + 'CoverageMetrics', + 'NeuronCoverage', + 'TopKNeuronCoverage', + 'NeuronBoundsCoverage', + 'SuperNeuronActivateCoverage', + 'KMultisectionNeuronCoverage'] diff --git a/mindarmour/fuzz_testing/fuzzing.py b/mindarmour/fuzz_testing/fuzzing.py index 2ddd9ad..bc57185 100644 --- a/mindarmour/fuzz_testing/fuzzing.py +++ b/mindarmour/fuzz_testing/fuzzing.py @@ -21,15 +21,14 @@ from mindspore import Model from mindspore import Tensor from mindspore import nn -from mindarmour.utils._check_param import check_model, check_numpy_param, \ - check_param_multi_types, check_norm_level, check_param_in_range, \ - check_param_type, check_int_positive +from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ + check_param_in_range, check_param_type, check_int_positive, check_param_bounds from mindarmour.utils.logger import LogUtil from ..adv_robustness.attacks import FastGradientSignMethod, \ MomentumDiverseInputIterativeMethod, ProjectedGradientDescent from .image_transform import Contrast, Brightness, Blur, \ Noise, Translate, Scale, Shear, Rotate -from .model_coverage_metrics import ModelCoverageMetrics +from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage LOGGER = LogUtil.get_instance() TAG = 'Fuzzer' @@ -43,11 +42,22 @@ def _select_next(initial_seeds): return seed, initial_seeds -def _coverage_gains(coverages): - """ Calculate the coverage gains of mutated samples.""" - gains = [0] + coverages[:-1] +def _coverage_gains(pre_coverage, coverages): + """ + Calculate the coverage gains of mutated samples. + + Args: + pre_coverage (float): Last value of coverages for previous mutated samples. + coverages (list): Coverage of mutated samples. + + Returns: + - list, coverage gains for mutated samples. + + - float, last value in parameter coverages. + """ + gains = [pre_coverage] + coverages[:-1] gains = np.array(coverages) - np.array(gains) - return gains + return gains, coverages[-1] def _is_trans_valid(seed, mutate_sample): @@ -65,37 +75,22 @@ def _is_trans_valid(seed, mutate_sample): size = np.shape(diff)[0] l0_norm = np.linalg.norm(diff, ord=0) linf = np.linalg.norm(diff, ord=np.inf) - if l0_norm > pixels_change_rate*size: + if l0_norm > pixels_change_rate * size: if linf < 256: is_valid = True else: - if linf < pixel_value_change_rate*255: + if linf < pixel_value_change_rate * 255: is_valid = True return is_valid -def _check_eval_metrics(eval_metrics): - """ Check evaluation metrics.""" - if isinstance(eval_metrics, (list, tuple)): - eval_metrics_ = [] - available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] - for elem in eval_metrics: - if elem not in available_metrics: - msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem) - LOGGER.error(TAG, msg) - raise ValueError(msg) - eval_metrics_.append(elem.lower()) - elif isinstance(eval_metrics, str): - if eval_metrics != 'auto': - msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics) - LOGGER.error(TAG, msg) - raise ValueError(msg) - eval_metrics_ = 'auto' +def _gain_threshold(coverage): + """Get threshold for given neuron coverage class.""" + if coverage is isinstance(coverage, KMultisectionNeuronCoverage): + gain_threshold = 0.1 / coverage.segmented_num else: - msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - return eval_metrics_ + gain_threshold = 0 + return gain_threshold class Fuzzer: @@ -113,6 +108,7 @@ class Fuzzer: Examples: >>> net = Net() + >>> model = Model(net) >>> mutate_config = [{'method': 'Blur', >>> 'params': {'auto_param': [True]}}, >>> {'method': 'Contrast', @@ -121,18 +117,15 @@ class Fuzzer: >>> 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, >>> {'method': 'FGSM', >>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] - >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - >>> neuron_num = 10 - >>> segmented_num = 1000 - >>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) - >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds) + >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + >>> model_fuzz_test = Fuzzer(model) + >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, + >>> nc, max_iters=100) """ - def __init__(self, target_model, train_dataset, neuron_num, - segmented_num=1000): + def __init__(self, target_model): self._target_model = check_model('model', target_model, Model) - train_dataset = check_numpy_param('train_dataset', train_dataset) - self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset) + # Allowed mutate strategies so far. self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, @@ -161,8 +154,7 @@ class Fuzzer: 'prob': {'dtype': [float], 'range': [0, 1]}, 'bounds': {'dtype': [tuple]}}} - def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', - eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): + def fuzzing(self, mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10000, mutate_num_per_seed=20): """ Fuzzing tests for deep neural networks. @@ -175,32 +167,20 @@ class Fuzzer: {'method': 'FGSM', 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, ...]. - The supported methods list is in `self._strategies`, and the - params of each method must within the range of optional parameters.  - Supported methods are grouped in three types: - Firstly, pixel value based transform methods include: - 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine - transform methods include: 'Translate', 'Scale', 'Shear' and - 'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'. - `mutate_config` must have method in the type of pixel value based - transform methods. The way of setting parameters for first and - second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'. - For third type methods, the optional parameters refer to + The supported methods list is in `self._strategies`, and the params of each method must within the + range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based + transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform + methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', + 'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. + The way of setting parameters for first and second type methods can be seen in + 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to `self._attack_param_checklists`. - initial_seeds (list[list]): Initial seeds used to generate mutated - samples. The format of initial seeds is [[image_data, label], - [...], ...] and the label must be one-hot. - coverage_metric (str): Model coverage metric of neural networks. All - supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'. - eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the - type is 'auto', it will calculate all the metrics, else if the - type is list or tuple, it will calculate the metrics specified - by user. All supported evaluate methods are 'accuracy', - 'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'. - max_iters (int): Max number of select a seed to mutate. - Default: 10000. - mutate_num_per_seed (int): The number of mutate times for a seed. - Default: 20. + initial_seeds (list[list]): Initial seeds used to generate mutated samples. The format of initial seeds is + [[image_data, label], [...], ...] and the label must be one-hot. + coverage (CoverageMetrics): Class of neuron coverage metrics. + evaluate (bool): return evaluate report or not. Default: True. + max_iters (int): Max number of select a seed to mutate. Default: 10000. + mutate_num_per_seed (int): The number of mutate times for a seed. Default: 20. Returns: - list, mutated samples in fuzz_testing. @@ -214,18 +194,18 @@ class Fuzzer: - dict, metrics report of fuzzer. Raises: - TypeError: If the type of `eval_metrics` is not str, list or tuple. - TypeError: If the type of metric in list `eval_metrics` is not str. - ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str. - ValueError: If metric in list `eval_metrics` is not in ['accuracy', - 'attack_success_rate', 'kmnc', 'nbc', 'snac']. + ValueError, coverage must be subclass of CoverageMetrics. + + ValueError, if initial seeds is empty. + + ValueError, if element of seed is not two in initial seeds. """ # Check parameters. - eval_metrics_ = _check_eval_metrics(eval_metrics) - if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: - msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric) + if not isinstance(coverage, CoverageMetrics): + msg = 'coverage must be subclass of CoverageMetrics' LOGGER.error(TAG, msg) raise ValueError(msg) + evaluate = check_param_type('evaluate', evaluate, bool) max_iters = check_int_positive('max_iters', max_iters) mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) mutate_config = self._check_mutate_config(mutate_config) @@ -235,15 +215,21 @@ class Fuzzer: if not initial_seeds: msg = 'initial_seeds must not be empty.' raise ValueError(msg) + initial_samples = [] for seed in initial_seeds: check_param_type('seed', seed, list) if len(seed) != 2: - msg = 'seed in initial seeds must have two element image and ' \ - 'label, but got {} element.'.format(len(seed)) + msg = 'seed in initial seeds must have two element image and label, but got {} element.'.format( + len(seed)) raise ValueError(msg) check_numpy_param('seed[0]', seed[0]) check_numpy_param('seed[1]', seed[1]) + initial_samples.append(seed[0]) seed.append(0) + initial_samples = np.array(initial_samples) + # calculate the coverage of initial seeds + pre_coverage = coverage.get_metrics(initial_samples) + gain_threshold = _gain_threshold(coverage) seed, initial_seeds = _select_next(initial_seeds) fuzz_samples = [] @@ -253,30 +239,27 @@ class Fuzzer: iter_num = 0 while initial_seeds and iter_num < max_iters: # Mutate a seed. - mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, - mutates, - mutate_config, + mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, mutates, mutate_config, mutate_num_per_seed) # Calculate the coverages and predictions of generated samples. - coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric) - coverage_gains = _coverage_gains(coverages) + coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage) + coverage_gains, pre_coverage = _coverage_gains(pre_coverage, coverages) for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies): fuzz_samples.append(mutate[0]) true_labels.append(mutate[1]) fuzz_preds.append(pred) fuzz_strategies.append(strategy) - # if the mutate samples has coverage gains add this samples in - # the initial_seeds to guide new mutates. - if cov > 0: + # if the mutate samples has coverage gains add this samples in the initial_seeds to guide new mutates. + if cov > gain_threshold: initial_seeds.append(mutate) seed, initial_seeds = _select_next(initial_seeds) iter_num += 1 metrics_report = None - if eval_metrics_ is not None: - metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_) + if evaluate: + metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage) return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report - def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"): + def _get_coverages_and_predict(self, mutate_samples, coverage): """ Calculate the coverages and predictions of generated samples.""" samples = [s[0] for s in mutate_samples] samples = np.array(samples) @@ -285,17 +268,10 @@ class Fuzzer: predictions = predictions.asnumpy() for index in range(len(samples)): mutate = samples[:index + 1] - self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) - if coverage_metric == 'KMNC': - coverages.append(self._coverage_metrics.get_kmnc()) - if coverage_metric == 'NBC': - coverages.append(self._coverage_metrics.get_nbc()) - if coverage_metric == 'SNAC': - coverages.append(self._coverage_metrics.get_snac()) + coverages.append(coverage.get_metrics(mutate)) return coverages, predictions - def _metamorphic_mutate(self, seed, mutates, mutate_config, - mutate_num_per_seed): + def _metamorphic_mutate(self, seed, mutates, mutate_config, mutate_num_per_seed): """Mutate a seed using strategies random selected from mutate_config.""" mutate_samples = [] mutate_strategies = [] @@ -310,8 +286,8 @@ class Fuzzer: params = strategy['params'] method = strategy['method'] selected_param = {} - for p in params: - selected_param[p] = choice(params[p]) + for param in params: + selected_param[param] = choice(params[param]) if method in list(self._pixel_value_trans_list + self._affine_trans_list): if method == 'Shear': @@ -367,8 +343,7 @@ class Fuzzer: else: for key in params.keys(): check_param_type(str(key), params[key], list) - # Methods in `metate_config` should at least have one in the type of - # pixel value based transform methods. + # Methods in `metate_config` should at least have one in the type of pixel value based transform methods. if not has_pixel_trans: msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list) raise ValueError(msg) @@ -386,17 +361,7 @@ class Fuzzer: check_param_type(param_name, params[param_name], list) for param_value in params[param_name]: if param_name == 'bounds': - bounds = check_param_multi_types('bounds', param_value, [tuple]) - if len(bounds) != 2: - msg = 'The format of bounds must be format (lower_bound, upper_bound),' \ - 'but got its length as{}'.format(len(bounds)) - raise ValueError(msg) - for bound_value in bounds: - _ = check_param_multi_types('bound', bound_value, [int, float]) - if bounds[0] >= bounds[1]: - msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \ - "got {}".format(bounds[0], bounds[1]) - raise ValueError(msg) + _ = check_param_bounds('bounds', param_name) elif param_name == 'norm_level': _ = check_norm_level(param_value) else: @@ -420,57 +385,40 @@ class Fuzzer: mutates[method] = self._strategies[method](network, loss_fn=loss_fn) return mutates - def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, - fuzz_strategies, metrics): + def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage): """ - Evaluate generated fuzz_testing samples in three dimensions: accuracy, - attack success rate and neural coverage. + Evaluate generated fuzz_testing samples in three dimensions: accuracy, attack success rate and neural coverage. Args: - fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples - according to seeds. + fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples according to seeds. true_labels ([numpy.ndarray, list]): Ground truth labels of seeds. fuzz_preds ([numpy.ndarray, list]): Predictions of generated fuzz samples. fuzz_strategies ([numpy.ndarray, list]): Mutate strategies of fuzz samples. - metrics (Union[list, tuple, str]): evaluation metrics. + coverage (CoverageMetrics): Neuron coverage metrics class. Returns: - dict, evaluate metrics include accuracy, attack success rate - and neural coverage. + dict, evaluate metrics include accuracy, attack success rate and neural coverage. """ + fuzz_samples = np.array(fuzz_samples) true_labels = np.asarray(true_labels) fuzz_preds = np.asarray(fuzz_preds) temp = np.argmax(true_labels, axis=1) == np.argmax(fuzz_preds, axis=1) metrics_report = {} - if metrics == 'auto' or 'accuracy' in metrics: - if temp.any(): - acc = np.sum(temp) / np.size(temp) - else: - acc = 0 - metrics_report['Accuracy'] = acc - - if metrics == 'auto' or 'attack_success_rate' in metrics: - cond = [elem in self._attacks_list for elem in fuzz_strategies] - temp = temp[cond] - if temp.any(): - attack_success_rate = 1 - np.sum(temp) / np.size(temp) - else: - attack_success_rate = None - metrics_report['Attack_success_rate'] = attack_success_rate - - if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: - self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32)) - - if metrics == 'auto' or 'kmnc' in metrics: - kmnc = self._coverage_metrics.get_kmnc() - metrics_report['Neural_coverage_KMNC'] = kmnc - - if metrics == 'auto' or 'nbc' in metrics: - nbc = self._coverage_metrics.get_nbc() - metrics_report['Neural_coverage_NBC'] = nbc - if metrics == 'auto' or 'snac' in metrics: - snac = self._coverage_metrics.get_snac() - metrics_report['Neural_coverage_SNAC'] = snac + if temp.any(): + acc = np.sum(temp) / np.size(temp) + else: + acc = 0 + metrics_report['Accuracy'] = acc + + cond = [elem in self._attacks_list for elem in fuzz_strategies] + temp = temp[cond] + if temp.any(): + attack_success_rate = 1 - np.sum(temp) / np.size(temp) + else: + attack_success_rate = None + metrics_report['Attack_success_rate'] = attack_success_rate + + metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) return metrics_report diff --git a/mindarmour/fuzz_testing/model_coverage_metrics.py b/mindarmour/fuzz_testing/model_coverage_metrics.py index 24482fa..7b87cd3 100644 --- a/mindarmour/fuzz_testing/model_coverage_metrics.py +++ b/mindarmour/fuzz_testing/model_coverage_metrics.py @@ -14,311 +14,396 @@ """ Model-Test Coverage Metrics. """ - +from abc import abstractmethod from collections import defaultdict +import math import numpy as np from mindspore import Tensor from mindspore import Model from mindspore.train.summary.summary_record import _get_summary_tensor_data -from mindarmour.utils._check_param import check_model, check_numpy_param, \ - check_int_positive, check_param_multi_types +from mindarmour.utils._check_param import check_model, check_numpy_param, check_int_positive, \ + check_param_type, check_value_positive from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() -TAG = 'ModelCoverageMetrics' +TAG = 'CoverageMetrics' -class ModelCoverageMetrics: +class CoverageMetrics: """ - As we all known, each neuron output of a network will have a output range - after training (we call it original range), and test dataset is used to - estimate the accuracy of the trained network. However, neurons' output - distribution would be different with different test datasets. Therefore, - similar to function fuzz, model fuzz means testing those neurons' outputs - and estimating the proportion of original range that has emerged with test + The abstract base class for Neuron coverage classes calculating coverage metrics. + + As we all known, each neuron output of a network will have a output range after training (we call it original + range), and test dataset is used to estimate the accuracy of the trained network. However, neurons' output + distribution would be different with different test datasets. Therefore, similar to function fuzz, model fuzz means + testing those neurons' outputs and estimating the proportion of original range that has emerged with test datasets. - Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep - Learning Systems `_ + Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep Learning Systems + `_ Args: model (Model): The pre-trained model which waiting for testing. - neuron_num (int): The number of testing neurons. - segmented_num (int): The number of segmented sections of neurons' output intervals. - train_dataset (numpy.ndarray): Training dataset used for determine - the neurons' output boundaries. - - Raises: - ValueError: If neuron_num is too big (for example, bigger than 1e+9). - - Examples: - >>> net = LeNet5() - >>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) - >>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) - >>> model = Model(net) - >>> neuron_num = 10 - >>> segmented_num = 1000 - >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) - >>> model_fuzz_test.calculate_coverage(test_images) - >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) - >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) - >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) - >>> print('NC of this test is : %s', model_fuzz_test.get_nc()) - >>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. """ - def __init__(self, model, neuron_num, segmented_num, train_dataset): + def __init__(self, model, incremental=False, batch_size=32): self._model = check_model('model', model, Model) - self._segmented_num = check_int_positive('segmented_num', segmented_num) - self._neuron_num = check_int_positive('neuron_num', neuron_num) - if self._neuron_num > 1e+9: - msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur' - LOGGER.error(TAG, msg) - raise ValueError(msg) - train_dataset = check_numpy_param('train_dataset', train_dataset) - self._lower_bounds = [np.inf]*self._neuron_num - self._upper_bounds = [-np.inf]*self._neuron_num - self._var = [0]*self._neuron_num - self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)] - self._lower_corner_hits = [0]*self._neuron_num - self._upper_corner_hits = [0]*self._neuron_num - self._bounds_get(train_dataset) - self._model_layer_dict = defaultdict(bool) - self._effective_model_layer_dict = defaultdict(bool) - - def _set_init_effective_coverage_table(self, dataset): + self.incremental = check_param_type('incremental', incremental, bool) + self.batch_size = check_int_positive('batch_size', batch_size) + self._activate_table = defaultdict(list) + + @abstractmethod + def get_metrics(self, dataset): """ - Initialise the coverage table of each neuron in the model. + Calculate coverage metrics of given dataset. Args: - dataset (numpy.ndarray): Dataset used for initialising the coverage table. + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. + + Raises: + NotImplementedError: It is an abstract method. """ - self._model.predict(Tensor(dataset[0:1])) - tensors = _get_summary_tensor_data() - for name, tensor in tensors.items(): - if 'input' in name: - continue - for num_neuron in range(tensor.shape[1]): - self._model_layer_dict[(name, num_neuron)] = False - self._effective_model_layer_dict[(name, num_neuron)] = False - - def _bounds_get(self, train_dataset, batch_size=32): + msg = 'The function get_metrics() is an abstract method in class `CoverageMetrics`, and should be' \ + ' implemented in child class.' + LOGGER.error(TAG, msg) + raise NotImplementedError(msg) + + def _init_neuron_activate_table(self, data): """ - Update the lower and upper boundaries of neurons' outputs. + Initialise the activate table of each neuron in the model with format: + {'layer1': [n1, n2, n3, ..., nn], 'layer2': [n1, n2, n3, ..., nn], ...} Args: - train_dataset (numpy.ndarray): Training dataset used for - determine the neurons' output boundaries. - batch_size (int): The number of samples in a predict batch. - Default: 32. + data (numpy.ndarray): Data used for initialising the activate table. + + Return: + dict, return a activate_table. """ - batch_size = check_int_positive('batch_size', batch_size) - output_mat = [] - batches = train_dataset.shape[0] // batch_size - for i in range(batches): - inputs = train_dataset[i*batch_size: (i + 1)*batch_size] - output = self._model.predict(Tensor(inputs)).asnumpy() - output_mat.append(output) - lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0) - self._lower_bounds = np.min(lower_compare_array, axis=0) - upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0) - self._upper_bounds = np.max(upper_compare_array, axis=0) - if batches == 0: - output = self._model.predict(Tensor(train_dataset)).asnumpy() - self._lower_bounds = np.min(output, axis=0) - self._upper_bounds = np.max(output, axis=0) - output_mat.append(output) - self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0) - - def _sections_hits_count(self, dataset, intervals): + self._model.predict(Tensor(data)) + layer_out = _get_summary_tensor_data() + if not layer_out: + msg = 'User must use TensorSummary() operation to specify the middle layer of the model participating in ' \ + 'the coverage calculation.' + LOGGER.error(TAG, msg) + raise ValueError(msg) + activate_table = defaultdict() + for layer, value in layer_out.items(): + activate_table[layer] = np.zeros(value.shape[1], np.bool) + return activate_table + + def _get_bounds(self, train_dataset): """ - Update the coverage matrix of neurons' output subsections. + Update the lower and upper boundaries of neurons' outputs. Args: - dataset (numpy.ndarray): Testing data. - intervals (list[float]): Segmentation intervals of neurons' outputs. + train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. + + Return: + - numpy.ndarray, upper bounds of neuron' outputs. + + - numpy.ndarray, lower bounds of neuron' outputs. """ - dataset = check_numpy_param('dataset', dataset) - batch_output = self._model.predict(Tensor(dataset)).asnumpy() - batch_section_indexes = (batch_output - self._lower_bounds) // intervals - for section_indexes in batch_section_indexes: - for i in range(self._neuron_num): - if section_indexes[i] < 0: - self._lower_corner_hits[i] = 1 - elif section_indexes[i] >= self._segmented_num: - self._upper_corner_hits[i] = 1 + upper_bounds = defaultdict(list) + lower_bounds = defaultdict(list) + batches = math.ceil(train_dataset.shape[0] / self.batch_size) + for i in range(batches): + inputs = train_dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + min_value = np.min(value, axis=0) + max_value = np.max(value, axis=0) + if np.any(upper_bounds[layer]): + max_flag = upper_bounds[layer] > max_value + min_flag = lower_bounds[layer] < min_value + upper_bounds[layer] = upper_bounds[layer] * max_flag + max_value * (1 - max_flag) + lower_bounds[layer] = lower_bounds[layer] * min_flag + min_value * (1 - min_flag) else: - self._main_section_hits[i][int(section_indexes[i])] = 1 + upper_bounds[layer] = max_value + lower_bounds[layer] = min_value + return upper_bounds, lower_bounds - def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): + def _activate_rate(self): + """ + Calculate the activate rate of neurons. """ - Update the coverage matrix of neural coverage and effective neural coverage. + total_neurons = 0 + activated_neurons = 0 + for _, value in self._activate_table.items(): + activated_neurons += np.sum(value) + total_neurons += len(value) + activate_rate = activated_neurons / total_neurons - Args: - name (string): the name of the tensor. - tensor (tensor): the tensor in the network. - scaled_mean (numpy.ndarray): feature map of the tensor. - scaled_rank (numpy.ndarray): rank of tensor value. - top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. - threshold (float): neuron is covered when its output is greater than the threshold. + return activate_rate + +class NeuronCoverage(CoverageMetrics): + """ + Calculate the neurons activated coverage. Neuron is activated when its output is greater than the threshold. + Neuron coverage equals the proportion of activated neurons to total neurons in the network. + + Args: + model (Model): The pre-trained model which waiting for testing. + threshold (float): Threshold used to determined neurons is activated or not. Default: 0.1. + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. + + """ + def __init__(self, model, threshold=0.1, incremental=False, batch_size=32): + super(NeuronCoverage, self).__init__(model, incremental, batch_size) + self.threshold = check_value_positive('threshold', threshold) + + def get_metrics(self, dataset): """ - for num_neuron in range(tensor.shape[1]): - if num_neuron >= (len(scaled_rank) - top_k) and not \ - self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: - self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True - if scaled_mean[num_neuron] > threshold and not \ - self._model_layer_dict[(name, num_neuron)]: - self._model_layer_dict[(name, num_neuron)] = True - - def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): - """ - Calculate the testing adequacy of the given dataset. + Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network. Args: - dataset (numpy.ndarray): Data for fuzz test. - bias_coefficient (Union[int, float]): The coefficient used - for changing the neurons' output boundaries. Default: 0. - batch_size (int): The number of samples in a predict batch. Default: 32. + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. + + Returns: + float, the metric of 'neuron coverage'. Examples: - >>> neuron_num = 10 - >>> segmented_num = 1000 - >>> batch_size = 32 - >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) - >>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) + >>> nc = NeuronCoverage(model, threshold=0.1) + >>> nc_metrics = nc.get_metrics(test_data) """ - dataset = check_numpy_param('dataset', dataset) - batch_size = check_int_positive('batch_size', batch_size) - bias_coefficient = check_param_multi_types('bias_coefficient', bias_coefficient, [int, float]) - self._lower_bounds -= bias_coefficient*self._var - self._upper_bounds += bias_coefficient*self._var - intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num - batches = dataset.shape[0] // batch_size + batches = math.ceil(dataset.shape[0] / self.batch_size) + if not self.incremental or not self._activate_table: + self._activate_table = self._init_neuron_activate_table(dataset[0:1]) for i in range(batches): - self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) + inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + activate = np.sum(value > self.threshold, axis=0) > 0 + self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) + neuron_coverage = self._activate_rate() + return neuron_coverage + + +class TopKNeuronCoverage(CoverageMetrics): + """ + Calculate the top k activated neurons coverage. Neuron is activated when its output has the top k largest value in + that hidden layers. Top k neurons coverage equals the proportion of activated neurons to total neurons in the + network. + Args: + model (Model): The pre-trained model which waiting for testing. + top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3. + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. + """ + def __init__(self, model, top_k=3, incremental=False, batch_size=32): + super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) + self.top_k = check_int_positive('top_k', top_k) - def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): + def get_metrics(self, dataset): """ - Calculate the effective testing adequacy of the given dataset. - In effective neural coverage, neuron is covered when its output has the top k largest value - in that hidden layers. In neural coverage, neuron is covered when its output is greater than the - threshold. Coverage equals the covered neurons divided by the total neurons in the network. + Get the metric of Top K activated neuron coverage. Args: - threshold (float): neuron is covered when its output is greater than the threshold. - top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. - dataset (numpy.ndarray): Data for fuzz test. + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. + + Returns: + float, the metrics of 'top k neuron coverage'. Examples: - >>> neuron_num = 10 - >>> segmented_num = 1000 - >>> top_k = 3 - >>> threshold = 0.1 - >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) - >>> model_fuzz_test.calculate_coverage(test_images) - >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) + >>> tknc = TopKNeuronCoverage(model, top_k=3) + >>> metrics = tknc.get_metrics(test_data) """ - top_k = check_int_positive('top_k', top_k) dataset = check_numpy_param('dataset', dataset) - batch_size = check_int_positive('batch_size', batch_size) - batches = dataset.shape[0] // batch_size - self._set_init_effective_coverage_table(dataset) + batches = math.ceil(dataset.shape[0] / self.batch_size) + if not self.incremental or not self._activate_table: + self._activate_table = self._init_neuron_activate_table(dataset[0:1]) for i in range(batches): - inputs = dataset[i*batch_size: (i + 1)*batch_size] - self._model.predict(Tensor(inputs)).asnumpy() - tensors = _get_summary_tensor_data() - for name, tensor in tensors.items(): - if 'input' in name: - continue - scaled = tensor.asnumpy()[-1] - if scaled.ndim >= 3: # - scaled_mean = np.mean(scaled, axis=(1, 2)) - scaled_rank = np.argsort(scaled_mean) - self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) - else: - scaled_rank = np.argsort(scaled) - self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) + inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + if len(value.shape) > 2: + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + top_k_value = np.sort(value)[:, -self.top_k].reshape(value.shape[0], 1) + top_k_value = np.sum((value - top_k_value) >= 0, axis=0) > 0 + self._activate_table[layer] = np.logical_or(self._activate_table[layer], top_k_value) + top_k_neuron_coverage = self._activate_rate() + return top_k_neuron_coverage + + +class SuperNeuronActivateCoverage(CoverageMetrics): + """ + Get the metric of 'super neuron activation coverage'. :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the + proportion of neurons whose neurons output value in the test set exceeds the upper bounds of the corresponding + neurons output value in the training set. - def get_nc(self): + Args: + model (Model): The pre-trained model which waiting for testing. + train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. + """ + def __init__(self, model, train_dataset, incremental=False, batch_size=32): + super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size) + train_dataset = check_numpy_param('train_dataset', train_dataset) + self.upper_bounds, self.lower_bounds = self._get_bounds(train_dataset=train_dataset) + + def get_metrics(self, dataset): """ - Get the metric of 'neuron coverage'. + Get the metric of 'strong neuron activation coverage'. + + Args: + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. Returns: - float, the metric of 'neuron coverage'. + float, the metric of 'strong neuron activation coverage'. Examples: - >>> model_fuzz_test.get_nc() + >>> snac = SuperNeuronActivateCoverage(model, train_dataset) + >>> metrics = snac.get_metrics(test_data) """ - covered_neurons = len([v for v in self._model_layer_dict.values() if v]) - total_neurons = len(self._model_layer_dict) - nc = covered_neurons / float(total_neurons) - return nc + dataset = check_numpy_param('dataset', dataset) + if not self.incremental or not self._activate_table: + self._activate_table = self._init_neuron_activate_table(dataset[0:1]) + batches = math.ceil(dataset.shape[0] / self.batch_size) - def get_effective_nc(self): - """ - Get the metric of 'effective neuron coverage'. + for i in range(batches): + inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + if len(value.shape) > 2: + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + activate = np.sum(value > self.upper_bounds[layer], axis=0) > 0 + self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) + snac = self._activate_rate() + return snac - Returns: - float, the metric of 'the effective neuron coverage'. - Examples: - >>> model_fuzz_test.get_effective_nc() - """ - covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) - total_neurons = len(self._effective_model_layer_dict) - effective_nc = covered_neurons / float(total_neurons) - return effective_nc +class NeuronBoundsCoverage(SuperNeuronActivateCoverage): + """ + Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| + |LowerCornerNeuron|)/(2*|N|)`, + where :math`|N|` is the number of neurons, NBC refers to the proportion of neurons whose neurons output value in + the test dataset exceeds the upper and lower bounds of the corresponding neurons output value in the training + dataset. - def get_kmnc(self): - """ - Get the metric of 'k-multisection neuron coverage'. KMNC measures how - thoroughly the given set of test inputs covers the range of neurons - output values derived from training dataset. + Args: + model (Model): The pre-trained model which waiting for testing. + train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. + """ - Returns: - float, the metric of 'k-multisection neuron coverage'. + def __init__(self, model, train_dataset, incremental=False, batch_size=32): + super(NeuronBoundsCoverage, self).__init__(model, train_dataset, incremental=incremental, batch_size=batch_size) - Examples: - >>> model_fuzz_test.get_kmnc() + def get_metrics(self, dataset): """ - kmnc = np.sum(self._main_section_hits) / (self._neuron_num*self._segmented_num) - return kmnc + Get the metric of 'neuron boundary coverage'. - def get_nbc(self): - """ - Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| - + |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons, - NBC refers to the proportion of neurons whose neurons output value in - the test dataset exceeds the upper and lower bounds of the corresponding - neurons output value in the training dataset. + Args: + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. Returns: float, the metric of 'neuron boundary coverage'. Examples: - >>> model_fuzz_test.get_nbc() + >>> nbc = NeuronBoundsCoverage(model, train_dataset) + >>> metrics = nbc.get_metrics(test_data) """ - nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num) + dataset = check_numpy_param('dataset', dataset) + if not self.incremental or not self._activate_table: + self._activate_table = self._init_neuron_activate_table(dataset[0:1]) + + batches = math.ceil(dataset.shape[0] / self.batch_size) + for i in range(batches): + inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + if len(value.shape) > 2: + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + outer = np.logical_or(value > self.upper_bounds[layer], value < self.lower_bounds[layer]) + activate = np.sum(outer, axis=0) > 0 + self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate) + nbc = self._activate_rate() return nbc - def get_snac(self): + +class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage): + """ + Get the metric of 'k-multisection neuron coverage'. KMNC measures how thoroughly the given set of test inputs + covers the range of neurons output values derived from training dataset. + + Args: + model (Model): The pre-trained model which waiting for testing. + train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. + segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100. + incremental (bool): Metrics will be calculate in incremental way or not. Default: False. + batch_size (int): The number of samples in a fuzz test batch. Default: 32. + """ + + def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32): + super(KMultisectionNeuronCoverage, self).__init__(model, train_dataset, incremental=incremental, + batch_size=batch_size) + self.segmented_num = check_int_positive('segmented_num', segmented_num) + self.intervals = defaultdict(list) + for keys in self.upper_bounds.keys(): + self.intervals[keys] = (self.upper_bounds[keys] - self.lower_bounds[keys]) / self.segmented_num + + def _init_k_multisection_table(self, data): + """ Initial the activate table.""" + self._model.predict(Tensor(data)) + layer_out = _get_summary_tensor_data() + activate_section_table = defaultdict() + for layer, value in layer_out.items(): + activate_section_table[layer] = np.zeros((value.shape[1], self.segmented_num), np.bool) + return activate_section_table + + def get_metrics(self, dataset): """ - Get the metric of 'strong neuron activation coverage'. - :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the proportion - of neurons whose neurons output value in the test set exceeds the upper - bounds of the corresponding neurons output value in the training set. + Get the metric of 'k-multisection neuron coverage'. + + Args: + dataset (numpy.ndarray): Dataset used to calculate coverage metrics. Returns: - float, the metric of 'strong neuron activation coverage'. + float, the metric of 'k-multisection neuron coverage'. Examples: - >>> model_fuzz_test.get_snac() + >>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100) + >>> metrics = kmnc.get_metrics(test_data) """ - snac = np.sum(self._upper_corner_hits) / self._neuron_num - return snac + + dataset = check_numpy_param('dataset', dataset) + if not self.incremental or not self._activate_table: + self._activate_table = self._init_k_multisection_table(dataset[0:1]) + + batches = math.ceil(dataset.shape[0] / self.batch_size) + for i in range(batches): + inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size] + self._model.predict(Tensor(inputs)) + layer_out = _get_summary_tensor_data() + for layer, tensor in layer_out.items(): + value = tensor.asnumpy() + value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))])) + hits = np.floor((value - self.lower_bounds[layer]) / self.intervals[layer]).astype(int) + hits = np.transpose(hits, [1, 0]) + for n in range(len(hits)): + for sec in hits[n]: + if sec >= self.segmented_num or sec < 0: + continue + self._activate_table[layer][n][sec] = True + + kmnc = self._activate_rate() / self.segmented_num + return kmnc diff --git a/mindarmour/utils/_check_param.py b/mindarmour/utils/_check_param.py index 14b5dd7..1892268 100644 --- a/mindarmour/utils/_check_param.py +++ b/mindarmour/utils/_check_param.py @@ -39,9 +39,7 @@ def _check_array_not_empty(arg_name, arg_value): def check_param_type(arg_name, arg_value, valid_type): """Check parameter type.""" if not isinstance(arg_value, valid_type): - msg = '{} must be {}, but got {}'.format(arg_name, - valid_type, - type(arg_value).__name__) + msg = '{} must be {}, but got {}'.format(arg_name, valid_type, type(arg_value).__name__) LOGGER.error(TAG, msg) raise TypeError(msg) @@ -51,8 +49,7 @@ def check_param_type(arg_name, arg_value, valid_type): def check_param_multi_types(arg_name, arg_value, valid_types): """Check parameter multi types.""" if not isinstance(arg_value, tuple(valid_types)): - msg = 'type of {} must be in {}, but got {}' \ - .format(arg_name, valid_types, type(arg_value).__name__) + msg = 'type of {} must be in {}, but got {}'.format(arg_name, valid_types, type(arg_value).__name__) LOGGER.error(TAG, msg) raise TypeError(msg) @@ -68,8 +65,7 @@ def check_int_positive(arg_name, arg_value): raise ValueError(msg) arg_value = check_param_type(arg_name, arg_value, int) if arg_value <= 0: - msg = '{} must be greater than 0, but got {}'.format(arg_name, - arg_value) + msg = '{} must be greater than 0, but got {}'.format(arg_name, arg_value) LOGGER.error(TAG, msg) raise ValueError(msg) return arg_value @@ -79,8 +75,7 @@ def check_value_non_negative(arg_name, arg_value): """Check non negative value.""" arg_value = check_param_multi_types(arg_name, arg_value, (int, float)) if float(arg_value) < 0.0: - msg = '{} must not be less than 0, but got {}'.format(arg_name, - arg_value) + msg = '{} must not be less than 0, but got {}'.format(arg_name, arg_value) LOGGER.error(TAG, msg) raise ValueError(msg) return arg_value @@ -90,8 +85,7 @@ def check_value_positive(arg_name, arg_value): """Check positive value.""" arg_value = check_param_multi_types(arg_name, arg_value, (int, float)) if float(arg_value) <= 0.0: - msg = '{} must be greater than zero, but got {}'.format(arg_name, - arg_value) + msg = '{} must be greater than zero, but got {}'.format(arg_name, arg_value) LOGGER.error(TAG, msg) raise ValueError(msg) return arg_value @@ -102,10 +96,7 @@ def check_param_in_range(arg_name, arg_value, lower, upper): Check range of parameter. """ if arg_value <= lower or arg_value >= upper: - msg = '{} must be between {} and {}, but got {}'.format(arg_name, - lower, - upper, - arg_value) + msg = '{} must be between {} and {}, but got {}'.format(arg_name, lower, upper, arg_value) LOGGER.error(TAG, msg) raise ValueError(msg) @@ -129,10 +120,7 @@ def check_model(model_name, model, model_type): """ if isinstance(model, model_type): return model - msg = '{} should be an instance of {}, but got {}' \ - .format(model_name, - model_type, - type(model).__name__) + msg = '{} should be an instance of {}, but got {}'.format(model_name, model_type, type(model).__name__) LOGGER.error(TAG, msg) raise TypeError(msg) @@ -175,11 +163,9 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): labels (numpy.ndarray): Labels of `inputs`. Returns: - - numpy.ndarray, if `inputs` 's dimension equals to - `labels`, return inputs with type of numpy.ndarray. + - numpy.ndarray, if `inputs` 's dimension equals to `labels`, return inputs with type of numpy.ndarray. - - numpy.ndarray, if `inputs` 's dimension equals to - `labels` , return labels with type of numpy.ndarray. + - numpy.ndarray, if `inputs` 's dimension equals to `labels` , return labels with type of numpy.ndarray. Raises: ValueError: If inputs.shape[0] is not equal to labels.shape[0]. @@ -188,8 +174,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): labels = check_numpy_param(labels_name, labels) if inputs.shape[0] != labels.shape[0]: msg = '{} shape[0] must equal {} shape[0], bot got shape of ' \ - 'inputs {}, shape of labels {}'.format(inputs_name, labels_name, - inputs.shape, labels.shape) + 'inputs {}, shape of labels {}'.format(inputs_name, labels_name, inputs.shape, labels.shape) LOGGER.error(TAG, msg) raise ValueError(msg) return inputs, labels @@ -198,10 +183,8 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels): def check_equal_length(para_name1, value1, para_name2, value2): """Check weather the two parameters have equal length.""" if len(value1) != len(value2): - msg = 'The dimension of {0} must equal to the ' \ - '{1}, but got {0} is {2}, ' \ - '{1} is {3}'.format(para_name1, para_name2, len(value1), - len(value2)) + msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\ + .format(para_name1, para_name2, len(value1), len(value2)) LOGGER.error(TAG, msg) raise ValueError(msg) return value1, value2 @@ -210,10 +193,8 @@ def check_equal_length(para_name1, value1, para_name2, value2): def check_equal_shape(para_name1, value1, para_name2, value2): """Check weather the two parameters have equal shape.""" if value1.shape != value2.shape: - msg = 'The shape of {0} must equal to the ' \ - '{1}, but got {0} is {2}, ' \ - '{1} is {3}'.format(para_name1, para_name2, value1.shape, - value2.shape) + msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\ + format(para_name1, para_name2, value1.shape, value2.shape) LOGGER.error(TAG, msg) raise ValueError(msg) return value1, value2 @@ -225,8 +206,7 @@ def check_norm_level(norm_level): msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level)) accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf] if norm_level not in accept_norm: - msg = 'norm_level must be in {}, but got {}'.format(accept_norm, - norm_level) + msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level) LOGGER.error(TAG, msg) raise ValueError(msg) return norm_level @@ -252,20 +232,16 @@ def normalize_value(value, norm_level): value_reshape = value.reshape((value.shape[0], -1)) avoid_zero_div = 1e-12 if norm_level in (1, '1', 'l1'): - norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + \ - avoid_zero_div + norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm elif norm_level in (2, '2', 'l2'): - norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + \ - avoid_zero_div + norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm elif norm_level in (np.inf, 'inf'): - norm = np.max(abs(value_reshape), axis=1, keepdims=True) + \ - avoid_zero_div + norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div norm_value = value_reshape / norm else: - msg = 'Values of `norm_level` different from 1, 2 and ' \ - '`np.inf` are currently not supported, but got {}.' \ + msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \ .format(norm_level) LOGGER.error(TAG, msg) raise NotImplementedError(msg) @@ -339,13 +315,30 @@ def check_inputs_labels(inputs, labels): inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs if isinstance(inputs, tuple): for i, inputs_item in enumerate(inputs): - _ = check_pair_numpy_param('inputs_image', inputs_image, \ - 'inputs[{}]'.format(i), inputs_item) + _ = check_pair_numpy_param('inputs_image', inputs_image, 'inputs[{}]'.format(i), inputs_item) if isinstance(labels, tuple): for i, labels_item in enumerate(labels): - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels[{}]'.format(i), labels_item) + _ = check_pair_numpy_param('inputs', inputs_image, 'labels[{}]'.format(i), labels_item) else: - _ = check_pair_numpy_param('inputs', inputs_image, \ - 'labels', labels) + _ = check_pair_numpy_param('inputs', inputs_image, 'labels', labels) return inputs_image, inputs, labels + + +def check_param_bounds(arg_name, arg_value): + """Check bounds is valid""" + arg_value = check_param_multi_types(arg_name, arg_value, [tuple, list]) + if len(arg_value) != 2: + msg = 'length of {0} must be 2, but got length of {0} is {1}'.format(arg_name, len(arg_value)) + LOGGER.error(TAG, msg) + raise ValueError(msg) + for i, b in enumerate(arg_value): + if not isinstance(b, (float, int)): + msg = 'each value in {} must be int or float, but got the {}th value is {}'.format(arg_name, i, b) + LOGGER.error(TAG, msg) + raise ValueError(msg) + if arg_value[0] > arg_value[1]: + msg = "lower boundary cannot be greater than upper boundary, corresponding values in {} are {} and {}". \ + format(arg_name, arg_value[0], arg_value[1]) + LOGGER.error(TAG, msg) + raise ValueError(msg) + return arg_value diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index 95f3ee6..61f34b4 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -25,7 +25,8 @@ from mindspore.ops import TensorSummary from mindarmour.adv_robustness.attacks import FastGradientSignMethod from mindarmour.utils.logger import LogUtil -from mindarmour.fuzz_testing import ModelCoverageMetrics +from mindarmour.fuzz_testing import NeuronCoverage, TopKNeuronCoverage, SuperNeuronActivateCoverage, \ + NeuronBoundsCoverage, KMultisectionNeuronCoverage LOGGER = LogUtil.get_instance() TAG = 'Neuron coverage test' @@ -74,39 +75,48 @@ def test_lenet_mnist_coverage_cpu(): model = Model(net) # initialize fuzz test with training dataset - neuron_num = 10 - segmented_num = 1000 - top_k = 3 - threshold = 0.1 training_data = (np.random.random((10000, 10))*20).astype(np.float32) - model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) - # fuzz test with original test data # get test data test_data = (np.random.random((2000, 10))*20).astype(np.float32) test_labels = np.random.randint(0, 10, 2000).astype(np.int32) - model_fuzz_test.calculate_coverage(test_data) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) - model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + nc = NeuronCoverage(model, threshold=0.1) + nc_metric = nc.get_metrics(test_data) + + tknc = TopKNeuronCoverage(model, top_k=3) + tknc_metrics = tknc.get_metrics(test_data) + + snac = SuperNeuronActivateCoverage(model, training_data) + snac_metrics = snac.get_metrics(test_data) + + nbc = NeuronBoundsCoverage(model, training_data) + nbc_metrics = nbc.get_metrics(test_data) + + kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) + kmnc_metrics = kmnc.get_metrics(test_data) + + print('KMNC of this test is: ', kmnc_metrics) + print('NBC of this test is: ', nbc_metrics) + print('SNAC of this test is: ', snac_metrics) + print('NC of this test is: ', nc_metric) + print('TKNC of this test is: ', tknc_metrics) # generate adv_data loss = SoftmaxCrossEntropyWithLogits(sparse=True) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) - model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) - - model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + nc_metric = nc.get_metrics(adv_data) + tknc_metrics = tknc.get_metrics(adv_data) + snac_metrics = snac.get_metrics(adv_data) + nbc_metrics = nbc.get_metrics(adv_data) + kmnc_metrics = kmnc.get_metrics(adv_data) + print('KMNC of adv data is: ', kmnc_metrics) + print('NBC of adv data is: ', nbc_metrics) + print('SNAC of adv data is: ', snac_metrics) + print('NC of adv data is: ', nc_metric) + print('TKNC of adv data is: ', tknc_metrics) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -120,35 +130,28 @@ def test_lenet_mnist_coverage_ascend(): model = Model(net) # initialize fuzz test with training dataset - neuron_num = 10 - segmented_num = 1000 - top_k = 3 - threshold = 0.1 training_data = (np.random.random((10000, 10))*20).astype(np.float32) - model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) # fuzz test with original test data # get test data test_data = (np.random.random((2000, 10))*20).astype(np.float32) - test_labels = np.random.randint(0, 10, 2000) - test_labels = (np.eye(10)[test_labels]).astype(np.float32) - model_fuzz_test.calculate_coverage(test_data) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + nc = NeuronCoverage(model, threshold=0.1) + nc_metric = nc.get_metrics(test_data) - model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + tknc = TopKNeuronCoverage(model, top_k=3) + tknc_metrics = tknc.get_metrics(test_data) - # generate adv_data - attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) - adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) - model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) - - model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) - LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) - LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + snac = SuperNeuronActivateCoverage(model, training_data) + snac_metrics = snac.get_metrics(test_data) + + nbc = NeuronBoundsCoverage(model, training_data) + nbc_metrics = nbc.get_metrics(test_data) + + kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100) + kmnc_metrics = kmnc.get_metrics(test_data) + + print('KMNC of this test is: ', kmnc_metrics) + print('NBC of this test is: ', nbc_metrics) + print('SNAC of this test is: ', snac_metrics) + print('NC of this test is: ', nc_metric) + print('TKNC of this test is: ', tknc_metrics) diff --git a/tests/ut/python/fuzzing/test_fuzzer.py b/tests/ut/python/fuzzing/test_fuzzer.py index 5066bcf..1d585d4 100644 --- a/tests/ut/python/fuzzing/test_fuzzer.py +++ b/tests/ut/python/fuzzing/test_fuzzer.py @@ -21,9 +21,10 @@ from mindspore import nn from mindspore.common.initializer import TruncatedNormal from mindspore.ops import operations as P from mindspore.train import Model +from mindspore.ops import TensorSummary from mindarmour.fuzz_testing import Fuzzer -from mindarmour.fuzz_testing import ModelCoverageMetrics +from mindarmour.fuzz_testing import KMultisectionNeuronCoverage from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() @@ -52,30 +53,37 @@ class Net(nn.Cell): """ Lenet network """ + def __init__(self): super(Net, self).__init__() self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16*5*5, 120) + self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc2 = fc_with_initialize(120, 84) self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.reshape = P.Reshape() + self.summary = TensorSummary() def construct(self, x): x = self.conv1(x) x = self.relu(x) + self.summary('conv1', x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) + self.summary('conv2', x) x = self.max_pool2d(x) - x = self.reshape(x, (-1, 16*5*5)) + x = self.reshape(x, (-1, 16 * 5 * 5)) x = self.fc1(x) x = self.relu(x) + self.summary('fc1', x) x = self.fc2(x) x = self.relu(x) + self.summary('fc2', x) x = self.fc3(x) + self.summary('fc3', x) return x @@ -100,12 +108,8 @@ def test_fuzzing_ascend(): {'method': 'FGSM', 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} ] - # initialize fuzz test with training dataset - neuron_num = 10 - segmented_num = 1000 - train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) + train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) # fuzz test with original test data # get test data test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) @@ -118,13 +122,12 @@ def test_fuzzing_ascend(): initial_seeds.append([img, label]) initial_seeds = initial_seeds[:100] - model_coverage_test.calculate_coverage( - np.array(test_images[:100]).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) - model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) + nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + cn_metrics = nc.get_metrics(test_images[:100]) + print('neuron coverage of initial seeds is: ', cn_metrics) + model_fuzz_test = Fuzzer(model) + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100) print(metrics) @@ -139,8 +142,6 @@ def test_fuzzing_cpu(): model = Model(net) batch_size = 8 num_classe = 10 - neuron_num = 10 - segmented_num = 1000 mutate_config = [{'method': 'Blur', 'params': {'auto_param': [True]}}, {'method': 'Contrast', @@ -152,7 +153,6 @@ def test_fuzzing_cpu(): ] # initialize fuzz test with training dataset train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -166,11 +166,9 @@ def test_fuzzing_cpu(): initial_seeds.append([img, label]) initial_seeds = initial_seeds[:100] - model_coverage_test.calculate_coverage( - np.array(test_images[:100]).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) - - model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds) + nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + tknc_metrics = nc.get_metrics(test_images[:100]) + print('neuron coverage of initial seeds is: ', tknc_metrics) + model_fuzz_test = Fuzzer(model) + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100) print(metrics)