From f150496e00fe4812948d05eb70a436c0ede43893 Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Wed, 22 Apr 2020 21:57:35 +0800 Subject: [PATCH] add a fuzzing test framework and image transform method --- example/mnist_demo/lenet5_mnist_fuzzing.py | 89 +++++++++ mindarmour/fuzzing/__init__.py | 7 +- mindarmour/fuzzing/fuzzing.py | 169 ++++++++++++++++ mindarmour/fuzzing/model_coverage_metrics.py | 5 + mindarmour/utils/image_transform.py | 267 ++++++++++++++++++++++++++ tests/ut/python/fuzzing/test_fuzzing.py | 161 ++++++++++++++++ tests/ut/python/utils/test_image_transform.py | 136 +++++++++++++ 7 files changed, 833 insertions(+), 1 deletion(-) create mode 100644 example/mnist_demo/lenet5_mnist_fuzzing.py create mode 100644 mindarmour/fuzzing/fuzzing.py create mode 100644 mindarmour/utils/image_transform.py create mode 100644 tests/ut/python/fuzzing/test_fuzzing.py create mode 100644 tests/ut/python/utils/test_image_transform.py diff --git a/example/mnist_demo/lenet5_mnist_fuzzing.py b/example/mnist_demo/lenet5_mnist_fuzzing.py new file mode 100644 index 0000000..d6604fd --- /dev/null +++ b/example/mnist_demo/lenet5_mnist_fuzzing.py @@ -0,0 +1,89 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import numpy as np + +from mindspore import Model +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import SoftmaxCrossEntropyWithLogits + +from mindarmour.attacks.gradient_method import FastGradientSignMethod +from mindarmour.utils.logger import LogUtil +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics +from mindarmour.fuzzing.fuzzing import Fuzzing +from lenet5_net import LeNet5 + +sys.path.append("..") +from data_processing import generate_mnist_dataset + +LOGGER = LogUtil.get_instance() +TAG = 'Fuzz_test' +LOGGER.set_level('INFO') + + +def test_lenet_mnist_fuzzing(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + # upload trained network + ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = LeNet5() + load_dict = load_checkpoint(ckpt_name) + load_param_into_net(net, load_dict) + model = Model(net) + + # get training data + data_list = "./MNIST_datasets/train" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size, sparse=True) + train_images = [] + for data in ds.create_tuple_iterator(): + images = data[0].astype(np.float32) + train_images.append(images) + train_images = np.concatenate(train_images, axis=0) + + # initialize fuzz test with training dataset + model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) + + # fuzz test with original test data + # get test data + data_list = "./MNIST_datasets/test" + batch_size = 32 + ds = generate_mnist_dataset(data_list, batch_size, sparse=True) + test_images = [] + test_labels = [] + for data in ds.create_tuple_iterator(): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + test_images = np.concatenate(test_images, axis=0) + test_labels = np.concatenate(test_labels, axis=0) + initial_seeds = [] + + # make initial seeds + for img, label in zip(test_images, test_labels): + initial_seeds.append([img, label, 0]) + + initial_seeds = initial_seeds[:100] + model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + + model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) + failed_tests = model_fuzz_test.fuzzing() + model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + + +if __name__ == '__main__': + test_lenet_mnist_fuzzing() diff --git a/mindarmour/fuzzing/__init__.py b/mindarmour/fuzzing/__init__.py index c591d2b..cc88f0c 100644 --- a/mindarmour/fuzzing/__init__.py +++ b/mindarmour/fuzzing/__init__.py @@ -1,3 +1,8 @@ +""" +This module includes various metrics to fuzzing the test of DNN. +""" +from .fuzzing import Fuzzing from .model_coverage_metrics import ModelCoverageMetrics -__all__ = ['ModelCoverageMetrics'] \ No newline at end of file +__all__ = ['Fuzzing', + 'ModelCoverageMetrics'] diff --git a/mindarmour/fuzzing/fuzzing.py b/mindarmour/fuzzing/fuzzing.py new file mode 100644 index 0000000..e0e21aa --- /dev/null +++ b/mindarmour/fuzzing/fuzzing.py @@ -0,0 +1,169 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fuzzing. +""" +import numpy as np +from random import choice + +from mindspore import Tensor +from mindspore import Model + +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics +from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ + Translate, Scale, Shear, Rotate +from mindarmour.utils._check_param import check_model, check_numpy_param, \ + check_int_positive + + +class Fuzzing: + """ + Fuzzing test framework for deep neural networks. + + Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep + Neural Networks `_ + + Args: + initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0], + [image, label, 0], ...]. + target_model (Model): Target fuzz model. + train_dataset (numpy.ndarray): Training dataset used for determine + the neurons' output boundaries. + const_K (int): The number of mutate tests for a seed. + mode (str): Image mode used in image transform, 'L' means grey graph. + Default: 'L'. + """ + + def __init__(self, initial_seeds, target_model, train_dataset, const_K, + mode='L', max_seed_num=1000): + self.initial_seeds = initial_seeds + self.target_model = check_model('model', target_model, Model) + self.train_dataset = check_numpy_param('train_dataset', train_dataset) + self.K = check_int_positive('const_k', const_K) + self.mode = mode + self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) + self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, + train_dataset) + + def _image_value_expand(self, image): + return image*255 + + def _image_value_compress(self, image): + return image / 255 + + def _metamorphic_mutate(self, seed, try_num=50): + if self.mode == 'L': + seed = seed[0] + info = [seed, seed] + mutate_tests = [] + affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] + pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] + strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, + 'Noise': Noise, + 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, + 'Rotate': Rotate} + for _ in range(self.K): + for _ in range(try_num): + if (info[0] == info[1]).all(): + trans_strage = self._random_pick_mutate(affine_trans, + pixel_value_trans) + else: + trans_strage = self._random_pick_mutate(affine_trans, []) + transform = strages[trans_strage]( + self._image_value_expand(seed), self.mode) + transform.random_param() + mutate_test = transform.transform() + mutate_test = np.expand_dims( + self._image_value_compress(mutate_test), 0) + + if self._is_trans_valid(seed, mutate_test): + if trans_strage in affine_trans: + info[1] = mutate_test + mutate_tests.append(mutate_test) + if len(mutate_tests) == 0: + mutate_tests.append(seed) + return np.array(mutate_tests) + + def fuzzing(self, coverage_metric='KMNC'): + """ + Fuzzing tests for deep neural networks. + + Args: + coverage_metric (str): Model coverage metric of neural networks. + Default: 'KMNC'. + + Returns: + list, mutated tests mis-predicted by target dnn model. + """ + seed = self._select_next() + failed_tests = [] + seed_num = 0 + while len(seed) > 0 and seed_num < self.max_seed_num: + mutate_tests = self._metamorphic_mutate(seed[0]) + coverages, results = self._run(mutate_tests, coverage_metric) + coverage_gains = self._coverage_gains(coverages) + for mutate, cov, res in zip(mutate_tests, coverage_gains, results): + if np.argmax(seed[1]) != np.argmax(res): + failed_tests.append(mutate) + continue + if cov > 0: + self.initial_seeds.append([mutate, seed[1], 0]) + seed = self._select_next() + seed_num += 1 + + return failed_tests + + def _coverage_gains(self, coverages): + gains = [0] + coverages[:-1] + gains = np.array(coverages) - np.array(gains) + return gains + + def _run(self, mutate_tests, coverage_metric="KNMC"): + coverages = [] + result = self.target_model.predict( + Tensor(mutate_tests.astype(np.float32))) + result = result.asnumpy() + for index in range(len(mutate_tests)): + mutate = np.expand_dims(mutate_tests[index], 0) + self.coverage_metrics.test_adequacy_coverage_calculate( + mutate.astype(np.float32), batch_size=1) + if coverage_metric == "KMNC": + coverages.append(self.coverage_metrics.get_kmnc()) + + return coverages, result + + def _select_next(self): + seed = choice(self.initial_seeds) + return seed + + def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): + strage = choice(affine_trans_list + pixel_value_trans_list) + return strage + + def _is_trans_valid(self, seed, mutate_test): + is_valid = False + alpha = 0.02 + beta = 0.2 + diff = np.array(seed - mutate_test).flatten() + size = np.shape(diff)[0] + L0 = np.linalg.norm(diff, ord=0) + Linf = np.linalg.norm(diff, ord=np.inf) + if L0 > alpha*size: + if Linf < 256: + is_valid = True + else: + if Linf < beta*255: + is_valid = True + + return is_valid diff --git a/mindarmour/fuzzing/model_coverage_metrics.py b/mindarmour/fuzzing/model_coverage_metrics.py index bc8f562..22ce4b1 100644 --- a/mindarmour/fuzzing/model_coverage_metrics.py +++ b/mindarmour/fuzzing/model_coverage_metrics.py @@ -76,6 +76,11 @@ class ModelCoverageMetrics: upper_compare_array = np.concatenate( [output, np.array([self._upper_bounds])], axis=0) self._upper_bounds = np.max(upper_compare_array, axis=0) + if batches == 0: + output = self._model.predict(Tensor(train_dataset)).asnumpy() + self._lower_bounds = np.min(output, axis=0) + self._upper_bounds = np.max(output, axis=0) + output_mat.append(output) self._var = np.std(np.concatenate(np.array(output_mat), axis=0), axis=0) diff --git a/mindarmour/utils/image_transform.py b/mindarmour/utils/image_transform.py new file mode 100644 index 0000000..5cdf0cf --- /dev/null +++ b/mindarmour/utils/image_transform.py @@ -0,0 +1,267 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image transform +""" +import numpy as np +from PIL import Image, ImageEnhance, ImageFilter +import random + +from mindarmour.utils._check_param import check_numpy_param + +class ImageTransform: + """ + The abstract base class for all image transform classes. + """ + + def __init__(self): + pass + + def random_param(self): + pass + + def transform(self): + pass + + +class Contrast(ImageTransform): + """ + Contrast of an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Contrast, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.factor = random.uniform(-10, 10) + + def transform(self): + img = Image.fromarray(self.image, self.mode) + img_contrast = ImageEnhance.Contrast(img) + trans_image = img_contrast.enhance(self.factor) + trans_image = np.array(trans_image) + return trans_image + + +class Brightness(ImageTransform): + """ + Brightness of an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Brightness, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.factor = random.uniform(-10, 10) + + def transform(self): + img = Image.fromarray(self.image, self.mode) + img_contrast = ImageEnhance.Brightness(img) + trans_image = img_contrast.enhance(self.factor) + trans_image = np.array(trans_image) + return trans_image + + +class Blur(ImageTransform): + """ + GaussianBlur of an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Blur, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.radius = random.uniform(-10, 10) + + def transform(self): + """ Transform the image. """ + img = Image.fromarray(self.image, self.mode) + trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) + trans_image = np.array(trans_image) + return trans_image + + +class Noise(ImageTransform): + """ + Add noise of an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Noise, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ random generate parameters """ + self.factor = random.uniform(-1, 1) + + def transform(self): + """ Random generate parameters. """ + noise = np.random.uniform(low=-1, high=1, size=self.image.shape) + trans_image = np.copy(self.image) + trans_image[noise < -self.factor] = 0 + trans_image[noise > self.factor] = 255 + trans_image = np.array(trans_image) + return trans_image + + +class Translate(ImageTransform): + """ + Translate an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Translate, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + image_shape = np.shape(self.image) + self.x_bias = random.uniform(0, image_shape[0]) + self.y_bias = random.uniform(0, image_shape[1]) + + def transform(self): + """ Transform the image. """ + img = Image.fromarray(self.image, self.mode) + trans_image = img.transform(img.size, Image.AFFINE, + (1, 0, self.x_bias, 0, 1, self.y_bias)) + trans_image = np.array(trans_image) + return trans_image + + +class Scale(ImageTransform): + """ + Scale an image. + + Args: + image(numpy.ndarray): Original image to be transformed. + mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Scale, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.factor_x = random.uniform(0, 1) + self.factor_y = random.uniform(0, 1) + + def transform(self): + """ Transform the image. """ + img = Image.fromarray(self.image, self.mode) + trans_image = img.transform(img.size, Image.AFFINE, + (self.factor_x, 0, 0, 0, self.factor_y, 0)) + trans_image = np.array(trans_image) + return trans_image + + +class Shear(ImageTransform): + """ + Shear an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Shear, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.factor = random.uniform(0, 1) + + def transform(self): + """ Transform the image. """ + img = Image.fromarray(self.image, self.mode) + if np.random.random() > 0.5: + level = -self.factor + else: + level = self.factor + if np.random.random() > 0.5: + trans_image = img.transform(img.size, Image.AFFINE, + (1, level, 0, 0, 1, 0)) + else: + trans_image = img.transform(img.size, Image.AFFINE, + (1, 0, 0, level, 1, 0)) + trans_image = np.array(trans_image, dtype=np.float) + return trans_image + + +class Rotate(ImageTransform): + """ + Rotate an image. + + Args: + image (numpy.ndarray): Original image to be transformed. + mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], + 'L' means grey image. + """ + + def __init__(self, image, mode): + super(Rotate, self).__init__() + self.image = check_numpy_param('image', image) + self.mode = mode + + def random_param(self): + """ Random generate parameters. """ + self.angle = random.uniform(0, 360) + + def transform(self): + """ Transform the image. """ + img = Image.fromarray(self.image, self.mode) + trans_image = img.rotate(self.angle) + trans_image = np.array(trans_image) + return trans_image diff --git a/tests/ut/python/fuzzing/test_fuzzing.py b/tests/ut/python/fuzzing/test_fuzzing.py new file mode 100644 index 0000000..6ddf0aa --- /dev/null +++ b/tests/ut/python/fuzzing/test_fuzzing.py @@ -0,0 +1,161 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model-fuzz coverage test. +""" +import numpy as np +import pytest +import sys + +from mindspore.train import Model +from mindspore import nn +from mindspore.ops import operations as P +from mindspore import context +from mindspore.common.initializer import TruncatedNormal + +from mindarmour.utils.logger import LogUtil +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics +from mindarmour.fuzzing.fuzzing import Fuzzing + + +LOGGER = LogUtil.get_instance() +TAG = 'Fuzzing test' +LOGGER.set_level('INFO') + + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + + +def fc_with_initialize(input_channels, out_channels): + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + + +def weight_variable(): + return TruncatedNormal(0.02) + + +class Net(nn.Cell): + """ + Lenet network + """ + def __init__(self): + super(Net, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16*5*5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.reshape(x, (-1, 16*5*5)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_fuzzing_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + # load network + net = Net() + model = Model(net) + batch_size = 8 + num_classe = 10 + + # initialize fuzz test with training dataset + training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) + model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) + + # fuzz test with original test data + # get test data + test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) + test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) + test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) + + initial_seeds = [] + for img, label in zip(test_data, test_labels): + initial_seeds.append([img, label, 0]) + model_coverage_test.test_adequacy_coverage_calculate( + np.array(test_data).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', + model_coverage_test.get_kmnc()) + + model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, + max_seed_num=10) + failed_tests = model_fuzz_test.fuzzing() + model_coverage_test.test_adequacy_coverage_calculate( + np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', + model_coverage_test.get_kmnc()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_fuzzing_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + # load network + net = Net() + model = Model(net) + batch_size = 8 + num_classe = 10 + + # initialize fuzz test with training dataset + training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) + model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) + + # fuzz test with original test data + # get test data + test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) + test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) + test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) + + initial_seeds = [] + for img, label in zip(test_data, test_labels): + initial_seeds.append([img, label, 0]) + model_coverage_test.test_adequacy_coverage_calculate( + np.array(test_data).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', + model_coverage_test.get_kmnc()) + + model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, + max_seed_num=10) + failed_tests = model_fuzz_test.fuzzing() + model_coverage_test.test_adequacy_coverage_calculate( + np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', + model_coverage_test.get_kmnc()) diff --git a/tests/ut/python/utils/test_image_transform.py b/tests/ut/python/utils/test_image_transform.py new file mode 100644 index 0000000..1b49c95 --- /dev/null +++ b/tests/ut/python/utils/test_image_transform.py @@ -0,0 +1,136 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image transform test. +""" +import numpy as np +import pytest + +from mindarmour.utils.logger import LogUtil +from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ + Translate, Scale, Shear, Rotate + +LOGGER = LogUtil.get_instance() +TAG = 'Image transform test' +LOGGER.set_level('INFO') + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_contrast(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Contrast(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_brightness(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Brightness(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_blur(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Blur(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_noise(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Noise(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_translate(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Translate(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_shear(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Shear(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_scale(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Scale(image, mode) + trans.random_param() + trans_image = trans.transform() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_rotate(): + image = (np.random.rand(32, 32)*255).astype(np.float32) + mode = 'L' + trans = Rotate(image, mode) + trans.random_param() + trans_image = trans.transform() + +