From 72a3f6bd6c361e2b862db62e1c133df0be26fb4d Mon Sep 17 00:00:00 2001 From: RyanZ Date: Tue, 8 Jun 2021 10:27:45 +0800 Subject: [PATCH] Add feature of NC and Effective NC to test coverag Signed-off-by: zhengyang (H) add feature of NC and effective NC to test coverage Signed-off-by: zhengyang (H) --- examples/ai_fuzzer/README.md | 28 +++-- .../ai_fuzzer/fuzz_testing_and_model_enhense.py | 6 +- examples/ai_fuzzer/lenet5_mnist_coverage.py | 16 ++- examples/ai_fuzzer/lenet5_mnist_fuzzing.py | 2 +- .../networks/lenet5/lenet5_net_for_fuzzing.py | 99 +++++++++++++++++ mindarmour/fuzz_testing/model_coverage_metrics.py | 120 ++++++++++++++++++++- tests/ut/python/fuzzing/test_coverage_metrics.py | 28 ++++- 7 files changed, 281 insertions(+), 18 deletions(-) create mode 100644 examples/common/networks/lenet5/lenet5_net_for_fuzzing.py diff --git a/examples/ai_fuzzer/README.md b/examples/ai_fuzzer/README.md index 22405b4..598e33a 100644 --- a/examples/ai_fuzzer/README.md +++ b/examples/ai_fuzzer/README.md @@ -1,24 +1,32 @@ # Application demos of model fuzzing + ## Introduction + The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to branch coverage or line coverage of traditional software, some people propose the concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage as a guide to search more metamorphic inputs to test our models. -## 1. calculation of neuron coverage -There are three metrics proposed for evaluating the neuron coverage of a test:KMNC, NBC and SNAC. Usually we need to - feed all the training dataset into the model first, and record the output range of all neurons (however, only the last - layer of neurons are recorded in our method). In the testing phase, we feed test samples into the model, and - calculate those three metrics mentioned above according to those neurons' output distribution. +## 1. calculation of neuron coverage + +There are five metrics proposed for evaluating the neuron coverage of a test:NC, Effective NC, KMNC, NBC and SNAC. + Usually we need to feed all the training dataset into the model first, and record the output range of all neurons + (however, in KMNC, NBC and SNAC, only the last layer of neurons are recorded in our method). In the testing phase, + we feed test samples into the model, and calculate those three metrics mentioned above according to those neurons' + output distribution. + ```sh -$ cd examples/ai_fuzzer/ -$ python lenet5_mnist_coverage.py +cd examples/ai_fuzzer/ +python lenet5_mnist_coverage.py ``` -## 2. fuzz test for AI model + +## 2. fuzz test for AI model + We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and then metamorphic samples are generated through iterative manipulations. + ```sh -$ cd examples/ai_fuzzer/ -$ python lenet5_mnist_fuzzing.py +cd examples/ai_fuzzer/ +python lenet5_mnist_fuzzing.py ``` \ No newline at end of file diff --git a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py index d958542..84b3e30 100644 --- a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py +++ b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py @@ -31,7 +31,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset -from examples.common.networks.lenet5.lenet5_net import LeNet5 +from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 LOGGER = LogUtil.get_instance() TAG = 'Fuzz_testing and enhance model' @@ -75,9 +75,11 @@ def example_lenet_mnist_fuzzing(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) + neuron_num = 10 + segmented_num = 1000 # initialize fuzz test with training dataset - model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) + model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data diff --git a/examples/ai_fuzzer/lenet5_mnist_coverage.py b/examples/ai_fuzzer/lenet5_mnist_coverage.py index 355c95b..4e77a32 100644 --- a/examples/ai_fuzzer/lenet5_mnist_coverage.py +++ b/examples/ai_fuzzer/lenet5_mnist_coverage.py @@ -22,7 +22,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset -from examples.common.networks.lenet5.lenet5_net import LeNet5 +from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 LOGGER = LogUtil.get_instance() TAG = 'Neuron coverage test' @@ -46,9 +46,13 @@ def test_lenet_mnist_coverage(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) + neuron_num = 10 + segmented_num = 1000 + top_k = 3 + threshold = 0.1 # initialize fuzz test with training dataset - model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) + model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) # fuzz test with original test data # get test data @@ -69,6 +73,10 @@ def test_lenet_mnist_coverage(): LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + # generate adv_data loss = SoftmaxCrossEntropyWithLogits(sparse=True) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) @@ -78,6 +86,10 @@ def test_lenet_mnist_coverage(): LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index 3f92dea..7707c44 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -21,7 +21,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset -from examples.common.networks.lenet5.lenet5_net import LeNet5 +from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 LOGGER = LogUtil.get_instance() TAG = 'Fuzz_test' diff --git a/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py b/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py new file mode 100644 index 0000000..803edd0 --- /dev/null +++ b/examples/common/networks/lenet5/lenet5_net_for_fuzzing.py @@ -0,0 +1,99 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +lenet network with summary +""" +from mindspore import nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import TensorSummary + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + return TruncatedNormal(0.05) + + +class LeNet5(nn.Cell): + """ + Lenet network + """ + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16*5*5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + self.summary = TensorSummary() + + def construct(self, x): + """ + construct the network architecture + Returns: + x (tensor): network output + """ + self.summary('input', x) + + x = self.conv1(x) + self.summary('1', x) + + x = self.relu(x) + self.summary('2', x) + + x = self.max_pool2d(x) + self.summary('3', x) + + x = self.conv2(x) + self.summary('4', x) + + x = self.relu(x) + self.summary('5', x) + + x = self.max_pool2d(x) + self.summary('6', x) + + x = self.flatten(x) + self.summary('7', x) + + x = self.fc1(x) + self.summary('8', x) + + x = self.relu(x) + self.summary('9', x) + + x = self.fc2(x) + self.summary('10', x) + + x = self.relu(x) + self.summary('11', x) + + x = self.fc3(x) + self.summary('output', x) + return x diff --git a/mindarmour/fuzz_testing/model_coverage_metrics.py b/mindarmour/fuzz_testing/model_coverage_metrics.py index f6a5388..24482fa 100644 --- a/mindarmour/fuzz_testing/model_coverage_metrics.py +++ b/mindarmour/fuzz_testing/model_coverage_metrics.py @@ -15,10 +15,12 @@ Model-Test Coverage Metrics. """ +from collections import defaultdict import numpy as np from mindspore import Tensor from mindspore import Model +from mindspore.train.summary.summary_record import _get_summary_tensor_data from mindarmour.utils._check_param import check_model, check_numpy_param, \ check_int_positive, check_param_multi_types @@ -63,6 +65,9 @@ class ModelCoverageMetrics: >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) + >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) + >>> print('NC of this test is : %s', model_fuzz_test.get_nc()) + >>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) """ def __init__(self, model, neuron_num, segmented_num, train_dataset): @@ -81,6 +86,24 @@ class ModelCoverageMetrics: self._lower_corner_hits = [0]*self._neuron_num self._upper_corner_hits = [0]*self._neuron_num self._bounds_get(train_dataset) + self._model_layer_dict = defaultdict(bool) + self._effective_model_layer_dict = defaultdict(bool) + + def _set_init_effective_coverage_table(self, dataset): + """ + Initialise the coverage table of each neuron in the model. + + Args: + dataset (numpy.ndarray): Dataset used for initialising the coverage table. + """ + self._model.predict(Tensor(dataset[0:1])) + tensors = _get_summary_tensor_data() + for name, tensor in tensors.items(): + if 'input' in name: + continue + for num_neuron in range(tensor.shape[1]): + self._model_layer_dict[(name, num_neuron)] = False + self._effective_model_layer_dict[(name, num_neuron)] = False def _bounds_get(self, train_dataset, batch_size=32): """ @@ -130,6 +153,27 @@ class ModelCoverageMetrics: else: self._main_section_hits[i][int(section_indexes[i])] = 1 + def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold): + """ + Update the coverage matrix of neural coverage and effective neural coverage. + + Args: + name (string): the name of the tensor. + tensor (tensor): the tensor in the network. + scaled_mean (numpy.ndarray): feature map of the tensor. + scaled_rank (numpy.ndarray): rank of tensor value. + top_k (int): neuron is covered when its output has the top k largest value in that hidden layer. + threshold (float): neuron is covered when its output is greater than the threshold. + + """ + for num_neuron in range(tensor.shape[1]): + if num_neuron >= (len(scaled_rank) - top_k) and not \ + self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]: + self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True + if scaled_mean[num_neuron] > threshold and not \ + self._model_layer_dict[(name, num_neuron)]: + self._model_layer_dict[(name, num_neuron)] = True + def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): """ Calculate the testing adequacy of the given dataset. @@ -143,8 +187,9 @@ class ModelCoverageMetrics: Examples: >>> neuron_num = 10 >>> segmented_num = 1000 + >>> batch_size = 32 >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) - >>> model_fuzz_test.calculate_coverage(test_images) + >>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size) """ dataset = check_numpy_param('dataset', dataset) @@ -157,6 +202,79 @@ class ModelCoverageMetrics: for i in range(batches): self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals) + + def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32): + """ + Calculate the effective testing adequacy of the given dataset. + In effective neural coverage, neuron is covered when its output has the top k largest value + in that hidden layers. In neural coverage, neuron is covered when its output is greater than the + threshold. Coverage equals the covered neurons divided by the total neurons in the network. + + Args: + threshold (float): neuron is covered when its output is greater than the threshold. + top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer. + dataset (numpy.ndarray): Data for fuzz test. + + Examples: + >>> neuron_num = 10 + >>> segmented_num = 1000 + >>> top_k = 3 + >>> threshold = 0.1 + >>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) + >>> model_fuzz_test.calculate_coverage(test_images) + >>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold) + """ + top_k = check_int_positive('top_k', top_k) + dataset = check_numpy_param('dataset', dataset) + batch_size = check_int_positive('batch_size', batch_size) + batches = dataset.shape[0] // batch_size + self._set_init_effective_coverage_table(dataset) + for i in range(batches): + inputs = dataset[i*batch_size: (i + 1)*batch_size] + self._model.predict(Tensor(inputs)).asnumpy() + tensors = _get_summary_tensor_data() + for name, tensor in tensors.items(): + if 'input' in name: + continue + scaled = tensor.asnumpy()[-1] + if scaled.ndim >= 3: # + scaled_mean = np.mean(scaled, axis=(1, 2)) + scaled_rank = np.argsort(scaled_mean) + self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold) + else: + scaled_rank = np.argsort(scaled) + self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold) + + def get_nc(self): + """ + Get the metric of 'neuron coverage'. + + Returns: + float, the metric of 'neuron coverage'. + + Examples: + >>> model_fuzz_test.get_nc() + """ + covered_neurons = len([v for v in self._model_layer_dict.values() if v]) + total_neurons = len(self._model_layer_dict) + nc = covered_neurons / float(total_neurons) + return nc + + def get_effective_nc(self): + """ + Get the metric of 'effective neuron coverage'. + + Returns: + float, the metric of 'the effective neuron coverage'. + + Examples: + >>> model_fuzz_test.get_effective_nc() + """ + covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v]) + total_neurons = len(self._effective_model_layer_dict) + effective_nc = covered_neurons / float(total_neurons) + return effective_nc + def get_kmnc(self): """ Get the metric of 'k-multisection neuron coverage'. KMNC measures how diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index b8a7287..95f3ee6 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -21,6 +21,7 @@ from mindspore import nn from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits from mindspore.train import Model from mindspore import context +from mindspore.ops import TensorSummary from mindarmour.adv_robustness.attacks import FastGradientSignMethod from mindarmour.utils.logger import LogUtil @@ -46,6 +47,7 @@ class Net(Cell): """ super(Net, self).__init__() self._relu = nn.ReLU() + self.summary = TensorSummary() def construct(self, inputs): """ @@ -54,7 +56,10 @@ class Net(Cell): Args: inputs (Tensor): Input data. """ + self.summary('input', inputs) + out = self._relu(inputs) + self.summary('1', out) return out @@ -71,7 +76,10 @@ def test_lenet_mnist_coverage_cpu(): # initialize fuzz test with training dataset neuron_num = 10 segmented_num = 1000 + top_k = 3 + threshold = 0.1 training_data = (np.random.random((10000, 10))*20).astype(np.float32) + model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) # fuzz test with original test data @@ -83,6 +91,10 @@ def test_lenet_mnist_coverage_cpu(): LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + # generate adv_data loss = SoftmaxCrossEntropyWithLogits(sparse=True) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) @@ -92,6 +104,9 @@ def test_lenet_mnist_coverage_cpu(): LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -107,9 +122,10 @@ def test_lenet_mnist_coverage_ascend(): # initialize fuzz test with training dataset neuron_num = 10 segmented_num = 1000 + top_k = 3 + threshold = 0.1 training_data = (np.random.random((10000, 10))*20).astype(np.float32) - - model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,) + model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data) # fuzz test with original test data # get test data @@ -121,6 +137,10 @@ def test_lenet_mnist_coverage_ascend(): LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc()) + # generate adv_data attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False)) adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) @@ -128,3 +148,7 @@ def test_lenet_mnist_coverage_ascend(): LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) + + model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold) + LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc()) + LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())