@@ -0,0 +1,89 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import sys | |||
import numpy as np | |||
from mindspore import Model | |||
from mindspore import context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||
from lenet5_net import LeNet5 | |||
sys.path.append("..") | |||
from data_processing import generate_mnist_dataset | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Fuzz_test' | |||
LOGGER.set_level('INFO') | |||
def test_lenet_mnist_fuzzing(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# upload trained network | |||
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = LeNet5() | |||
load_dict = load_checkpoint(ckpt_name) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
# get training data | |||
data_list = "./MNIST_datasets/train" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
train_images = [] | |||
for data in ds.create_tuple_iterator(): | |||
images = data[0].astype(np.float32) | |||
train_images.append(images) | |||
train_images = np.concatenate(train_images, axis=0) | |||
# initialize fuzz test with training dataset | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
# fuzz test with original test data | |||
# get test data | |||
data_list = "./MNIST_datasets/test" | |||
batch_size = 32 | |||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
test_images = [] | |||
test_labels = [] | |||
for data in ds.create_tuple_iterator(): | |||
images = data[0].astype(np.float32) | |||
labels = data[1] | |||
test_images.append(images) | |||
test_labels.append(labels) | |||
test_images = np.concatenate(test_images, axis=0) | |||
test_labels = np.concatenate(test_labels, axis=0) | |||
initial_seeds = [] | |||
# make initial seeds | |||
for img, label in zip(test_images, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
initial_seeds = initial_seeds[:100] | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
if __name__ == '__main__': | |||
test_lenet_mnist_fuzzing() |
@@ -1,3 +1,8 @@ | |||
""" | |||
This module includes various metrics to fuzzing the test of DNN. | |||
""" | |||
from .fuzzing import Fuzzing | |||
from .model_coverage_metrics import ModelCoverageMetrics | |||
__all__ = ['ModelCoverageMetrics'] | |||
__all__ = ['Fuzzing', | |||
'ModelCoverageMetrics'] |
@@ -0,0 +1,169 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Fuzzing. | |||
""" | |||
import numpy as np | |||
from random import choice | |||
from mindspore import Tensor | |||
from mindspore import Model | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
Translate, Scale, Shear, Rotate | |||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
check_int_positive | |||
class Fuzzing: | |||
""" | |||
Fuzzing test framework for deep neural networks. | |||
Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep | |||
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | |||
Args: | |||
initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0], | |||
[image, label, 0], ...]. | |||
target_model (Model): Target fuzz model. | |||
train_dataset (numpy.ndarray): Training dataset used for determine | |||
the neurons' output boundaries. | |||
const_K (int): The number of mutate tests for a seed. | |||
mode (str): Image mode used in image transform, 'L' means grey graph. | |||
Default: 'L'. | |||
""" | |||
def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||
mode='L', max_seed_num=1000): | |||
self.initial_seeds = initial_seeds | |||
self.target_model = check_model('model', target_model, Model) | |||
self.train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
self.K = check_int_positive('const_k', const_K) | |||
self.mode = mode | |||
self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||
self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||
train_dataset) | |||
def _image_value_expand(self, image): | |||
return image*255 | |||
def _image_value_compress(self, image): | |||
return image / 255 | |||
def _metamorphic_mutate(self, seed, try_num=50): | |||
if self.mode == 'L': | |||
seed = seed[0] | |||
info = [seed, seed] | |||
mutate_tests = [] | |||
affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||
pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||
'Noise': Noise, | |||
'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||
'Rotate': Rotate} | |||
for _ in range(self.K): | |||
for _ in range(try_num): | |||
if (info[0] == info[1]).all(): | |||
trans_strage = self._random_pick_mutate(affine_trans, | |||
pixel_value_trans) | |||
else: | |||
trans_strage = self._random_pick_mutate(affine_trans, []) | |||
transform = strages[trans_strage]( | |||
self._image_value_expand(seed), self.mode) | |||
transform.random_param() | |||
mutate_test = transform.transform() | |||
mutate_test = np.expand_dims( | |||
self._image_value_compress(mutate_test), 0) | |||
if self._is_trans_valid(seed, mutate_test): | |||
if trans_strage in affine_trans: | |||
info[1] = mutate_test | |||
mutate_tests.append(mutate_test) | |||
if len(mutate_tests) == 0: | |||
mutate_tests.append(seed) | |||
return np.array(mutate_tests) | |||
def fuzzing(self, coverage_metric='KMNC'): | |||
""" | |||
Fuzzing tests for deep neural networks. | |||
Args: | |||
coverage_metric (str): Model coverage metric of neural networks. | |||
Default: 'KMNC'. | |||
Returns: | |||
list, mutated tests mis-predicted by target dnn model. | |||
""" | |||
seed = self._select_next() | |||
failed_tests = [] | |||
seed_num = 0 | |||
while len(seed) > 0 and seed_num < self.max_seed_num: | |||
mutate_tests = self._metamorphic_mutate(seed[0]) | |||
coverages, results = self._run(mutate_tests, coverage_metric) | |||
coverage_gains = self._coverage_gains(coverages) | |||
for mutate, cov, res in zip(mutate_tests, coverage_gains, results): | |||
if np.argmax(seed[1]) != np.argmax(res): | |||
failed_tests.append(mutate) | |||
continue | |||
if cov > 0: | |||
self.initial_seeds.append([mutate, seed[1], 0]) | |||
seed = self._select_next() | |||
seed_num += 1 | |||
return failed_tests | |||
def _coverage_gains(self, coverages): | |||
gains = [0] + coverages[:-1] | |||
gains = np.array(coverages) - np.array(gains) | |||
return gains | |||
def _run(self, mutate_tests, coverage_metric="KNMC"): | |||
coverages = [] | |||
result = self.target_model.predict( | |||
Tensor(mutate_tests.astype(np.float32))) | |||
result = result.asnumpy() | |||
for index in range(len(mutate_tests)): | |||
mutate = np.expand_dims(mutate_tests[index], 0) | |||
self.coverage_metrics.test_adequacy_coverage_calculate( | |||
mutate.astype(np.float32), batch_size=1) | |||
if coverage_metric == "KMNC": | |||
coverages.append(self.coverage_metrics.get_kmnc()) | |||
return coverages, result | |||
def _select_next(self): | |||
seed = choice(self.initial_seeds) | |||
return seed | |||
def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||
strage = choice(affine_trans_list + pixel_value_trans_list) | |||
return strage | |||
def _is_trans_valid(self, seed, mutate_test): | |||
is_valid = False | |||
alpha = 0.02 | |||
beta = 0.2 | |||
diff = np.array(seed - mutate_test).flatten() | |||
size = np.shape(diff)[0] | |||
L0 = np.linalg.norm(diff, ord=0) | |||
Linf = np.linalg.norm(diff, ord=np.inf) | |||
if L0 > alpha*size: | |||
if Linf < 256: | |||
is_valid = True | |||
else: | |||
if Linf < beta*255: | |||
is_valid = True | |||
return is_valid |
@@ -76,6 +76,11 @@ class ModelCoverageMetrics: | |||
upper_compare_array = np.concatenate( | |||
[output, np.array([self._upper_bounds])], axis=0) | |||
self._upper_bounds = np.max(upper_compare_array, axis=0) | |||
if batches == 0: | |||
output = self._model.predict(Tensor(train_dataset)).asnumpy() | |||
self._lower_bounds = np.min(output, axis=0) | |||
self._upper_bounds = np.max(output, axis=0) | |||
output_mat.append(output) | |||
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | |||
axis=0) | |||
@@ -0,0 +1,267 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Image transform | |||
""" | |||
import numpy as np | |||
from PIL import Image, ImageEnhance, ImageFilter | |||
import random | |||
from mindarmour.utils._check_param import check_numpy_param | |||
class ImageTransform: | |||
""" | |||
The abstract base class for all image transform classes. | |||
""" | |||
def __init__(self): | |||
pass | |||
def random_param(self): | |||
pass | |||
def transform(self): | |||
pass | |||
class Contrast(ImageTransform): | |||
""" | |||
Contrast of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Contrast, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(-10, 10) | |||
def transform(self): | |||
img = Image.fromarray(self.image, self.mode) | |||
img_contrast = ImageEnhance.Contrast(img) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Brightness(ImageTransform): | |||
""" | |||
Brightness of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Brightness, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(-10, 10) | |||
def transform(self): | |||
img = Image.fromarray(self.image, self.mode) | |||
img_contrast = ImageEnhance.Brightness(img) | |||
trans_image = img_contrast.enhance(self.factor) | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Blur(ImageTransform): | |||
""" | |||
GaussianBlur of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Blur, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.radius = random.uniform(-10, 10) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(self.image, self.mode) | |||
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Noise(ImageTransform): | |||
""" | |||
Add noise of an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Noise, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" random generate parameters """ | |||
self.factor = random.uniform(-1, 1) | |||
def transform(self): | |||
""" Random generate parameters. """ | |||
noise = np.random.uniform(low=-1, high=1, size=self.image.shape) | |||
trans_image = np.copy(self.image) | |||
trans_image[noise < -self.factor] = 0 | |||
trans_image[noise > self.factor] = 255 | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Translate(ImageTransform): | |||
""" | |||
Translate an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Translate, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
image_shape = np.shape(self.image) | |||
self.x_bias = random.uniform(0, image_shape[0]) | |||
self.y_bias = random.uniform(0, image_shape[1]) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(self.image, self.mode) | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Scale(ImageTransform): | |||
""" | |||
Scale an image. | |||
Args: | |||
image(numpy.ndarray): Original image to be transformed. | |||
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Scale, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor_x = random.uniform(0, 1) | |||
self.factor_y = random.uniform(0, 1) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(self.image, self.mode) | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(self.factor_x, 0, 0, 0, self.factor_y, 0)) | |||
trans_image = np.array(trans_image) | |||
return trans_image | |||
class Shear(ImageTransform): | |||
""" | |||
Shear an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Shear, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.factor = random.uniform(0, 1) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(self.image, self.mode) | |||
if np.random.random() > 0.5: | |||
level = -self.factor | |||
else: | |||
level = self.factor | |||
if np.random.random() > 0.5: | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, level, 0, 0, 1, 0)) | |||
else: | |||
trans_image = img.transform(img.size, Image.AFFINE, | |||
(1, 0, 0, level, 1, 0)) | |||
trans_image = np.array(trans_image, dtype=np.float) | |||
return trans_image | |||
class Rotate(ImageTransform): | |||
""" | |||
Rotate an image. | |||
Args: | |||
image (numpy.ndarray): Original image to be transformed. | |||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
'L' means grey image. | |||
""" | |||
def __init__(self, image, mode): | |||
super(Rotate, self).__init__() | |||
self.image = check_numpy_param('image', image) | |||
self.mode = mode | |||
def random_param(self): | |||
""" Random generate parameters. """ | |||
self.angle = random.uniform(0, 360) | |||
def transform(self): | |||
""" Transform the image. """ | |||
img = Image.fromarray(self.image, self.mode) | |||
trans_image = img.rotate(self.angle) | |||
trans_image = np.array(trans_image) | |||
return trans_image |
@@ -0,0 +1,161 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Model-fuzz coverage test. | |||
""" | |||
import numpy as np | |||
import pytest | |||
import sys | |||
from mindspore.train import Model | |||
from mindspore import nn | |||
from mindspore.ops import operations as P | |||
from mindspore import context | |||
from mindspore.common.initializer import TruncatedNormal | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Fuzzing test' | |||
LOGGER.set_level('INFO') | |||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
weight = weight_variable() | |||
return nn.Conv2d(in_channels, out_channels, | |||
kernel_size=kernel_size, stride=stride, padding=padding, | |||
weight_init=weight, has_bias=False, pad_mode="valid") | |||
def fc_with_initialize(input_channels, out_channels): | |||
weight = weight_variable() | |||
bias = weight_variable() | |||
return nn.Dense(input_channels, out_channels, weight, bias) | |||
def weight_variable(): | |||
return TruncatedNormal(0.02) | |||
class Net(nn.Cell): | |||
""" | |||
Lenet network | |||
""" | |||
def __init__(self): | |||
super(Net, self).__init__() | |||
self.conv1 = conv(1, 6, 5) | |||
self.conv2 = conv(6, 16, 5) | |||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||
self.fc2 = fc_with_initialize(120, 84) | |||
self.fc3 = fc_with_initialize(84, 10) | |||
self.relu = nn.ReLU() | |||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
self.reshape = P.Reshape() | |||
def construct(self, x): | |||
x = self.conv1(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.conv2(x) | |||
x = self.relu(x) | |||
x = self.max_pool2d(x) | |||
x = self.reshape(x, (-1, 16*5*5)) | |||
x = self.fc1(x) | |||
x = self.relu(x) | |||
x = self.fc2(x) | |||
x = self.relu(x) | |||
x = self.fc3(x) | |||
return x | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_ascend(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
# initialize fuzz test with training dataset | |||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
for img, label in zip(test_data, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(test_data).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
max_seed_num=10) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fuzzing_ascend(): | |||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
# load network | |||
net = Net() | |||
model = Model(net) | |||
batch_size = 8 | |||
num_classe = 10 | |||
# initialize fuzz test with training dataset | |||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
# fuzz test with original test data | |||
# get test data | |||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
initial_seeds = [] | |||
for img, label in zip(test_data, test_labels): | |||
initial_seeds.append([img, label, 0]) | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(test_data).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) | |||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
max_seed_num=10) | |||
failed_tests = model_fuzz_test.fuzzing() | |||
model_coverage_test.test_adequacy_coverage_calculate( | |||
np.array(failed_tests).astype(np.float32)) | |||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
model_coverage_test.get_kmnc()) |
@@ -0,0 +1,136 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Image transform test. | |||
""" | |||
import numpy as np | |||
import pytest | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
Translate, Scale, Shear, Rotate | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Image transform test' | |||
LOGGER.set_level('INFO') | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_contrast(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Contrast(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_brightness(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Brightness(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_blur(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Blur(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_noise(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Noise(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_translate(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Translate(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_shear(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Shear(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_scale(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Scale(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_rotate(): | |||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
mode = 'L' | |||
trans = Rotate(image, mode) | |||
trans.random_param() | |||
trans_image = trans.transform() | |||