|
|
@@ -72,6 +72,34 @@ def _is_trans_valid(seed, mutate_sample): |
|
|
|
return is_valid |
|
|
|
|
|
|
|
|
|
|
|
def _check_eval_metrics(eval_metrics): |
|
|
|
""" Check evaluation metrics.""" |
|
|
|
if isinstance(eval_metrics, (list, tuple)): |
|
|
|
eval_metrics_ = [] |
|
|
|
available_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', |
|
|
|
'snac'] |
|
|
|
for elem in eval_metrics: |
|
|
|
if elem not in available_metrics: |
|
|
|
msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ |
|
|
|
.format(available_metrics, elem) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_.append(elem.lower()) |
|
|
|
elif isinstance(eval_metrics, str): |
|
|
|
if eval_metrics != 'auto': |
|
|
|
msg = "the value of `eval_metrics` must be 'auto' if it's type " \ |
|
|
|
"is str, but got {}.".format(eval_metrics) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_ = 'auto' |
|
|
|
else: |
|
|
|
msg = "the type of `eval_metrics` must be str, list or tuple, " \ |
|
|
|
"but got {}.".format(type(eval_metrics)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
return eval_metrics_ |
|
|
|
|
|
|
|
|
|
|
|
class Fuzzer: |
|
|
|
""" |
|
|
|
Fuzzing test framework for deep neural networks. |
|
|
@@ -89,16 +117,21 @@ class Fuzzer: |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = Net() |
|
|
|
>>> mutate_config = [{'method': 'Blur', 'params': {'auto_param': True}}, |
|
|
|
>>> {'method': 'Contrast','params': {'factor': 2}}, |
|
|
|
>>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, |
|
|
|
>>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] |
|
|
|
>>> mutate_config = [{'method': 'Blur', |
|
|
|
>>> 'params': {'auto_param': [True]}}, |
|
|
|
>>> {'method': 'Contrast', |
|
|
|
>>> 'params': {'factor': [2]}}, |
|
|
|
>>> {'method': 'Translate', |
|
|
|
>>> 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, |
|
|
|
>>> {'method': 'FGSM', |
|
|
|
>>> 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] |
|
|
|
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) |
|
|
|
>>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000) |
|
|
|
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzz_testing(mutate_config, initial_seeds) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, target_model, train_dataset, neuron_num, segmented_num=1000): |
|
|
|
def __init__(self, target_model, train_dataset, neuron_num, |
|
|
|
segmented_num=1000): |
|
|
|
self._target_model = check_model('model', target_model, Model) |
|
|
|
train_dataset = check_numpy_param('train_dataset', train_dataset) |
|
|
|
self._coverage_metrics = ModelCoverageMetrics(target_model, |
|
|
@@ -106,9 +139,14 @@ class Fuzzer: |
|
|
|
segmented_num, |
|
|
|
train_dataset) |
|
|
|
# Allowed mutate strategies so far. |
|
|
|
self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, |
|
|
|
'Blur': Blur, 'Noise': Noise, 'Translate': Translate, |
|
|
|
'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, |
|
|
|
self._strategies = {'Contrast': Contrast, |
|
|
|
'Brightness': Brightness, |
|
|
|
'Blur': Blur, |
|
|
|
'Noise': Noise, |
|
|
|
'Translate': Translate, |
|
|
|
'Scale': Scale, |
|
|
|
'Shear': Shear, |
|
|
|
'Rotate': Rotate, |
|
|
|
'FGSM': FastGradientSignMethod, |
|
|
|
'PGD': ProjectedGradientDescent, |
|
|
|
'MDIIM': MomentumDiverseInputIterativeMethod} |
|
|
@@ -117,21 +155,27 @@ class Fuzzer: |
|
|
|
'Noise'] |
|
|
|
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] |
|
|
|
self._attack_param_checklists = { |
|
|
|
'FGSM': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, |
|
|
|
'alpha': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [tuple]}}}, |
|
|
|
'PGD': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, |
|
|
|
'eps_iter': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'nb_iter': {'dtype': [int], |
|
|
|
'range': [0, 1e5]}, |
|
|
|
'bounds': {'dtype': [tuple]}}}, |
|
|
|
'MDIIM': { |
|
|
|
'params': {'eps': {'dtype': [float], 'range': [0, 1]}, |
|
|
|
'norm_level': {'dtype': [str]}, |
|
|
|
'prob': {'dtype': [float], 'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [tuple]}}}} |
|
|
|
'FGSM': {'eps': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'alpha': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [tuple]}}, |
|
|
|
'PGD': {'eps': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'eps_iter': { |
|
|
|
'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'nb_iter': {'dtype': [int], |
|
|
|
'range': [0, 100000]}, |
|
|
|
'bounds': {'dtype': [tuple]}}, |
|
|
|
'MDIIM': {'eps': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'norm_level': {'dtype': [str, int], |
|
|
|
'range': [1, 2, '1', '2', 'l1', 'l2', |
|
|
|
'inf', 'np.inf']}, |
|
|
|
'prob': {'dtype': [float], |
|
|
|
'range': [0, 1]}, |
|
|
|
'bounds': {'dtype': [tuple]}}} |
|
|
|
|
|
|
|
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', |
|
|
|
eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): |
|
|
@@ -140,10 +184,15 @@ class Fuzzer: |
|
|
|
|
|
|
|
Args: |
|
|
|
mutate_config (list): Mutate configs. The format is |
|
|
|
[{'method': 'Blur', 'params': {'auto_param': True}}, |
|
|
|
{'method': 'Contrast', 'params': {'factor': 2}}]. The |
|
|
|
supported methods list is in `self._strategies`, and the |
|
|
|
params of each method must within the range of changeable parameters. |
|
|
|
[{'method': 'Blur', |
|
|
|
'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}}, |
|
|
|
{'method': 'Contrast', |
|
|
|
'params': {'factor': [1, 1.5, 2]}}, |
|
|
|
{'method': 'FGSM', |
|
|
|
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, |
|
|
|
...]. |
|
|
|
The supported methods list is in `self._strategies`, and the |
|
|
|
params of each method must within the range of optional parameters. |
|
|
|
Supported methods are grouped in three types: |
|
|
|
Firstly, pixel value based transform methods include: |
|
|
|
'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine |
|
|
@@ -152,7 +201,8 @@ class Fuzzer: |
|
|
|
`mutate_config` must have method in the type of pixel value based |
|
|
|
transform methods. The way of setting parameters for first and |
|
|
|
second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'. |
|
|
|
For third type methods, you can refer to the corresponding class. |
|
|
|
For third type methods, the optional parameters refer to |
|
|
|
`self._attack_param_checklists`. |
|
|
|
initial_seeds (list[list]): Initial seeds used to generate mutated |
|
|
|
samples. The format of initial seeds is [[image_data, label], |
|
|
|
[...], ...]. |
|
|
@@ -186,60 +236,32 @@ class Fuzzer: |
|
|
|
ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', |
|
|
|
'kmnc', 'nbc', 'snac']. |
|
|
|
""" |
|
|
|
if isinstance(eval_metrics, (list, tuple)): |
|
|
|
eval_metrics_ = [] |
|
|
|
avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] |
|
|
|
for elem in eval_metrics: |
|
|
|
if elem not in avaliable_metrics: |
|
|
|
msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ |
|
|
|
.format(avaliable_metrics, elem) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_.append(elem.lower()) |
|
|
|
elif isinstance(eval_metrics, str): |
|
|
|
if eval_metrics != 'auto': |
|
|
|
msg = "the value of `eval_metrics` must be 'auto' if it's type is str, " \ |
|
|
|
"but got {}.".format(eval_metrics) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
eval_metrics_ = 'auto' |
|
|
|
else: |
|
|
|
msg = "the type of `eval_metrics` must be str, list or tuple, but got {}." \ |
|
|
|
.format(type(eval_metrics)) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
# Check whether the mutate_config meet the specification. |
|
|
|
mutate_config = check_param_type('mutate_config', mutate_config, list) |
|
|
|
for config in mutate_config: |
|
|
|
check_param_type("config['params']", config['params'], dict) |
|
|
|
if set(config.keys()) != {'method', 'params'}: |
|
|
|
msg = "Config must contain 'method' and 'params', but got {}." \ |
|
|
|
.format(set(config.keys())) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
if config['method'] not in self._strategies.keys(): |
|
|
|
msg = "Config methods must be in {}, but got {}." \ |
|
|
|
.format(self._strategies.keys(), config['method']) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
# Check parameters. |
|
|
|
eval_metrics_ = _check_eval_metrics(eval_metrics) |
|
|
|
if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: |
|
|
|
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}." \ |
|
|
|
.format(coverage_metric) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise ValueError(msg) |
|
|
|
max_iters = check_int_positive('max_iters', max_iters) |
|
|
|
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) |
|
|
|
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', |
|
|
|
mutate_num_per_seed) |
|
|
|
mutate_config = self._check_mutate_config(mutate_config) |
|
|
|
mutates = self._init_mutates(mutate_config) |
|
|
|
|
|
|
|
initial_seeds = check_param_type('initial_seeds', initial_seeds, list) |
|
|
|
if not initial_seeds: |
|
|
|
msg = 'initial_seeds must not be empty.' |
|
|
|
raise ValueError(msg) |
|
|
|
for seed in initial_seeds: |
|
|
|
check_param_type('seed', seed, list) |
|
|
|
check_numpy_param('seed[0]', seed[0]) |
|
|
|
check_numpy_param('seed[1]', seed[1]) |
|
|
|
seed.append(0) |
|
|
|
|
|
|
|
seed, initial_seeds = _select_next(initial_seeds) |
|
|
|
fuzz_samples = [] |
|
|
|
gt_labels = [] |
|
|
|
true_labels = [] |
|
|
|
fuzz_preds = [] |
|
|
|
fuzz_strategies = [] |
|
|
|
iter_num = 0 |
|
|
@@ -250,13 +272,15 @@ class Fuzzer: |
|
|
|
mutate_config, |
|
|
|
mutate_num_per_seed) |
|
|
|
# Calculate the coverages and predictions of generated samples. |
|
|
|
coverages, predicts = self._run(mutate_samples, coverage_metric) |
|
|
|
coverages, predicts = self._get_coverages_and_predict( |
|
|
|
mutate_samples, |
|
|
|
coverage_metric) |
|
|
|
coverage_gains = _coverage_gains(coverages) |
|
|
|
for mutate, cov, pred, strategy in zip(mutate_samples, |
|
|
|
coverage_gains, |
|
|
|
predicts, mutate_strategies): |
|
|
|
fuzz_samples.append(mutate[0]) |
|
|
|
gt_labels.append(mutate[1]) |
|
|
|
true_labels.append(mutate[1]) |
|
|
|
fuzz_preds.append(pred) |
|
|
|
fuzz_strategies.append(strategy) |
|
|
|
# if the mutate samples has coverage gains add this samples in |
|
|
@@ -267,16 +291,21 @@ class Fuzzer: |
|
|
|
iter_num += 1 |
|
|
|
metrics_report = None |
|
|
|
if eval_metrics_ is not None: |
|
|
|
metrics_report = self._evaluate(fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
fuzz_strategies, eval_metrics_) |
|
|
|
return fuzz_samples, gt_labels, fuzz_preds, fuzz_strategies, metrics_report |
|
|
|
|
|
|
|
def _run(self, mutate_samples, coverage_metric="KNMC"): |
|
|
|
metrics_report = self._evaluate(fuzz_samples, |
|
|
|
true_labels, |
|
|
|
fuzz_preds, |
|
|
|
fuzz_strategies, |
|
|
|
eval_metrics_) |
|
|
|
return fuzz_samples, true_labels, fuzz_preds, fuzz_strategies, metrics_report |
|
|
|
|
|
|
|
def _get_coverages_and_predict(self, mutate_samples, |
|
|
|
coverage_metric="KNMC"): |
|
|
|
""" Calculate the coverages and predictions of generated samples.""" |
|
|
|
samples = [s[0] for s in mutate_samples] |
|
|
|
samples = np.array(samples) |
|
|
|
coverages = [] |
|
|
|
predictions = self._target_model.predict(Tensor(samples.astype(np.float32))) |
|
|
|
predictions = self._target_model.predict( |
|
|
|
Tensor(samples.astype(np.float32))) |
|
|
|
predictions = predictions.asnumpy() |
|
|
|
for index in range(len(samples)): |
|
|
|
mutate = samples[:index + 1] |
|
|
@@ -289,31 +318,6 @@ class Fuzzer: |
|
|
|
coverages.append(self._coverage_metrics.get_snac()) |
|
|
|
return coverages, predictions |
|
|
|
|
|
|
|
def _check_attack_params(self, method, params): |
|
|
|
"""Check input parameters of attack methods.""" |
|
|
|
allow_params = self._attack_param_checklists[method]['params'].keys() |
|
|
|
for param_name in params: |
|
|
|
if param_name not in allow_params: |
|
|
|
msg = "parameters of {} must in {}".format(method, allow_params) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
param_value = params[param_name] |
|
|
|
if param_name == 'bounds': |
|
|
|
bounds = check_param_multi_types('bounds', param_value, |
|
|
|
[list, tuple]) |
|
|
|
for bound_value in bounds: |
|
|
|
_ = check_param_multi_types('bound', bound_value, [int, float]) |
|
|
|
elif param_name == 'norm_level': |
|
|
|
_ = check_norm_level(param_value) |
|
|
|
else: |
|
|
|
allow_type = self._attack_param_checklists[method]['params'][param_name][ |
|
|
|
'dtype'] |
|
|
|
allow_range = self._attack_param_checklists[method]['params'][param_name][ |
|
|
|
'range'] |
|
|
|
_ = check_param_multi_types(str(param_name), param_value, allow_type) |
|
|
|
_ = check_param_in_range(str(param_name), param_value, allow_range[0], |
|
|
|
allow_range[1]) |
|
|
|
|
|
|
|
def _metamorphic_mutate(self, seed, mutates, mutate_config, |
|
|
|
mutate_num_per_seed): |
|
|
|
"""Mutate a seed using strategies random selected from mutate_config.""" |
|
|
@@ -329,12 +333,18 @@ class Fuzzer: |
|
|
|
transform = mutates[strage['method']] |
|
|
|
params = strage['params'] |
|
|
|
method = strage['method'] |
|
|
|
if method in list(self._pixel_value_trans_list + self._affine_trans_list): |
|
|
|
transform.set_params(**params) |
|
|
|
selected_param = {} |
|
|
|
for p in params: |
|
|
|
selected_param[p] = choice(params[p]) |
|
|
|
|
|
|
|
if method in list( |
|
|
|
self._pixel_value_trans_list + self._affine_trans_list): |
|
|
|
transform.set_params(**selected_param) |
|
|
|
mutate_sample = transform.transform(seed[0]) |
|
|
|
else: |
|
|
|
for param_name in params: |
|
|
|
transform.__setattr__('_' + str(param_name), params[param_name]) |
|
|
|
for param_name in selected_param: |
|
|
|
transform.__setattr__('_' + str(param_name), |
|
|
|
selected_param[param_name]) |
|
|
|
mutate_sample = transform.generate([seed[0].astype(np.float32)], |
|
|
|
[seed[1]])[0] |
|
|
|
if method not in self._pixel_value_trans_list: |
|
|
@@ -348,51 +358,117 @@ class Fuzzer: |
|
|
|
mutate_strategies.append(None) |
|
|
|
return np.array(mutate_samples), mutate_strategies |
|
|
|
|
|
|
|
def _init_mutates(self, mutate_config): |
|
|
|
""" Check whether the mutate_config meet the specification.""" |
|
|
|
def _check_mutate_config(self, mutate_config): |
|
|
|
"""Check whether the mutate_config meet the specification.""" |
|
|
|
mutate_config = check_param_type('mutate_config', mutate_config, list) |
|
|
|
has_pixel_trans = False |
|
|
|
for mutate in mutate_config: |
|
|
|
if mutate['method'] in self._pixel_value_trans_list: |
|
|
|
|
|
|
|
for config in mutate_config: |
|
|
|
check_param_type("config", config, dict) |
|
|
|
if set(config.keys()) != {'method', 'params'}: |
|
|
|
msg = "Config must contain 'method' and 'params', but got {}." \ |
|
|
|
.format(set(config.keys())) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
method = config['method'] |
|
|
|
params = config['params'] |
|
|
|
|
|
|
|
# Method must be in the optional range. |
|
|
|
if method not in self._strategies.keys(): |
|
|
|
msg = "Config methods must be in {}, but got {}." \ |
|
|
|
.format(self._strategies.keys(), method) |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
if config['method'] in self._pixel_value_trans_list: |
|
|
|
has_pixel_trans = True |
|
|
|
break |
|
|
|
|
|
|
|
check_param_type('params', params, dict) |
|
|
|
# Check parameters of attack methods. The parameters of transformed |
|
|
|
# methods will be verified in transferred parameters. |
|
|
|
if method in self._attacks_list: |
|
|
|
self._check_attack_params(method, params) |
|
|
|
else: |
|
|
|
for key in params.keys(): |
|
|
|
check_param_type(str(key), params[key], list) |
|
|
|
# Methods in `metate_config` should at least have one in the type of |
|
|
|
# pixel value based transform methods. |
|
|
|
if not has_pixel_trans: |
|
|
|
msg = "mutate methods in mutate_config at lease have one in {}".format( |
|
|
|
self._pixel_value_trans_list) |
|
|
|
msg = "mutate methods in mutate_config should at least have one " \ |
|
|
|
"in {}".format(self._pixel_value_trans_list) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
return mutate_config |
|
|
|
|
|
|
|
def _check_attack_params(self, method, params): |
|
|
|
"""Check input parameters of attack methods.""" |
|
|
|
allow_params = self._attack_param_checklists[method].keys() |
|
|
|
for param_name in params: |
|
|
|
if param_name not in allow_params: |
|
|
|
msg = "parameters of {} must in {}".format(method, allow_params) |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
check_param_type(param_name, params[param_name], list) |
|
|
|
for param_value in params[param_name]: |
|
|
|
if param_name == 'bounds': |
|
|
|
bounds = check_param_multi_types('bounds', param_value, [tuple]) |
|
|
|
for bound_value in bounds: |
|
|
|
_ = check_param_multi_types('bound', bound_value, |
|
|
|
[int, float]) |
|
|
|
if bounds[0] >= bounds[1]: |
|
|
|
msg = "upper bound must more than lower bound, but upper " \ |
|
|
|
"bound got {}, lower bound got {}".format(bounds[0], |
|
|
|
bounds[1]) |
|
|
|
raise ValueError(msg) |
|
|
|
elif param_name == 'norm_level': |
|
|
|
_ = check_norm_level(param_value) |
|
|
|
else: |
|
|
|
allow_type = \ |
|
|
|
self._attack_param_checklists[method][param_name]['dtype'] |
|
|
|
allow_range = \ |
|
|
|
self._attack_param_checklists[method][param_name]['range'] |
|
|
|
_ = check_param_multi_types(str(param_name), param_value, |
|
|
|
allow_type) |
|
|
|
_ = check_param_in_range(str(param_name), param_value, |
|
|
|
allow_range[0], |
|
|
|
allow_range[1]) |
|
|
|
|
|
|
|
def _init_mutates(self, mutate_config): |
|
|
|
""" Check whether the mutate_config meet the specification.""" |
|
|
|
mutates = {} |
|
|
|
for mutate in mutate_config: |
|
|
|
method = mutate['method'] |
|
|
|
params = mutate['params'] |
|
|
|
if method not in self._attacks_list: |
|
|
|
mutates[method] = self._strategies[method]() |
|
|
|
else: |
|
|
|
self._check_attack_params(method, params) |
|
|
|
network = self._target_model._network |
|
|
|
loss_fn = self._target_model._loss_fn |
|
|
|
mutates[method] = self._strategies[method](network, |
|
|
|
loss_fn=loss_fn) |
|
|
|
return mutates |
|
|
|
|
|
|
|
def _evaluate(self, fuzz_samples, gt_labels, fuzz_preds, |
|
|
|
def _evaluate(self, fuzz_samples, true_labels, fuzz_preds, |
|
|
|
fuzz_strategies, metrics): |
|
|
|
""" |
|
|
|
Evaluate generated fuzz_testing samples in three dimention: accuracy, |
|
|
|
Evaluate generated fuzz_testing samples in three dimensions: accuracy, |
|
|
|
attack success rate and neural coverage. |
|
|
|
|
|
|
|
Args: |
|
|
|
fuzz_samples (numpy.ndarray): Generated fuzz_testing samples according to seeds. |
|
|
|
gt_labels (numpy.ndarray): Ground Truth of seeds. |
|
|
|
fuzz_preds (numpy.ndarray): Predictions of generated fuzz samples. |
|
|
|
fuzz_strategies (numpy.ndarray): Mutate strategies of fuzz samples. |
|
|
|
fuzz_samples ([numpy.ndarray, list]): Generated fuzz_testing samples |
|
|
|
according to seeds. |
|
|
|
true_labels ([numpy.ndarray, list]): Ground Truth labels of seeds. |
|
|
|
fuzz_preds ([numpy.ndarray, list]): Predictions of generated fuzz samples. |
|
|
|
fuzz_strategies ([numpy.ndarray, list]): Mutate strategies of fuzz samples. |
|
|
|
metrics (Union[list, tuple, str]): evaluation metrics. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict, evaluate metrics include accuarcy, attack success rate |
|
|
|
dict, evaluate metrics include accuracy, attack success rate |
|
|
|
and neural coverage. |
|
|
|
""" |
|
|
|
gt_labels = np.asarray(gt_labels) |
|
|
|
true_labels = np.asarray(true_labels) |
|
|
|
fuzz_preds = np.asarray(fuzz_preds) |
|
|
|
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) |
|
|
|
temp = np.argmax(true_labels, axis=1) == np.argmax(fuzz_preds, axis=1) |
|
|
|
metrics_report = {} |
|
|
|
if metrics == 'auto' or 'accuracy' in metrics: |
|
|
|
if temp.any(): |
|
|
|