Browse Source

reconstruct AI fuzzer and Model Neuron Coverages

tags/v1.6.0
ZhidanLiu 3 years ago
parent
commit
9fd4bfd2b9
9 changed files with 567 additions and 549 deletions
  1. +22
    -30
      examples/ai_fuzzer/lenet5_mnist_coverage.py
  2. +13
    -19
      examples/ai_fuzzer/lenet5_mnist_fuzzing.py
  3. +4
    -5
      examples/common/networks/lenet5/lenet5_net_for_fuzzing.py
  4. +8
    -2
      mindarmour/fuzz_testing/__init__.py
  5. +98
    -150
      mindarmour/fuzz_testing/fuzzing.py
  6. +308
    -223
      mindarmour/fuzz_testing/model_coverage_metrics.py
  7. +42
    -49
      mindarmour/utils/_check_param.py
  8. +50
    -47
      tests/ut/python/fuzzing/test_coverage_metrics.py
  9. +22
    -24
      tests/ut/python/fuzzing/test_fuzzer.py

+ 22
- 30
examples/ai_fuzzer/lenet5_mnist_coverage.py View File

@@ -14,11 +14,10 @@
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.adv_robustness.attacks import FastGradientSignMethod
from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.fuzz_testing.model_coverage_metrics import NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage,\
SuperNeuronActivateCoverage, KMultisectionNeuronCoverage
from mindarmour.utils.logger import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
@@ -46,13 +45,6 @@ def test_lenet_mnist_coverage():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1

# initialize fuzz test with training dataset
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -67,31 +59,31 @@ def test_lenet_mnist_coverage():
test_images.append(images)
test_labels.append(labels)
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
model_fuzz_test.calculate_coverage(test_images)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
# initialize fuzz test with training dataset
nc = NeuronCoverage(model, threshold=0.1)
nc_metric = nc.get_metrics(test_images)

tknc = TopKNeuronCoverage(model, top_k=3)
tknc_metrics = tknc.get_metrics(test_images)

snac = SuperNeuronActivateCoverage(model, train_images)
snac_metrics = snac.get_metrics(test_images)

nbc = NeuronBoundsCoverage(model, train_images)
nbc_metrics = nbc.get_metrics(test_images)

# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_images, test_labels, batch_size=32)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this adv data is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this adv data is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this adv data is : %s', model_fuzz_test.get_snac())
kmnc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
kmnc_metrics = kmnc.get_metrics(test_images)

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
print('KMNC of this test is: ', kmnc_metrics)
print('NBC of this test is: ', nbc_metrics)
print('SNAC of this test is: ', snac_metrics)
print('NC of this test is: ', nc_metric)
print('TKNC of this test is: ', tknc_metrics)


if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_lenet_mnist_coverage()

+ 13
- 19
examples/ai_fuzzer/lenet5_mnist_fuzzing.py View File

@@ -14,11 +14,11 @@
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import load_checkpoint, load_param_into_net

from mindarmour.fuzz_testing import Fuzzer
from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
from mindarmour.utils import LogUtil

from examples.common.dataset.data_processing import generate_mnist_dataset
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
@@ -52,7 +52,7 @@ def test_lenet_mnist_fuzzing():
'params': {'auto_param': [True]}},
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}
]
]

# get training data
data_list = "../common/dataset/MNIST/train"
@@ -63,11 +63,6 @@ def test_lenet_mnist_fuzzing():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
neuron_num = 10
segmented_num = 1000

# initialize fuzz test with training dataset
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -88,21 +83,20 @@ def test_lenet_mnist_fuzzing():
# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label])

coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True)
kmnc = coverage.get_metrics(test_images[:100])
print('KMNC of initial seeds is: ', kmnc)
initial_seeds = initial_seeds[:100]
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzer(model)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10,
mutate_num_per_seed=20)

model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, eval_metrics='auto')
if metrics:
for key in metrics:
LOGGER.info(TAG, key + ': %s', metrics[key])
print(key + ': ', metrics[key])


if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# device_target can be "CPU"GPU, "" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
test_lenet_mnist_fuzzing()

+ 4
- 5
examples/common/networks/lenet5/lenet5_net_for_fuzzing.py View File

@@ -20,19 +20,21 @@ from mindspore.ops import TensorSummary


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""Wrap conv."""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")


def fc_with_initialize(input_channels, out_channels):
"""Wrap initialize method of full connection layer."""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)


def weight_variable():
"""Wrap initialize variable."""
return TruncatedNormal(0.05)


@@ -50,7 +52,6 @@ class LeNet5(nn.Cell):
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()

self.summary = TensorSummary()

def construct(self, x):
@@ -59,8 +60,6 @@ class LeNet5(nn.Cell):
Returns:
x (tensor): network output
"""
self.summary('input', x)

x = self.conv1(x)
self.summary('1', x)



+ 8
- 2
mindarmour/fuzz_testing/__init__.py View File

@@ -16,7 +16,13 @@ This module provides a neuron coverage-gain based fuzz method to evaluate the
robustness of given model.
"""
from .fuzzing import Fuzzer
from .model_coverage_metrics import ModelCoverageMetrics
from .model_coverage_metrics import CoverageMetrics, NeuronCoverage, TopKNeuronCoverage, NeuronBoundsCoverage, \
SuperNeuronActivateCoverage, KMultisectionNeuronCoverage

__all__ = ['Fuzzer',
'ModelCoverageMetrics']
'CoverageMetrics',
'NeuronCoverage',
'TopKNeuronCoverage',
'NeuronBoundsCoverage',
'SuperNeuronActivateCoverage',
'KMultisectionNeuronCoverage']

+ 98
- 150
mindarmour/fuzz_testing/fuzzing.py View File

@@ -21,15 +21,14 @@ from mindspore import Model
from mindspore import Tensor
from mindspore import nn

from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_param_multi_types, check_norm_level, check_param_in_range, \
check_param_type, check_int_positive
from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \
check_param_in_range, check_param_type, check_int_positive, check_param_bounds
from mindarmour.utils.logger import LogUtil
from ..adv_robustness.attacks import FastGradientSignMethod, \
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent
from .image_transform import Contrast, Brightness, Blur, \
Noise, Translate, Scale, Shear, Rotate
from .model_coverage_metrics import ModelCoverageMetrics
from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage

LOGGER = LogUtil.get_instance()
TAG = 'Fuzzer'
@@ -43,11 +42,22 @@ def _select_next(initial_seeds):
return seed, initial_seeds


def _coverage_gains(coverages):
""" Calculate the coverage gains of mutated samples."""
gains = [0] + coverages[:-1]
def _coverage_gains(pre_coverage, coverages):
"""
Calculate the coverage gains of mutated samples.

Args:
pre_coverage (float): Last value of coverages for previous mutated samples.
coverages (list): Coverage of mutated samples.

Returns:
- list, coverage gains for mutated samples.

- float, last value in parameter coverages.
"""
gains = [pre_coverage] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains
return gains, coverages[-1]


def _is_trans_valid(seed, mutate_sample):
@@ -65,37 +75,22 @@ def _is_trans_valid(seed, mutate_sample):
size = np.shape(diff)[0]
l0_norm = np.linalg.norm(diff, ord=0)
linf = np.linalg.norm(diff, ord=np.inf)
if l0_norm > pixels_change_rate*size:
if l0_norm > pixels_change_rate * size:
if linf < 256:
is_valid = True
else:
if linf < pixel_value_change_rate*255:
if linf < pixel_value_change_rate * 255:
is_valid = True
return is_valid


def _check_eval_metrics(eval_metrics):
""" Check evaluation metrics."""
if isinstance(eval_metrics, (list, tuple)):
eval_metrics_ = []
available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac']
for elem in eval_metrics:
if elem not in available_metrics:
msg = 'metric in list `eval_metrics` must be in {}, but got {}.'.format(available_metrics, elem)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_.append(elem.lower())
elif isinstance(eval_metrics, str):
if eval_metrics != 'auto':
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, but got {}.".format(eval_metrics)
LOGGER.error(TAG, msg)
raise ValueError(msg)
eval_metrics_ = 'auto'
def _gain_threshold(coverage):
"""Get threshold for given neuron coverage class."""
if coverage is isinstance(coverage, KMultisectionNeuronCoverage):
gain_threshold = 0.1 / coverage.segmented_num
else:
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}.".format(type(eval_metrics))
LOGGER.error(TAG, msg)
raise TypeError(msg)
return eval_metrics_
gain_threshold = 0
return gain_threshold


class Fuzzer:
@@ -113,6 +108,7 @@ class Fuzzer:

Examples:
>>> net = Net()
>>> model = Model(net)
>>> mutate_config = [{'method': 'Blur',
>>> 'params': {'auto_param': [True]}},
>>> {'method': 'Contrast',
@@ -121,18 +117,15 @@ class Fuzzer:
>>> 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}},
>>> {'method': 'FGSM',
>>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}]
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> neuron_num = 10
>>> segmented_num = 1000
>>> model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds)
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
>>> model_fuzz_test = Fuzzer(model)
>>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds,
>>> nc, max_iters=100)
"""

def __init__(self, target_model, train_dataset, neuron_num,
segmented_num=1000):
def __init__(self, target_model):
self._target_model = check_model('model', target_model, Model)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._coverage_metrics = ModelCoverageMetrics(target_model, neuron_num, segmented_num, train_dataset)

# Allowed mutate strategies so far.
self._strategies = {'Contrast': Contrast,
'Brightness': Brightness,
@@ -161,8 +154,7 @@ class Fuzzer:
'prob': {'dtype': [float], 'range': [0, 1]},
'bounds': {'dtype': [tuple]}}}

def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC',
eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20):
def fuzzing(self, mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10000, mutate_num_per_seed=20):
"""
Fuzzing tests for deep neural networks.

@@ -175,32 +167,20 @@ class Fuzzer:
{'method': 'FGSM',
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}},
...].
The supported methods list is in `self._strategies`, and the
params of each method must within the range of optional parameters. 
Supported methods are grouped in three types:
Firstly, pixel value based transform methods include:
'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine
transform methods include: 'Translate', 'Scale', 'Shear' and
'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'.
`mutate_config` must have method in the type of pixel value based
transform methods. The way of setting parameters for first and
second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'.
For third type methods, the optional parameters refer to
The supported methods list is in `self._strategies`, and the params of each method must within the
range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based
transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform
methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM',
'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods.
The way of setting parameters for first and second type methods can be seen in
'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to
`self._attack_param_checklists`.
initial_seeds (list[list]): Initial seeds used to generate mutated
samples. The format of initial seeds is [[image_data, label],
[...], ...] and the label must be one-hot.
coverage_metric (str): Model coverage metric of neural networks. All
supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'.
eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the
type is 'auto', it will calculate all the metrics, else if the
type is list or tuple, it will calculate the metrics specified
by user. All supported evaluate methods are 'accuracy',
'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'.
max_iters (int): Max number of select a seed to mutate.
Default: 10000.
mutate_num_per_seed (int): The number of mutate times for a seed.
Default: 20.
initial_seeds (list[list]): Initial seeds used to generate mutated samples. The format of initial seeds is
[[image_data, label], [...], ...] and the label must be one-hot.
coverage (CoverageMetrics): Class of neuron coverage metrics.
evaluate (bool): return evaluate report or not. Default: True.
max_iters (int): Max number of select a seed to mutate. Default: 10000.
mutate_num_per_seed (int): The number of mutate times for a seed. Default: 20.

Returns:
- list, mutated samples in fuzz_testing.
@@ -214,18 +194,18 @@ class Fuzzer:
- dict, metrics report of fuzzer.

Raises:
TypeError: If the type of `eval_metrics` is not str, list or tuple.
TypeError: If the type of metric in list `eval_metrics` is not str.
ValueError: If `eval_metrics` is not equal to 'auto' when it's type is str.
ValueError: If metric in list `eval_metrics` is not in ['accuracy',
'attack_success_rate', 'kmnc', 'nbc', 'snac'].
ValueError, coverage must be subclass of CoverageMetrics.
ValueError, if initial seeds is empty.
ValueError, if element of seed is not two in initial seeds.
"""
# Check parameters.
eval_metrics_ = _check_eval_metrics(eval_metrics)
if coverage_metric not in ['KMNC', 'NBC', 'SNAC']:
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}.".format(coverage_metric)
if not isinstance(coverage, CoverageMetrics):
msg = 'coverage must be subclass of CoverageMetrics'
LOGGER.error(TAG, msg)
raise ValueError(msg)
evaluate = check_param_type('evaluate', evaluate, bool)
max_iters = check_int_positive('max_iters', max_iters)
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed)
mutate_config = self._check_mutate_config(mutate_config)
@@ -235,15 +215,21 @@ class Fuzzer:
if not initial_seeds:
msg = 'initial_seeds must not be empty.'
raise ValueError(msg)
initial_samples = []
for seed in initial_seeds:
check_param_type('seed', seed, list)
if len(seed) != 2:
msg = 'seed in initial seeds must have two element image and ' \
'label, but got {} element.'.format(len(seed))
msg = 'seed in initial seeds must have two element image and label, but got {} element.'.format(
len(seed))
raise ValueError(msg)
check_numpy_param('seed[0]', seed[0])
check_numpy_param('seed[1]', seed[1])
initial_samples.append(seed[0])
seed.append(0)
initial_samples = np.array(initial_samples)
# calculate the coverage of initial seeds
pre_coverage = coverage.get_metrics(initial_samples)
gain_threshold = _gain_threshold(coverage)

seed, initial_seeds = _select_next(initial_seeds)
fuzz_samples = []
@@ -253,30 +239,27 @@ class Fuzzer:
iter_num = 0
while initial_seeds and iter_num < max_iters:
# Mutate a seed.
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed,
mutates,
mutate_config,
mutate_samples, mutate_strategies = self._metamorphic_mutate(seed, mutates, mutate_config,
mutate_num_per_seed)
# Calculate the coverages and predictions of generated samples.
coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage_metric)
coverage_gains = _coverage_gains(coverages)
coverages, predicts = self._get_coverages_and_predict(mutate_samples, coverage)
coverage_gains, pre_coverage = _coverage_gains(pre_coverage, coverages)
for mutate, cov, pred, strategy in zip(mutate_samples, coverage_gains, predicts, mutate_strategies):
fuzz_samples.append(mutate[0])
true_labels.append(mutate[1])
fuzz_preds.append(pred)
fuzz_strategies.append(strategy)
# if the mutate samples has coverage gains add this samples in
# the initial_seeds to guide new mutates.
if cov > 0:
# if the mutate samples has coverage gains add this samples in the initial_seeds to guide new mutates.
if cov > gain_threshold:
initial_seeds.append(mutate)
seed, initial_seeds = _select_next(initial_seeds)
iter_num += 1
metrics_report = None
if eval_metrics_ is not None:
metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, eval_metrics_)
if evaluate:
metrics_report = self._evaluate(fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage)
return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report

def _get_coverages_and_predict(self, mutate_samples, coverage_metric="KNMC"):
def _get_coverages_and_predict(self, mutate_samples, coverage):
""" Calculate the coverages and predictions of generated samples."""
samples = [s[0] for s in mutate_samples]
samples = np.array(samples)
@@ -285,17 +268,10 @@ class Fuzzer:
predictions = predictions.asnumpy()
for index in range(len(samples)):
mutate = samples[:index + 1]
self._coverage_metrics.calculate_coverage(mutate.astype(np.float32))
if coverage_metric == 'KMNC':
coverages.append(self._coverage_metrics.get_kmnc())
if coverage_metric == 'NBC':
coverages.append(self._coverage_metrics.get_nbc())
if coverage_metric == 'SNAC':
coverages.append(self._coverage_metrics.get_snac())
coverages.append(coverage.get_metrics(mutate))
return coverages, predictions

def _metamorphic_mutate(self, seed, mutates, mutate_config,
mutate_num_per_seed):
def _metamorphic_mutate(self, seed, mutates, mutate_config, mutate_num_per_seed):
"""Mutate a seed using strategies random selected from mutate_config."""
mutate_samples = []
mutate_strategies = []
@@ -310,8 +286,8 @@ class Fuzzer:
params = strategy['params']
method = strategy['method']
selected_param = {}
for p in params:
selected_param[p] = choice(params[p])
for param in params:
selected_param[param] = choice(params[param])

if method in list(self._pixel_value_trans_list + self._affine_trans_list):
if method == 'Shear':
@@ -367,8 +343,7 @@ class Fuzzer:
else:
for key in params.keys():
check_param_type(str(key), params[key], list)
# Methods in `metate_config` should at least have one in the type of
# pixel value based transform methods.
# Methods in `metate_config` should at least have one in the type of pixel value based transform methods.
if not has_pixel_trans:
msg = "mutate methods in mutate_config should at least have one in {}".format(self._pixel_value_trans_list)
raise ValueError(msg)
@@ -386,17 +361,7 @@ class Fuzzer:
check_param_type(param_name, params[param_name], list)
for param_value in params[param_name]:
if param_name == 'bounds':
bounds = check_param_multi_types('bounds', param_value, [tuple])
if len(bounds) != 2:
msg = 'The format of bounds must be format (lower_bound, upper_bound),' \
'but got its length as{}'.format(len(bounds))
raise ValueError(msg)
for bound_value in bounds:
_ = check_param_multi_types('bound', bound_value, [int, float])
if bounds[0] >= bounds[1]:
msg = "upper bound must more than lower bound, but upper bound got {}, lower bound " \
"got {}".format(bounds[0], bounds[1])
raise ValueError(msg)
_ = check_param_bounds('bounds', param_name)
elif param_name == 'norm_level':
_ = check_norm_level(param_value)
else:
@@ -420,57 +385,40 @@ class Fuzzer:
mutates[method] = self._strategies[method](network, loss_fn=loss_fn)
return mutates

def _evaluate(self, fuzz_samples, true_labels, fuzz_preds,
fuzz_strategies, metrics):
def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, coverage):
"""
Evaluate generated fuzz_testing samples in three dimensions: accuracy,
attack success rate and neural coverage.
Evaluate generated fuzz_testing samples in three dimensions: accuracy, attack success rate and neural coverage.

Args:
fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples
according to seeds.
fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples according to seeds.
true_labels ([numpy.ndarray, list]): Ground truth labels of seeds.
fuzz_preds ([numpy.ndarray, list]): Predictions of generated fuzz samples.
fuzz_strategies ([numpy.ndarray, list]): Mutate strategies of fuzz samples.
metrics (Union[list, tuple, str]): evaluation metrics.
coverage (CoverageMetrics): Neuron coverage metrics class.

Returns:
dict, evaluate metrics include accuracy, attack success rate
and neural coverage.
dict, evaluate metrics include accuracy, attack success rate and neural coverage.
"""
fuzz_samples = np.array(fuzz_samples)
true_labels = np.asarray(true_labels)
fuzz_preds = np.asarray(fuzz_preds)
temp = np.argmax(true_labels, axis=1) == np.argmax(fuzz_preds, axis=1)
metrics_report = {}
if metrics == 'auto' or 'accuracy' in metrics:
if temp.any():
acc = np.sum(temp) / np.size(temp)
else:
acc = 0
metrics_report['Accuracy'] = acc

if metrics == 'auto' or 'attack_success_rate' in metrics:
cond = [elem in self._attacks_list for elem in fuzz_strategies]
temp = temp[cond]
if temp.any():
attack_success_rate = 1 - np.sum(temp) / np.size(temp)
else:
attack_success_rate = None
metrics_report['Attack_success_rate'] = attack_success_rate

if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics:
self._coverage_metrics.calculate_coverage(np.array(fuzz_samples).astype(np.float32))

if metrics == 'auto' or 'kmnc' in metrics:
kmnc = self._coverage_metrics.get_kmnc()
metrics_report['Neural_coverage_KMNC'] = kmnc

if metrics == 'auto' or 'nbc' in metrics:
nbc = self._coverage_metrics.get_nbc()
metrics_report['Neural_coverage_NBC'] = nbc

if metrics == 'auto' or 'snac' in metrics:
snac = self._coverage_metrics.get_snac()
metrics_report['Neural_coverage_SNAC'] = snac
if temp.any():
acc = np.sum(temp) / np.size(temp)
else:
acc = 0
metrics_report['Accuracy'] = acc

cond = [elem in self._attacks_list for elem in fuzz_strategies]
temp = temp[cond]
if temp.any():
attack_success_rate = 1 - np.sum(temp) / np.size(temp)
else:
attack_success_rate = None
metrics_report['Attack_success_rate'] = attack_success_rate

metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples)

return metrics_report

+ 308
- 223
mindarmour/fuzz_testing/model_coverage_metrics.py View File

@@ -14,311 +14,396 @@
"""
Model-Test Coverage Metrics.
"""
from abc import abstractmethod
from collections import defaultdict
import math
import numpy as np

from mindspore import Tensor
from mindspore import Model
from mindspore.train.summary.summary_record import _get_summary_tensor_data

from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive, check_param_multi_types
from mindarmour.utils._check_param import check_model, check_numpy_param, check_int_positive, \
check_param_type, check_value_positive
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'ModelCoverageMetrics'
TAG = 'CoverageMetrics'


class ModelCoverageMetrics:
class CoverageMetrics:
"""
As we all known, each neuron output of a network will have a output range
after training (we call it original range), and test dataset is used to
estimate the accuracy of the trained network. However, neurons' output
distribution would be different with different test datasets. Therefore,
similar to function fuzz, model fuzz means testing those neurons' outputs
and estimating the proportion of original range that has emerged with test
The abstract base class for Neuron coverage classes calculating coverage metrics.
As we all known, each neuron output of a network will have a output range after training (we call it original
range), and test dataset is used to estimate the accuracy of the trained network. However, neurons' output
distribution would be different with different test datasets. Therefore, similar to function fuzz, model fuzz means
testing those neurons' outputs and estimating the proportion of original range that has emerged with test
datasets.

Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep
Learning Systems <https://arxiv.org/abs/1803.07519>`_
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep Learning Systems
<https://arxiv.org/abs/1803.07519>`_

Args:
model (Model): The pre-trained model which waiting for testing.
neuron_num (int): The number of testing neurons.
segmented_num (int): The number of segmented sections of neurons' output intervals.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.

Raises:
ValueError: If neuron_num is too big (for example, bigger than 1e+9).

Examples:
>>> net = LeNet5()
>>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32)
>>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32)
>>> model = Model(net)
>>> neuron_num = 10
>>> segmented_num = 1000
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc())
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc())
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac())
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
>>> print('NC of this test is : %s', model_fuzz_test.get_nc())
>>> print('Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.
"""

def __init__(self, model, neuron_num, segmented_num, train_dataset):
def __init__(self, model, incremental=False, batch_size=32):
self._model = check_model('model', model, Model)
self._segmented_num = check_int_positive('segmented_num', segmented_num)
self._neuron_num = check_int_positive('neuron_num', neuron_num)
if self._neuron_num > 1e+9:
msg = 'neuron_num should be less than 1e+10, otherwise a MemoryError would occur'
LOGGER.error(TAG, msg)
raise ValueError(msg)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self._lower_bounds = [np.inf]*self._neuron_num
self._upper_bounds = [-np.inf]*self._neuron_num
self._var = [0]*self._neuron_num
self._main_section_hits = [[0 for _ in range(self._segmented_num)] for _ in range(self._neuron_num)]
self._lower_corner_hits = [0]*self._neuron_num
self._upper_corner_hits = [0]*self._neuron_num
self._bounds_get(train_dataset)
self._model_layer_dict = defaultdict(bool)
self._effective_model_layer_dict = defaultdict(bool)

def _set_init_effective_coverage_table(self, dataset):
self.incremental = check_param_type('incremental', incremental, bool)
self.batch_size = check_int_positive('batch_size', batch_size)
self._activate_table = defaultdict(list)

@abstractmethod
def get_metrics(self, dataset):
"""
Initialise the coverage table of each neuron in the model.
Calculate coverage metrics of given dataset.

Args:
dataset (numpy.ndarray): Dataset used for initialising the coverage table.
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Raises:
NotImplementedError: It is an abstract method.
"""
self._model.predict(Tensor(dataset[0:1]))
tensors = _get_summary_tensor_data()
for name, tensor in tensors.items():
if 'input' in name:
continue
for num_neuron in range(tensor.shape[1]):
self._model_layer_dict[(name, num_neuron)] = False
self._effective_model_layer_dict[(name, num_neuron)] = False

def _bounds_get(self, train_dataset, batch_size=32):
msg = 'The function get_metrics() is an abstract method in class `CoverageMetrics`, and should be' \
' implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)

def _init_neuron_activate_table(self, data):
"""
Update the lower and upper boundaries of neurons' outputs.
Initialise the activate table of each neuron in the model with format:
{'layer1': [n1, n2, n3, ..., nn], 'layer2': [n1, n2, n3, ..., nn], ...}

Args:
train_dataset (numpy.ndarray): Training dataset used for
determine the neurons' output boundaries.
batch_size (int): The number of samples in a predict batch.
Default: 32.
data (numpy.ndarray): Data used for initialising the activate table.
Return:
dict, return a activate_table.
"""
batch_size = check_int_positive('batch_size', batch_size)
output_mat = []
batches = train_dataset.shape[0] // batch_size
for i in range(batches):
inputs = train_dataset[i*batch_size: (i + 1)*batch_size]
output = self._model.predict(Tensor(inputs)).asnumpy()
output_mat.append(output)
lower_compare_array = np.concatenate([output, np.array([self._lower_bounds])], axis=0)
self._lower_bounds = np.min(lower_compare_array, axis=0)
upper_compare_array = np.concatenate([output, np.array([self._upper_bounds])], axis=0)
self._upper_bounds = np.max(upper_compare_array, axis=0)
if batches == 0:
output = self._model.predict(Tensor(train_dataset)).asnumpy()
self._lower_bounds = np.min(output, axis=0)
self._upper_bounds = np.max(output, axis=0)
output_mat.append(output)
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0)

def _sections_hits_count(self, dataset, intervals):
self._model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
if not layer_out:
msg = 'User must use TensorSummary() operation to specify the middle layer of the model participating in ' \
'the coverage calculation.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
activate_table = defaultdict()
for layer, value in layer_out.items():
activate_table[layer] = np.zeros(value.shape[1], np.bool)
return activate_table

def _get_bounds(self, train_dataset):
"""
Update the coverage matrix of neurons' output subsections.
Update the lower and upper boundaries of neurons' outputs.

Args:
dataset (numpy.ndarray): Testing data.
intervals (list[float]): Segmentation intervals of neurons' outputs.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.

Return:
- numpy.ndarray, upper bounds of neuron' outputs.

- numpy.ndarray, lower bounds of neuron' outputs.
"""
dataset = check_numpy_param('dataset', dataset)
batch_output = self._model.predict(Tensor(dataset)).asnumpy()
batch_section_indexes = (batch_output - self._lower_bounds) // intervals
for section_indexes in batch_section_indexes:
for i in range(self._neuron_num):
if section_indexes[i] < 0:
self._lower_corner_hits[i] = 1
elif section_indexes[i] >= self._segmented_num:
self._upper_corner_hits[i] = 1
upper_bounds = defaultdict(list)
lower_bounds = defaultdict(list)
batches = math.ceil(train_dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = train_dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
min_value = np.min(value, axis=0)
max_value = np.max(value, axis=0)
if np.any(upper_bounds[layer]):
max_flag = upper_bounds[layer] > max_value
min_flag = lower_bounds[layer] < min_value
upper_bounds[layer] = upper_bounds[layer] * max_flag + max_value * (1 - max_flag)
lower_bounds[layer] = lower_bounds[layer] * min_flag + min_value * (1 - min_flag)
else:
self._main_section_hits[i][int(section_indexes[i])] = 1
upper_bounds[layer] = max_value
lower_bounds[layer] = min_value
return upper_bounds, lower_bounds

def _coverage_update(self, name, tensor, scaled_mean, scaled_rank, top_k, threshold):
def _activate_rate(self):
"""
Calculate the activate rate of neurons.
"""
Update the coverage matrix of neural coverage and effective neural coverage.
total_neurons = 0
activated_neurons = 0
for _, value in self._activate_table.items():
activated_neurons += np.sum(value)
total_neurons += len(value)
activate_rate = activated_neurons / total_neurons

Args:
name (string): the name of the tensor.
tensor (tensor): the tensor in the network.
scaled_mean (numpy.ndarray): feature map of the tensor.
scaled_rank (numpy.ndarray): rank of tensor value.
top_k (int): neuron is covered when its output has the top k largest value in that hidden layer.
threshold (float): neuron is covered when its output is greater than the threshold.
return activate_rate


class NeuronCoverage(CoverageMetrics):
"""
Calculate the neurons activated coverage. Neuron is activated when its output is greater than the threshold.
Neuron coverage equals the proportion of activated neurons to total neurons in the network.

Args:
model (Model): The pre-trained model which waiting for testing.
threshold (float): Threshold used to determined neurons is activated or not. Default: 0.1.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.

"""
def __init__(self, model, threshold=0.1, incremental=False, batch_size=32):
super(NeuronCoverage, self).__init__(model, incremental, batch_size)
self.threshold = check_value_positive('threshold', threshold)

def get_metrics(self, dataset):
"""
for num_neuron in range(tensor.shape[1]):
if num_neuron >= (len(scaled_rank) - top_k) and not \
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])]:
self._effective_model_layer_dict[(name, scaled_rank[num_neuron])] = True
if scaled_mean[num_neuron] > threshold and not \
self._model_layer_dict[(name, num_neuron)]:
self._model_layer_dict[(name, num_neuron)] = True

def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32):
"""
Calculate the testing adequacy of the given dataset.
Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network.

Args:
dataset (numpy.ndarray): Data for fuzz test.
bias_coefficient (Union[int, float]): The coefficient used
for changing the neurons' output boundaries. Default: 0.
batch_size (int): The number of samples in a predict batch. Default: 32.
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Returns:
float, the metric of 'neuron coverage'.

Examples:
>>> neuron_num = 10
>>> segmented_num = 1000
>>> batch_size = 32
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images, top_k, threshold, batch_size)
>>> nc = NeuronCoverage(model, threshold=0.1)
>>> nc_metrics = nc.get_metrics(test_data)
"""

dataset = check_numpy_param('dataset', dataset)
batch_size = check_int_positive('batch_size', batch_size)
bias_coefficient = check_param_multi_types('bias_coefficient', bias_coefficient, [int, float])
self._lower_bounds -= bias_coefficient*self._var
self._upper_bounds += bias_coefficient*self._var
intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num
batches = dataset.shape[0] // batch_size
batches = math.ceil(dataset.shape[0] / self.batch_size)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
for i in range(batches):
self._sections_hits_count(dataset[i*batch_size: (i + 1)*batch_size], intervals)
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
activate = np.sum(value > self.threshold, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
neuron_coverage = self._activate_rate()
return neuron_coverage


class TopKNeuronCoverage(CoverageMetrics):
"""
Calculate the top k activated neurons coverage. Neuron is activated when its output has the top k largest value in
that hidden layers. Top k neurons coverage equals the proportion of activated neurons to total neurons in the
network.

Args:
model (Model): The pre-trained model which waiting for testing.
top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers. Default: 3.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.
"""
def __init__(self, model, top_k=3, incremental=False, batch_size=32):
super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
self.top_k = check_int_positive('top_k', top_k)

def calculate_effective_coverage(self, dataset, top_k=3, threshold=0.1, batch_size=32):
def get_metrics(self, dataset):
"""
Calculate the effective testing adequacy of the given dataset.
In effective neural coverage, neuron is covered when its output has the top k largest value
in that hidden layers. In neural coverage, neuron is covered when its output is greater than the
threshold. Coverage equals the covered neurons divided by the total neurons in the network.
Get the metric of Top K activated neuron coverage.

Args:
threshold (float): neuron is covered when its output is greater than the threshold.
top_k (int): neuron is covered when its output has the top k largest value in that hiddern layer.
dataset (numpy.ndarray): Data for fuzz test.
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Returns:
float, the metrics of 'top k neuron coverage'.

Examples:
>>> neuron_num = 10
>>> segmented_num = 1000
>>> top_k = 3
>>> threshold = 0.1
>>> model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)
>>> model_fuzz_test.calculate_coverage(test_images)
>>> model_fuzz_test.calculate_effective_coverage(test_images, top_k, threshold)
>>> tknc = TopKNeuronCoverage(model, top_k=3)
>>> metrics = tknc.get_metrics(test_data)
"""
top_k = check_int_positive('top_k', top_k)
dataset = check_numpy_param('dataset', dataset)
batch_size = check_int_positive('batch_size', batch_size)
batches = dataset.shape[0] // batch_size
self._set_init_effective_coverage_table(dataset)
batches = math.ceil(dataset.shape[0] / self.batch_size)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
for i in range(batches):
inputs = dataset[i*batch_size: (i + 1)*batch_size]
self._model.predict(Tensor(inputs)).asnumpy()
tensors = _get_summary_tensor_data()
for name, tensor in tensors.items():
if 'input' in name:
continue
scaled = tensor.asnumpy()[-1]
if scaled.ndim >= 3: #
scaled_mean = np.mean(scaled, axis=(1, 2))
scaled_rank = np.argsort(scaled_mean)
self._coverage_update(name, tensor, scaled_mean, scaled_rank, top_k, threshold)
else:
scaled_rank = np.argsort(scaled)
self._coverage_update(name, tensor, scaled, scaled_rank, top_k, threshold)
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
top_k_value = np.sort(value)[:, -self.top_k].reshape(value.shape[0], 1)
top_k_value = np.sum((value - top_k_value) >= 0, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], top_k_value)
top_k_neuron_coverage = self._activate_rate()
return top_k_neuron_coverage


class SuperNeuronActivateCoverage(CoverageMetrics):
"""
Get the metric of 'super neuron activation coverage'. :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the
proportion of neurons whose neurons output value in the test set exceeds the upper bounds of the corresponding
neurons output value in the training set.

def get_nc(self):
Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.
"""
def __init__(self, model, train_dataset, incremental=False, batch_size=32):
super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self.upper_bounds, self.lower_bounds = self._get_bounds(train_dataset=train_dataset)

def get_metrics(self, dataset):
"""
Get the metric of 'neuron coverage'.
Get the metric of 'strong neuron activation coverage'.

Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Returns:
float, the metric of 'neuron coverage'.
float, the metric of 'strong neuron activation coverage'.

Examples:
>>> model_fuzz_test.get_nc()
>>> snac = SuperNeuronActivateCoverage(model, train_dataset)
>>> metrics = snac.get_metrics(test_data)
"""
covered_neurons = len([v for v in self._model_layer_dict.values() if v])
total_neurons = len(self._model_layer_dict)
nc = covered_neurons / float(total_neurons)
return nc
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
batches = math.ceil(dataset.shape[0] / self.batch_size)

def get_effective_nc(self):
"""
Get the metric of 'effective neuron coverage'.
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
activate = np.sum(value > self.upper_bounds[layer], axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
snac = self._activate_rate()
return snac

Returns:
float, the metric of 'the effective neuron coverage'.

Examples:
>>> model_fuzz_test.get_effective_nc()
"""
covered_neurons = len([v for v in self._effective_model_layer_dict.values() if v])
total_neurons = len(self._effective_model_layer_dict)
effective_nc = covered_neurons / float(total_neurons)
return effective_nc
class NeuronBoundsCoverage(SuperNeuronActivateCoverage):
"""
Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| + |LowerCornerNeuron|)/(2*|N|)`,
where :math`|N|` is the number of neurons, NBC refers to the proportion of neurons whose neurons output value in
the test dataset exceeds the upper and lower bounds of the corresponding neurons output value in the training
dataset.

def get_kmnc(self):
"""
Get the metric of 'k-multisection neuron coverage'. KMNC measures how
thoroughly the given set of test inputs covers the range of neurons
output values derived from training dataset.
Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.
"""

Returns:
float, the metric of 'k-multisection neuron coverage'.
def __init__(self, model, train_dataset, incremental=False, batch_size=32):
super(NeuronBoundsCoverage, self).__init__(model, train_dataset, incremental=incremental, batch_size=batch_size)

Examples:
>>> model_fuzz_test.get_kmnc()
def get_metrics(self, dataset):
"""
kmnc = np.sum(self._main_section_hits) / (self._neuron_num*self._segmented_num)
return kmnc
Get the metric of 'neuron boundary coverage'.

def get_nbc(self):
"""
Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron|
+ |LowerCornerNeuron|)/(2*|N|)`, where :math`|N|` is the number of neurons,
NBC refers to the proportion of neurons whose neurons output value in
the test dataset exceeds the upper and lower bounds of the corresponding
neurons output value in the training dataset.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Returns:
float, the metric of 'neuron boundary coverage'.

Examples:
>>> model_fuzz_test.get_nbc()
>>> nbc = NeuronBoundsCoverage(model, train_dataset)
>>> metrics = nbc.get_metrics(test_data)
"""
nbc = (np.sum(self._lower_corner_hits) + np.sum(self._upper_corner_hits)) / (2*self._neuron_num)
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])

batches = math.ceil(dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
outer = np.logical_or(value > self.upper_bounds[layer], value < self.lower_bounds[layer])
activate = np.sum(outer, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
nbc = self._activate_rate()
return nbc

def get_snac(self):

class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage):
"""
Get the metric of 'k-multisection neuron coverage'. KMNC measures how thoroughly the given set of test inputs
covers the range of neurons output values derived from training dataset.

Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: 100.
incremental (bool): Metrics will be calculate in incremental way or not. Default: False.
batch_size (int): The number of samples in a fuzz test batch. Default: 32.
"""

def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32):
super(KMultisectionNeuronCoverage, self).__init__(model, train_dataset, incremental=incremental,
batch_size=batch_size)
self.segmented_num = check_int_positive('segmented_num', segmented_num)
self.intervals = defaultdict(list)
for keys in self.upper_bounds.keys():
self.intervals[keys] = (self.upper_bounds[keys] - self.lower_bounds[keys]) / self.segmented_num

def _init_k_multisection_table(self, data):
""" Initial the activate table."""
self._model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
activate_section_table = defaultdict()
for layer, value in layer_out.items():
activate_section_table[layer] = np.zeros((value.shape[1], self.segmented_num), np.bool)
return activate_section_table

def get_metrics(self, dataset):
"""
Get the metric of 'strong neuron activation coverage'.
:math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the proportion
of neurons whose neurons output value in the test set exceeds the upper
bounds of the corresponding neurons output value in the training set.
Get the metric of 'k-multisection neuron coverage'.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.

Returns:
float, the metric of 'strong neuron activation coverage'.
float, the metric of 'k-multisection neuron coverage'.

Examples:
>>> model_fuzz_test.get_snac()
>>> kmnc = KMultisectionNeuronCoverage(model, train_dataset, segmented_num=100)
>>> metrics = kmnc.get_metrics(test_data)
"""
snac = np.sum(self._upper_corner_hits) / self._neuron_num
return snac

dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_k_multisection_table(dataset[0:1])

batches = math.ceil(dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
hits = np.floor((value - self.lower_bounds[layer]) / self.intervals[layer]).astype(int)
hits = np.transpose(hits, [1, 0])
for n in range(len(hits)):
for sec in hits[n]:
if sec >= self.segmented_num or sec < 0:
continue
self._activate_table[layer][n][sec] = True

kmnc = self._activate_rate() / self.segmented_num
return kmnc

+ 42
- 49
mindarmour/utils/_check_param.py View File

@@ -39,9 +39,7 @@ def _check_array_not_empty(arg_name, arg_value):
def check_param_type(arg_name, arg_value, valid_type):
"""Check parameter type."""
if not isinstance(arg_value, valid_type):
msg = '{} must be {}, but got {}'.format(arg_name,
valid_type,
type(arg_value).__name__)
msg = '{} must be {}, but got {}'.format(arg_name, valid_type, type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise TypeError(msg)

@@ -51,8 +49,7 @@ def check_param_type(arg_name, arg_value, valid_type):
def check_param_multi_types(arg_name, arg_value, valid_types):
"""Check parameter multi types."""
if not isinstance(arg_value, tuple(valid_types)):
msg = 'type of {} must be in {}, but got {}' \
.format(arg_name, valid_types, type(arg_value).__name__)
msg = 'type of {} must be in {}, but got {}'.format(arg_name, valid_types, type(arg_value).__name__)
LOGGER.error(TAG, msg)
raise TypeError(msg)

@@ -68,8 +65,7 @@ def check_int_positive(arg_name, arg_value):
raise ValueError(msg)
arg_value = check_param_type(arg_name, arg_value, int)
if arg_value <= 0:
msg = '{} must be greater than 0, but got {}'.format(arg_name,
arg_value)
msg = '{} must be greater than 0, but got {}'.format(arg_name, arg_value)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return arg_value
@@ -79,8 +75,7 @@ def check_value_non_negative(arg_name, arg_value):
"""Check non negative value."""
arg_value = check_param_multi_types(arg_name, arg_value, (int, float))
if float(arg_value) < 0.0:
msg = '{} must not be less than 0, but got {}'.format(arg_name,
arg_value)
msg = '{} must not be less than 0, but got {}'.format(arg_name, arg_value)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return arg_value
@@ -90,8 +85,7 @@ def check_value_positive(arg_name, arg_value):
"""Check positive value."""
arg_value = check_param_multi_types(arg_name, arg_value, (int, float))
if float(arg_value) <= 0.0:
msg = '{} must be greater than zero, but got {}'.format(arg_name,
arg_value)
msg = '{} must be greater than zero, but got {}'.format(arg_name, arg_value)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return arg_value
@@ -102,10 +96,7 @@ def check_param_in_range(arg_name, arg_value, lower, upper):
Check range of parameter.
"""
if arg_value <= lower or arg_value >= upper:
msg = '{} must be between {} and {}, but got {}'.format(arg_name,
lower,
upper,
arg_value)
msg = '{} must be between {} and {}, but got {}'.format(arg_name, lower, upper, arg_value)
LOGGER.error(TAG, msg)
raise ValueError(msg)

@@ -129,10 +120,7 @@ def check_model(model_name, model, model_type):
"""
if isinstance(model, model_type):
return model
msg = '{} should be an instance of {}, but got {}' \
.format(model_name,
model_type,
type(model).__name__)
msg = '{} should be an instance of {}, but got {}'.format(model_name, model_type, type(model).__name__)
LOGGER.error(TAG, msg)
raise TypeError(msg)

@@ -175,11 +163,9 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels):
labels (numpy.ndarray): Labels of `inputs`.

Returns:
- numpy.ndarray, if `inputs` 's dimension equals to
`labels`, return inputs with type of numpy.ndarray.
- numpy.ndarray, if `inputs` 's dimension equals to `labels`, return inputs with type of numpy.ndarray.

- numpy.ndarray, if `inputs` 's dimension equals to
`labels` , return labels with type of numpy.ndarray.
- numpy.ndarray, if `inputs` 's dimension equals to `labels` , return labels with type of numpy.ndarray.

Raises:
ValueError: If inputs.shape[0] is not equal to labels.shape[0].
@@ -188,8 +174,7 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels):
labels = check_numpy_param(labels_name, labels)
if inputs.shape[0] != labels.shape[0]:
msg = '{} shape[0] must equal {} shape[0], bot got shape of ' \
'inputs {}, shape of labels {}'.format(inputs_name, labels_name,
inputs.shape, labels.shape)
'inputs {}, shape of labels {}'.format(inputs_name, labels_name, inputs.shape, labels.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return inputs, labels
@@ -198,10 +183,8 @@ def check_pair_numpy_param(inputs_name, inputs, labels_name, labels):
def check_equal_length(para_name1, value1, para_name2, value2):
"""Check weather the two parameters have equal length."""
if len(value1) != len(value2):
msg = 'The dimension of {0} must equal to the ' \
'{1}, but got {0} is {2}, ' \
'{1} is {3}'.format(para_name1, para_name2, len(value1),
len(value2))
msg = 'The dimension of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'\
.format(para_name1, para_name2, len(value1), len(value2))
LOGGER.error(TAG, msg)
raise ValueError(msg)
return value1, value2
@@ -210,10 +193,8 @@ def check_equal_length(para_name1, value1, para_name2, value2):
def check_equal_shape(para_name1, value1, para_name2, value2):
"""Check weather the two parameters have equal shape."""
if value1.shape != value2.shape:
msg = 'The shape of {0} must equal to the ' \
'{1}, but got {0} is {2}, ' \
'{1} is {3}'.format(para_name1, para_name2, value1.shape,
value2.shape)
msg = 'The shape of {0} must equal to the {1}, but got {0} is {2}, {1} is {3}'.\
format(para_name1, para_name2, value1.shape, value2.shape)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return value1, value2
@@ -225,8 +206,7 @@ def check_norm_level(norm_level):
msg = 'Type of norm_level must be in [int, str], but got {}'.format(type(norm_level))
accept_norm = [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', np.inf]
if norm_level not in accept_norm:
msg = 'norm_level must be in {}, but got {}'.format(accept_norm,
norm_level)
msg = 'norm_level must be in {}, but got {}'.format(accept_norm, norm_level)
LOGGER.error(TAG, msg)
raise ValueError(msg)
return norm_level
@@ -252,20 +232,16 @@ def normalize_value(value, norm_level):
value_reshape = value.reshape((value.shape[0], -1))
avoid_zero_div = 1e-12
if norm_level in (1, '1', 'l1'):
norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + \
avoid_zero_div
norm = np.linalg.norm(value_reshape, ord=1, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (2, '2', 'l2'):
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + \
avoid_zero_div
norm = np.linalg.norm(value_reshape, ord=2, axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
elif norm_level in (np.inf, 'inf'):
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + \
avoid_zero_div
norm = np.max(abs(value_reshape), axis=1, keepdims=True) + avoid_zero_div
norm_value = value_reshape / norm
else:
msg = 'Values of `norm_level` different from 1, 2 and ' \
'`np.inf` are currently not supported, but got {}.' \
msg = 'Values of `norm_level` different from 1, 2 and `np.inf` are currently not supported, but got {}.' \
.format(norm_level)
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
@@ -339,13 +315,30 @@ def check_inputs_labels(inputs, labels):
inputs_image = inputs[0] if isinstance(inputs, tuple) else inputs
if isinstance(inputs, tuple):
for i, inputs_item in enumerate(inputs):
_ = check_pair_numpy_param('inputs_image', inputs_image, \
'inputs[{}]'.format(i), inputs_item)
_ = check_pair_numpy_param('inputs_image', inputs_image, 'inputs[{}]'.format(i), inputs_item)
if isinstance(labels, tuple):
for i, labels_item in enumerate(labels):
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels[{}]'.format(i), labels_item)
_ = check_pair_numpy_param('inputs', inputs_image, 'labels[{}]'.format(i), labels_item)
else:
_ = check_pair_numpy_param('inputs', inputs_image, \
'labels', labels)
_ = check_pair_numpy_param('inputs', inputs_image, 'labels', labels)
return inputs_image, inputs, labels


def check_param_bounds(arg_name, arg_value):
"""Check bounds is valid"""
arg_value = check_param_multi_types(arg_name, arg_value, [tuple, list])
if len(arg_value) != 2:
msg = 'length of {0} must be 2, but got length of {0} is {1}'.format(arg_name, len(arg_value))
LOGGER.error(TAG, msg)
raise ValueError(msg)
for i, b in enumerate(arg_value):
if not isinstance(b, (float, int)):
msg = 'each value in {} must be int or float, but got the {}th value is {}'.format(arg_name, i, b)
LOGGER.error(TAG, msg)
raise ValueError(msg)
if arg_value[0] > arg_value[1]:
msg = "lower boundary cannot be greater than upper boundary, corresponding values in {} are {} and {}". \
format(arg_name, arg_value[0], arg_value[1])
LOGGER.error(TAG, msg)
raise ValueError(msg)
return arg_value

+ 50
- 47
tests/ut/python/fuzzing/test_coverage_metrics.py View File

@@ -25,7 +25,8 @@ from mindspore.ops import TensorSummary

from mindarmour.adv_robustness.attacks import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.fuzz_testing import NeuronCoverage, TopKNeuronCoverage, SuperNeuronActivateCoverage, \
NeuronBoundsCoverage, KMultisectionNeuronCoverage

LOGGER = LogUtil.get_instance()
TAG = 'Neuron coverage test'
@@ -74,39 +75,48 @@ def test_lenet_mnist_coverage_cpu():
model = Model(net)

# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1
training_data = (np.random.random((10000, 10))*20).astype(np.float32)

model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data)

# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
model_fuzz_test.calculate_coverage(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
nc = NeuronCoverage(model, threshold=0.1)
nc_metric = nc.get_metrics(test_data)

tknc = TopKNeuronCoverage(model, top_k=3)
tknc_metrics = tknc.get_metrics(test_data)

snac = SuperNeuronActivateCoverage(model, training_data)
snac_metrics = snac.get_metrics(test_data)

nbc = NeuronBoundsCoverage(model, training_data)
nbc_metrics = nbc.get_metrics(test_data)

kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100)
kmnc_metrics = kmnc.get_metrics(test_data)

print('KMNC of this test is: ', kmnc_metrics)
print('NBC of this test is: ', nbc_metrics)
print('SNAC of this test is: ', snac_metrics)
print('NC of this test is: ', nc_metric)
print('TKNC of this test is: ', tknc_metrics)

# generate adv_data
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
nc_metric = nc.get_metrics(adv_data)
tknc_metrics = tknc.get_metrics(adv_data)
snac_metrics = snac.get_metrics(adv_data)
nbc_metrics = nbc.get_metrics(adv_data)
kmnc_metrics = kmnc.get_metrics(adv_data)
print('KMNC of adv data is: ', kmnc_metrics)
print('NBC of adv data is: ', nbc_metrics)
print('SNAC of adv data is: ', snac_metrics)
print('NC of adv data is: ', nc_metric)
print('TKNC of adv data is: ', tknc_metrics)

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@@ -120,35 +130,28 @@ def test_lenet_mnist_coverage_ascend():
model = Model(net)

# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
top_k = 3
threshold = 0.1
training_data = (np.random.random((10000, 10))*20).astype(np.float32)
model_fuzz_test = ModelCoverageMetrics(model, neuron_num, segmented_num, training_data)

# fuzz test with original test data
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000)
test_labels = (np.eye(10)[test_labels]).astype(np.float32)
model_fuzz_test.calculate_coverage(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
nc = NeuronCoverage(model, threshold=0.1)
nc_metric = nc.get_metrics(test_data)

model_fuzz_test.calculate_effective_coverage(test_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
tknc = TopKNeuronCoverage(model, top_k=3)
tknc_metrics = tknc.get_metrics(test_data)

# generate adv_data
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=nn.SoftmaxCrossEntropyWithLogits(sparse=False))
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())

model_fuzz_test.calculate_effective_coverage(adv_data, top_k, threshold)
LOGGER.info(TAG, 'NC of this test is : %s', model_fuzz_test.get_nc())
LOGGER.info(TAG, 'Effective_NC of this test is : %s', model_fuzz_test.get_effective_nc())
snac = SuperNeuronActivateCoverage(model, training_data)
snac_metrics = snac.get_metrics(test_data)

nbc = NeuronBoundsCoverage(model, training_data)
nbc_metrics = nbc.get_metrics(test_data)

kmnc = KMultisectionNeuronCoverage(model, training_data, segmented_num=100)
kmnc_metrics = kmnc.get_metrics(test_data)

print('KMNC of this test is: ', kmnc_metrics)
print('NBC of this test is: ', nbc_metrics)
print('SNAC of this test is: ', snac_metrics)
print('NC of this test is: ', nc_metric)
print('TKNC of this test is: ', tknc_metrics)

+ 22
- 24
tests/ut/python/fuzzing/test_fuzzer.py View File

@@ -21,9 +21,10 @@ from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.ops import TensorSummary

from mindarmour.fuzz_testing import Fuzzer
from mindarmour.fuzz_testing import ModelCoverageMetrics
from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
@@ -52,30 +53,37 @@ class Net(nn.Cell):
"""
Lenet network
"""

def __init__(self):
super(Net, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.summary = TensorSummary()

def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
self.summary('conv1', x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
self.summary('conv2', x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.reshape(x, (-1, 16 * 5 * 5))
x = self.fc1(x)
x = self.relu(x)
self.summary('fc1', x)
x = self.fc2(x)
x = self.relu(x)
self.summary('fc2', x)
x = self.fc3(x)
self.summary('fc3', x)
return x


@@ -100,12 +108,8 @@ def test_fuzzing_ascend():
{'method': 'FGSM',
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}
]
# initialize fuzz test with training dataset
neuron_num = 10
segmented_num = 1000
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
# fuzz test with original test data
# get test data
test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
@@ -118,13 +122,12 @@ def test_fuzzing_ascend():
initial_seeds.append([img, label])

initial_seeds = initial_seeds[:100]
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())

model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
cn_metrics = nc.get_metrics(test_images[:100])
print('neuron coverage of initial seeds is: ', cn_metrics)
model_fuzz_test = Fuzzer(model)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100)
print(metrics)


@@ -139,8 +142,6 @@ def test_fuzzing_cpu():
model = Model(net)
batch_size = 8
num_classe = 10
neuron_num = 10
segmented_num = 1000
mutate_config = [{'method': 'Blur',
'params': {'auto_param': [True]}},
{'method': 'Contrast',
@@ -152,7 +153,6 @@ def test_fuzzing_cpu():
]
# initialize fuzz test with training dataset
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images)

# fuzz test with original test data
# get test data
@@ -166,11 +166,9 @@ def test_fuzzing_cpu():
initial_seeds.append([img, label])

initial_seeds = initial_seeds[:100]
model_coverage_test.calculate_coverage(
np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())

model_fuzz_test = Fuzzer(model, train_images, neuron_num, segmented_num)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds)
nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
tknc_metrics = nc.get_metrics(test_images[:100])
print('neuron coverage of initial seeds is: ', tknc_metrics)
model_fuzz_test = Fuzzer(model)
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, nc, max_iters=100)
print(metrics)

Loading…
Cancel
Save