Merge pull request !13 from ZhidanLiu/mastertags/v1.2.1
@@ -0,0 +1,89 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import sys | |||||
import numpy as np | |||||
from mindspore import Model | |||||
from mindspore import context | |||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||||
from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||||
from lenet5_net import LeNet5 | |||||
sys.path.append("..") | |||||
from data_processing import generate_mnist_dataset | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Fuzz_test' | |||||
LOGGER.set_level('INFO') | |||||
def test_lenet_mnist_fuzzing(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
# upload trained network | |||||
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | |||||
load_dict = load_checkpoint(ckpt_name) | |||||
load_param_into_net(net, load_dict) | |||||
model = Model(net) | |||||
# get training data | |||||
data_list = "./MNIST_datasets/train" | |||||
batch_size = 32 | |||||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
train_images = [] | |||||
for data in ds.create_tuple_iterator(): | |||||
images = data[0].astype(np.float32) | |||||
train_images.append(images) | |||||
train_images = np.concatenate(train_images, axis=0) | |||||
# initialize fuzz test with training dataset | |||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
# fuzz test with original test data | |||||
# get test data | |||||
data_list = "./MNIST_datasets/test" | |||||
batch_size = 32 | |||||
ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||||
test_images = [] | |||||
test_labels = [] | |||||
for data in ds.create_tuple_iterator(): | |||||
images = data[0].astype(np.float32) | |||||
labels = data[1] | |||||
test_images.append(images) | |||||
test_labels.append(labels) | |||||
test_images = np.concatenate(test_images, axis=0) | |||||
test_labels = np.concatenate(test_labels, axis=0) | |||||
initial_seeds = [] | |||||
# make initial seeds | |||||
for img, label in zip(test_images, test_labels): | |||||
initial_seeds.append([img, label, 0]) | |||||
initial_seeds = initial_seeds[:100] | |||||
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||||
failed_tests = model_fuzz_test.fuzzing() | |||||
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||||
if __name__ == '__main__': | |||||
test_lenet_mnist_fuzzing() |
@@ -1,3 +1,8 @@ | |||||
""" | |||||
This module includes various metrics to fuzzing the test of DNN. | |||||
""" | |||||
from .fuzzing import Fuzzing | |||||
from .model_coverage_metrics import ModelCoverageMetrics | from .model_coverage_metrics import ModelCoverageMetrics | ||||
__all__ = ['ModelCoverageMetrics'] | |||||
__all__ = ['Fuzzing', | |||||
'ModelCoverageMetrics'] |
@@ -0,0 +1,169 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Fuzzing. | |||||
""" | |||||
import numpy as np | |||||
from random import choice | |||||
from mindspore import Tensor | |||||
from mindspore import Model | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
Translate, Scale, Shear, Rotate | |||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||||
check_int_positive | |||||
class Fuzzing: | |||||
""" | |||||
Fuzzing test framework for deep neural networks. | |||||
Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep | |||||
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | |||||
Args: | |||||
initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0], | |||||
[image, label, 0], ...]. | |||||
target_model (Model): Target fuzz model. | |||||
train_dataset (numpy.ndarray): Training dataset used for determine | |||||
the neurons' output boundaries. | |||||
const_K (int): The number of mutate tests for a seed. | |||||
mode (str): Image mode used in image transform, 'L' means grey graph. | |||||
Default: 'L'. | |||||
""" | |||||
def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||||
mode='L', max_seed_num=1000): | |||||
self.initial_seeds = initial_seeds | |||||
self.target_model = check_model('model', target_model, Model) | |||||
self.train_dataset = check_numpy_param('train_dataset', train_dataset) | |||||
self.K = check_int_positive('const_k', const_K) | |||||
self.mode = mode | |||||
self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||||
self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||||
train_dataset) | |||||
def _image_value_expand(self, image): | |||||
return image*255 | |||||
def _image_value_compress(self, image): | |||||
return image / 255 | |||||
def _metamorphic_mutate(self, seed, try_num=50): | |||||
if self.mode == 'L': | |||||
seed = seed[0] | |||||
info = [seed, seed] | |||||
mutate_tests = [] | |||||
affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||||
pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||||
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||||
'Noise': Noise, | |||||
'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||||
'Rotate': Rotate} | |||||
for _ in range(self.K): | |||||
for _ in range(try_num): | |||||
if (info[0] == info[1]).all(): | |||||
trans_strage = self._random_pick_mutate(affine_trans, | |||||
pixel_value_trans) | |||||
else: | |||||
trans_strage = self._random_pick_mutate(affine_trans, []) | |||||
transform = strages[trans_strage]( | |||||
self._image_value_expand(seed), self.mode) | |||||
transform.random_param() | |||||
mutate_test = transform.transform() | |||||
mutate_test = np.expand_dims( | |||||
self._image_value_compress(mutate_test), 0) | |||||
if self._is_trans_valid(seed, mutate_test): | |||||
if trans_strage in affine_trans: | |||||
info[1] = mutate_test | |||||
mutate_tests.append(mutate_test) | |||||
if len(mutate_tests) == 0: | |||||
mutate_tests.append(seed) | |||||
return np.array(mutate_tests) | |||||
def fuzzing(self, coverage_metric='KMNC'): | |||||
""" | |||||
Fuzzing tests for deep neural networks. | |||||
Args: | |||||
coverage_metric (str): Model coverage metric of neural networks. | |||||
Default: 'KMNC'. | |||||
Returns: | |||||
list, mutated tests mis-predicted by target dnn model. | |||||
""" | |||||
seed = self._select_next() | |||||
failed_tests = [] | |||||
seed_num = 0 | |||||
while len(seed) > 0 and seed_num < self.max_seed_num: | |||||
mutate_tests = self._metamorphic_mutate(seed[0]) | |||||
coverages, results = self._run(mutate_tests, coverage_metric) | |||||
coverage_gains = self._coverage_gains(coverages) | |||||
for mutate, cov, res in zip(mutate_tests, coverage_gains, results): | |||||
if np.argmax(seed[1]) != np.argmax(res): | |||||
failed_tests.append(mutate) | |||||
continue | |||||
if cov > 0: | |||||
self.initial_seeds.append([mutate, seed[1], 0]) | |||||
seed = self._select_next() | |||||
seed_num += 1 | |||||
return failed_tests | |||||
def _coverage_gains(self, coverages): | |||||
gains = [0] + coverages[:-1] | |||||
gains = np.array(coverages) - np.array(gains) | |||||
return gains | |||||
def _run(self, mutate_tests, coverage_metric="KNMC"): | |||||
coverages = [] | |||||
result = self.target_model.predict( | |||||
Tensor(mutate_tests.astype(np.float32))) | |||||
result = result.asnumpy() | |||||
for index in range(len(mutate_tests)): | |||||
mutate = np.expand_dims(mutate_tests[index], 0) | |||||
self.coverage_metrics.test_adequacy_coverage_calculate( | |||||
mutate.astype(np.float32), batch_size=1) | |||||
if coverage_metric == "KMNC": | |||||
coverages.append(self.coverage_metrics.get_kmnc()) | |||||
return coverages, result | |||||
def _select_next(self): | |||||
seed = choice(self.initial_seeds) | |||||
return seed | |||||
def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||||
strage = choice(affine_trans_list + pixel_value_trans_list) | |||||
return strage | |||||
def _is_trans_valid(self, seed, mutate_test): | |||||
is_valid = False | |||||
alpha = 0.02 | |||||
beta = 0.2 | |||||
diff = np.array(seed - mutate_test).flatten() | |||||
size = np.shape(diff)[0] | |||||
L0 = np.linalg.norm(diff, ord=0) | |||||
Linf = np.linalg.norm(diff, ord=np.inf) | |||||
if L0 > alpha*size: | |||||
if Linf < 256: | |||||
is_valid = True | |||||
else: | |||||
if Linf < beta*255: | |||||
is_valid = True | |||||
return is_valid |
@@ -76,6 +76,11 @@ class ModelCoverageMetrics: | |||||
upper_compare_array = np.concatenate( | upper_compare_array = np.concatenate( | ||||
[output, np.array([self._upper_bounds])], axis=0) | [output, np.array([self._upper_bounds])], axis=0) | ||||
self._upper_bounds = np.max(upper_compare_array, axis=0) | self._upper_bounds = np.max(upper_compare_array, axis=0) | ||||
if batches == 0: | |||||
output = self._model.predict(Tensor(train_dataset)).asnumpy() | |||||
self._lower_bounds = np.min(output, axis=0) | |||||
self._upper_bounds = np.max(output, axis=0) | |||||
output_mat.append(output) | |||||
self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | ||||
axis=0) | axis=0) | ||||
@@ -0,0 +1,267 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform | |||||
""" | |||||
import numpy as np | |||||
from PIL import Image, ImageEnhance, ImageFilter | |||||
import random | |||||
from mindarmour.utils._check_param import check_numpy_param | |||||
class ImageTransform: | |||||
""" | |||||
The abstract base class for all image transform classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def random_param(self): | |||||
pass | |||||
def transform(self): | |||||
pass | |||||
class Contrast(ImageTransform): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Contrast, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(-10, 10) | |||||
def transform(self): | |||||
img = Image.fromarray(self.image, self.mode) | |||||
img_contrast = ImageEnhance.Contrast(img) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Brightness(ImageTransform): | |||||
""" | |||||
Brightness of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Brightness, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(-10, 10) | |||||
def transform(self): | |||||
img = Image.fromarray(self.image, self.mode) | |||||
img_contrast = ImageEnhance.Brightness(img) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Blur(ImageTransform): | |||||
""" | |||||
GaussianBlur of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Blur, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.radius = random.uniform(-10, 10) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(self.image, self.mode) | |||||
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Noise(ImageTransform): | |||||
""" | |||||
Add noise of an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Noise, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" random generate parameters """ | |||||
self.factor = random.uniform(-1, 1) | |||||
def transform(self): | |||||
""" Random generate parameters. """ | |||||
noise = np.random.uniform(low=-1, high=1, size=self.image.shape) | |||||
trans_image = np.copy(self.image) | |||||
trans_image[noise < -self.factor] = 0 | |||||
trans_image[noise > self.factor] = 255 | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Translate(ImageTransform): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Translate, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
image_shape = np.shape(self.image) | |||||
self.x_bias = random.uniform(0, image_shape[0]) | |||||
self.y_bias = random.uniform(0, image_shape[1]) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(self.image, self.mode) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Scale(ImageTransform): | |||||
""" | |||||
Scale an image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Scale, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor_x = random.uniform(0, 1) | |||||
self.factor_y = random.uniform(0, 1) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(self.image, self.mode) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(self.factor_x, 0, 0, 0, self.factor_y, 0)) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image | |||||
class Shear(ImageTransform): | |||||
""" | |||||
Shear an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Shear, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.factor = random.uniform(0, 1) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(self.image, self.mode) | |||||
if np.random.random() > 0.5: | |||||
level = -self.factor | |||||
else: | |||||
level = self.factor | |||||
if np.random.random() > 0.5: | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, level, 0, 0, 1, 0)) | |||||
else: | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, 0, level, 1, 0)) | |||||
trans_image = np.array(trans_image, dtype=np.float) | |||||
return trans_image | |||||
class Rotate(ImageTransform): | |||||
""" | |||||
Rotate an image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||||
'L' means grey image. | |||||
""" | |||||
def __init__(self, image, mode): | |||||
super(Rotate, self).__init__() | |||||
self.image = check_numpy_param('image', image) | |||||
self.mode = mode | |||||
def random_param(self): | |||||
""" Random generate parameters. """ | |||||
self.angle = random.uniform(0, 360) | |||||
def transform(self): | |||||
""" Transform the image. """ | |||||
img = Image.fromarray(self.image, self.mode) | |||||
trans_image = img.rotate(self.angle) | |||||
trans_image = np.array(trans_image) | |||||
return trans_image |
@@ -0,0 +1,161 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Model-fuzz coverage test. | |||||
""" | |||||
import numpy as np | |||||
import pytest | |||||
import sys | |||||
from mindspore.train import Model | |||||
from mindspore import nn | |||||
from mindspore.ops import operations as P | |||||
from mindspore import context | |||||
from mindspore.common.initializer import TruncatedNormal | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||||
from mindarmour.fuzzing.fuzzing import Fuzzing | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Fuzzing test' | |||||
LOGGER.set_level('INFO') | |||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
weight = weight_variable() | |||||
return nn.Conv2d(in_channels, out_channels, | |||||
kernel_size=kernel_size, stride=stride, padding=padding, | |||||
weight_init=weight, has_bias=False, pad_mode="valid") | |||||
def fc_with_initialize(input_channels, out_channels): | |||||
weight = weight_variable() | |||||
bias = weight_variable() | |||||
return nn.Dense(input_channels, out_channels, weight, bias) | |||||
def weight_variable(): | |||||
return TruncatedNormal(0.02) | |||||
class Net(nn.Cell): | |||||
""" | |||||
Lenet network | |||||
""" | |||||
def __init__(self): | |||||
super(Net, self).__init__() | |||||
self.conv1 = conv(1, 6, 5) | |||||
self.conv2 = conv(6, 16, 5) | |||||
self.fc1 = fc_with_initialize(16*5*5, 120) | |||||
self.fc2 = fc_with_initialize(120, 84) | |||||
self.fc3 = fc_with_initialize(84, 10) | |||||
self.relu = nn.ReLU() | |||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
self.reshape = P.Reshape() | |||||
def construct(self, x): | |||||
x = self.conv1(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.conv2(x) | |||||
x = self.relu(x) | |||||
x = self.max_pool2d(x) | |||||
x = self.reshape(x, (-1, 16*5*5)) | |||||
x = self.fc1(x) | |||||
x = self.relu(x) | |||||
x = self.fc2(x) | |||||
x = self.relu(x) | |||||
x = self.fc3(x) | |||||
return x | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_fuzzing_ascend(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
# load network | |||||
net = Net() | |||||
model = Model(net) | |||||
batch_size = 8 | |||||
num_classe = 10 | |||||
# initialize fuzz test with training dataset | |||||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||||
# fuzz test with original test data | |||||
# get test data | |||||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
initial_seeds = [] | |||||
for img, label in zip(test_data, test_labels): | |||||
initial_seeds.append([img, label, 0]) | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(test_data).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||||
max_seed_num=10) | |||||
failed_tests = model_fuzz_test.fuzzing() | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(failed_tests).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_fuzzing_ascend(): | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
# load network | |||||
net = Net() | |||||
model = Model(net) | |||||
batch_size = 8 | |||||
num_classe = 10 | |||||
# initialize fuzz test with training dataset | |||||
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||||
# fuzz test with original test data | |||||
# get test data | |||||
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
initial_seeds = [] | |||||
for img, label in zip(test_data, test_labels): | |||||
initial_seeds.append([img, label, 0]) | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(test_data).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||||
max_seed_num=10) | |||||
failed_tests = model_fuzz_test.fuzzing() | |||||
model_coverage_test.test_adequacy_coverage_calculate( | |||||
np.array(failed_tests).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of this test is : %s', | |||||
model_coverage_test.get_kmnc()) |
@@ -0,0 +1,136 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform test. | |||||
""" | |||||
import numpy as np | |||||
import pytest | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||||
Translate, Scale, Shear, Rotate | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image transform test' | |||||
LOGGER.set_level('INFO') | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_contrast(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Contrast(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_brightness(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Brightness(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_blur(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Blur(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_noise(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Noise(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_translate(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Translate(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_shear(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Shear(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_scale(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Scale(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_rotate(): | |||||
image = (np.random.rand(32, 32)*255).astype(np.float32) | |||||
mode = 'L' | |||||
trans = Rotate(image, mode) | |||||
trans.random_param() | |||||
trans_image = trans.transform() | |||||