Browse Source

!254 add feature of NC and Effective NC in test coverage

Merge pull request !254 from RyanZ/test_coverage3
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
9eaa3dc99d
7 changed files with 281 additions and 18 deletions
  1. +18
    -10
      examples/ai_fuzzer/README.md
  2. +4
    -2
      examples/ai_fuzzer/fuzz_testing_and_model_enhense.py
  3. +14
    -2
      examples/ai_fuzzer/lenet5_mnist_coverage.py
  4. +1
    -1
      examples/ai_fuzzer/lenet5_mnist_fuzzing.py
  5. +99
    -0
      examples/common/networks/lenet5/lenet5_net_for_fuzzing.py
  6. +119
    -1
      mindarmour/fuzz_testing/model_coverage_metrics.py
  7. +26
    -2
      tests/ut/python/fuzzing/test_coverage_metrics.py

+ 18
- 10
examples/ai_fuzzer/README.md View File

@@ -1,24 +1,32 @@
# Application demos of model fuzzing

## Introduction

The same as the traditional software fuzz testing, we can also design fuzz test for AI models. Compared to
branch coverage or line coverage of traditional software, some people propose the
concept of 'neuron coverage' based on the unique structure of deep neural network. We can use the neuron coverage
as a guide to search more metamorphic inputs to test our models.

## 1. calculation of neuron coverage
There are three metrics proposed for evaluating the neuron coverage of a test:KMNC, NBC and SNAC. Usually we need to
feed all the training dataset into the model first, and record the output range of all neurons (however, only the last
layer of neurons are recorded in our method). In the testing phase, we feed test samples into the model, and
calculate those three metrics mentioned above according to those neurons' output distribution.
## 1. calculation of neuron coverage

There are five metrics proposed for evaluating the neuron coverage of a test:NC, Effective NC, KMNC, NBC and SNAC.
Usually we need to feed all the training dataset into the model first, and record the output range of all neurons
(however, in KMNC, NBC and SNAC, only the last layer of neurons are recorded in our method). In the testing phase,
we feed test samples into the model, and calculate those three metrics mentioned above according to those neurons'
output distribution.

```sh
$ cd examples/ai_fuzzer/
$ python lenet5_mnist_coverage.py
cd examples/ai_fuzzer/
python lenet5_mnist_coverage.py
```
## 2. fuzz test for AI model

## 2. fuzz test for AI model

We have provided several types of methods for manipulating metamorphic inputs: affine transformation, pixel
transformation and adversarial attacks. Usually we feed the original samples into the fuzz function as seeds, and
then metamorphic samples are generated through iterative manipulations.

```sh
$ cd examples/ai_fuzzer/
$ python lenet5_mnist_fuzzing.py
cd examples/ai_fuzzer/
python lenet5_mnist_fuzzing.py
```

+ 4
- 2
examples/ai_fuzzer/fuzz_testing_and_model_enhense.py View File

@@ -31,7 +31,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5

LOGGER = LogUtil.get_instance()
TAG = 'Fuzz_testing and enhance model'
@@ -75,9 +75,11 @@ def example_lenet_mnist_fuzzing():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000

# initialize fuzz test with training dataset
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data


+ 14
- 2
examples/ai_fuzzer/lenet5_mnist_coverage.py View File

@@ -22,7 +22,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5

LOGGER = LogUtil.get_instance()
TAG = 'Neuron coverage test'
@@ -46,9 +46,13 @@ def test_lenet_mnist_coverage():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1

# initialize fuzz test with training dataset
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images)
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -69,6 +73,10 @@ def test_lenet_mnist_coverage():
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())

# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
@@ -78,6 +86,10 @@ def test_lenet_mnist_coverage():
LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())


if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"


+ 1
- 1
examples/ai_fuzzer/lenet5_mnist_fuzzing.py View File

@@ -21,7 +21,7 @@ from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5

LOGGER = LogUtil.get_instance()
TAG = 'Fuzz_test'


+ 99
- 0
examples/common/networks/lenet5/lenet5_net_for_fuzzing.py View File

@@ -0,0 +1,99 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
lenet network with summary
"""
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import TensorSummary


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")


def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)


def weight_variable():
return TruncatedNormal(0.05)


class LeNet5(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()

self.summary = TensorSummary()

def construct(self, x):
"""
construct the network architecture
Returns:
x (tensor): network output
"""
self.summary('input', x)

x = self.conv1(x)
self.summary('1', x)

x = self.relu(x)
self.summary('2', x)

x = self.max_pool2d(x)
self.summary('3', x)

x = self.conv2(x)
self.summary('4', x)

x = self.relu(x)
self.summary('5', x)

x = self.max_pool2d(x)
self.summary('6', x)

x = self.flatten(x)
self.summary('7', x)

x = self.fc1(x)
self.summary('8', x)

x = self.relu(x)
self.summary('9', x)

x = self.fc2(x)
self.summary('10', x)

x = self.relu(x)
self.summary('11', x)

x = self.fc3(x)
self.summary('output', x)
return x

+ 119
- 1
mindarmour/fuzz_testing/model_coverage_metrics.py View File

@@ -15,10 +15,12 @@
Model-Test Coverage Metrics.
"""

from collections import defaultdict
import numpy as np

from mindspore import Tensor
from mindspore import Model
from mindspore.train.summary.summary_record import _get_summary_tensor_data

from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive, check_param_multi_types
@@ -63,6 +65,9 @@ class ModelCoverageMetrics:
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc())
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc())
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac())
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
>>> print('NC of this test is : %s', model_fuzz_test.get_nc())
>>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
"""

def __init__(self, model, neuron_num, segmented_num, train_dataset):
@@ -81,6 +86,24 @@ class ModelCoverageMetrics:
self._lower_corner_hits = [0]*self._neuron_num
self._upper_corner_hits = [0]*self._neuron_num
self._bounds_get(train_dataset)
self._model_layer_dict = defaultdict(bool)
self._effective_model_layer_dict = defaultdict(bool)

def _set_init_effective_coverage_table(self, dataset):
"""
Initialise the coverage table of each neuron in the model.

Args:
dataset (numpy.ndarray): Dataset used for initialising the coverage table.
"""
self._model.predict(Tensor(dataset[0:1]))
tensors = _get_summary_tensor_data()
for name, tensor in tensors.items():
if 'input' in name:
continue
for num_neuron in range(tensor.shape[1]):
self._model_layer_dict[(name, num_neuron)] = False
self._effective_model_layer_dict[(name, num_neuron)] = False

def _bounds_get(self, train_dataset, batch_size=32):
"""
@@ -130,6 +153,27 @@ class ModelCoverageMetrics:
else:
self._main_section_hits[i][int(section_indexes[i])] = 1

def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold):
"""
Update the coverage matrix of neural coverage and effective neural coverage.

Args:
name (string): the name of the tensor.
tensor (tensor): the tensor in the network.
scaled_mean (numpy.ndarray): feature map of the tensor.
scaled_rank (numpy.ndarray): rank of tensor value.
top_k (int): neuron is covered when its output has the top k largest value in that hidden layer.
threshold (float): neuron is covered when its output is greater than the threshold.

"""
for num_neuron in range(tensor.shape[1]):
if num_neuron >= (len(scaled_rank) - top_k) and not \
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]:
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True
if scaled_mean[num_neuron] > threshold and not \
self._model_layer_dict[(name, num_neuron)]:
self._model_layer_dict[(name, num_neuron)] = True

def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32):
"""
Calculate the testing adequacy of the given dataset.
@@ -143,8 +187,9 @@ class ModelCoverageMetrics:
Examples:
>>> neuron_num = 10
>>> segmented_num = 1000
>>> batch_size = 32
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
>>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size)
"""

dataset = check_numpy_param('dataset', dataset)
@@ -157,6 +202,79 @@ class ModelCoverageMetrics:
for i in range(batches):
self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals)


def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32):
"""
Calculate the effective testing adequacy of the given dataset.
In effective neural coverage, neuron is covered when its output has the top k largest value
in that hidden layers. In neural coverage, neuron is covered when its output is greater than the
threshold. Coverage equals the covered neurons divided by the total neurons in the network.

Args:
threshold (float): neuron is covered when its output is greater than the threshold.
top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer.
dataset (numpy.ndarray): Data for fuzz test.

Examples:
>>> neuron_num = 10
>>> segmented_num = 1000
>>> top_k = 3
>>> threshold = 0.1
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
"""
top_k = check_int_positive('top_k', top_k)
dataset = check_numpy_param('dataset', dataset)
batch_size = check_int_positive('batch_size', batch_size)
batches = dataset.shape[0] // batch_size
self._set_init_effective_coverage_table(dataset)
for i in range(batches):
inputs = dataset[i*batch_size: (i + 1)*batch_size]
self._model.predict(Tensor(inputs)).asnumpy()
tensors = _get_summary_tensor_data()
for name, tensor in tensors.items():
if 'input' in name:
continue
scaled = tensor.asnumpy()[-1]
if scaled.ndim >= 3: #
scaled_mean = np.mean(scaled, axis=(1, 2))
scaled_rank = np.argsort(scaled_mean)
self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold)
else:
scaled_rank = np.argsort(scaled)
self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold)

def get_nc(self):
"""
Get the metric of 'neuron coverage'.

Returns:
float, the metric of 'neuron coverage'.

Examples:
>>> model_fuzz_test.get_nc()
"""
covered_neurons = len([v for v in self._model_layer_dict.values() if v])
total_neurons = len(self._model_layer_dict)
nc = covered_neurons / float(total_neurons)
return nc

def get_effective_nc(self):
"""
Get the metric of 'effective neuron coverage'.

Returns:
float, the metric of 'the effective neuron coverage'.

Examples:
>>> model_fuzz_test.get_effective_nc()
"""
covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v])
total_neurons = len(self._effective_model_layer_dict)
effective_nc = covered_neurons / float(total_neurons)
return effective_nc

def get_kmnc(self):
"""
Get the metric of 'k-multisection neuron coverage'. KMNC measures how


+ 26
- 2
tests/ut/python/fuzzing/test_coverage_metrics.py View File

@@ -21,6 +21,7 @@ from mindspore import nn
from mindspore.nn import Cell, SoftmaxCrossEntropyWithLogits
from mindspore.train import Model
from mindspore import context
from mindspore.ops import TensorSummary

from mindarmour.adv_robustness.attacks import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
@@ -46,6 +47,7 @@ class Net(Cell):
"""
super(Net, self).__init__()
self._relu = nn.ReLU()
self.summary = TensorSummary()

def construct(self, inputs):
"""
@@ -54,7 +56,10 @@ class Net(Cell):
Args:
inputs (Tensor): Input data.
"""
self.summary('input', inputs)

out = self._relu(inputs)
self.summary('1', out)
return out


@@ -71,7 +76,10 @@ def test_lenet_mnist_coverage_cpu():
# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1
training_data = (np.random.random((10000, 10))*20).astype(np.float32)

model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data)

# fuzz test with original test data
@@ -83,6 +91,10 @@ def test_lenet_mnist_coverage_cpu():
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())

# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
@@ -92,6 +104,9 @@ def test_lenet_mnist_coverage_cpu():
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@@ -107,9 +122,10 @@ def test_lenet_mnist_coverage_ascend():
# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1
training_data = (np.random.random((10000, 10))*20).astype(np.float32)

model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data,)
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data)

# fuzz test with original test data
# get test data
@@ -121,6 +137,10 @@ def test_lenet_mnist_coverage_ascend():
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())

# generate adv_data
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
@@ -128,3 +148,7 @@ def test_lenet_mnist_coverage_ascend():
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())

Loading…
Cancel
Save