From 9ea649e01b574328226256bb124500d3061df87c Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Tue, 1 Mar 2022 17:03:04 +0800 Subject: [PATCH] modify fuzz to call new image transform method --- .../ai_fuzzer/fuzz_testing_and_model_enhense.py | 139 ++--- examples/ai_fuzzer/lenet5_mnist_fuzzing.py | 67 ++- mindarmour/fuzz_testing/fuzzing.py | 142 ++++- mindarmour/fuzz_testing/image_transform.py | 609 --------------------- tests/ut/python/fuzzing/test_fuzzer.py | 32 +- tests/ut/python/fuzzing/test_image_transform.py | 126 ----- 6 files changed, 250 insertions(+), 865 deletions(-) delete mode 100644 mindarmour/fuzz_testing/image_transform.py delete mode 100644 tests/ut/python/fuzzing/test_image_transform.py diff --git a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py index 84b3e30..5f21bf6 100644 --- a/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py +++ b/examples/ai_fuzzer/fuzz_testing_and_model_enhense.py @@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum from mindarmour.adv_robustness.defenses import AdversarialDefense from mindarmour.fuzz_testing import Fuzzer -from mindarmour.fuzz_testing import ModelCoverageMetrics +from mindarmour.fuzz_testing import KMultisectionNeuronCoverage from mindarmour.utils.logger import LogUtil from examples.common.dataset.data_processing import generate_mnist_dataset @@ -38,33 +38,66 @@ TAG = 'Fuzz_testing and enhance model' LOGGER.set_level('INFO') +def split_dataset(image, label, proportion): + """ + Split the generated fuzz data into train and test set. + """ + indices = np.arange(len(image)) + random.shuffle(indices) + train_length = int(len(image) * proportion) + train_image = [image[i] for i in indices[:train_length]] + train_label = [label[i] for i in indices[:train_length]] + test_image = [image[i] for i in indices[:train_length]] + test_label = [label[i] for i in indices[:train_length]] + return train_image, train_label, test_image, test_label + + def example_lenet_mnist_fuzzing(): """ An example of fuzz testing and then enhance the non-robustness model. """ # upload trained network - ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m1-10_1250.ckpt' + ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' net = LeNet5() load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) model = Model(net) - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, - {'method': 'Contrast', - 'params': {'auto_param': [True]}}, - {'method': 'Translate', - 'params': {'auto_param': [True]}}, - {'method': 'Brightness', - 'params': {'auto_param': [True]}}, - {'method': 'Noise', - 'params': {'auto_param': [True]}}, - {'method': 'Scale', - 'params': {'auto_param': [True]}}, - {'method': 'Shear', - 'params': {'auto_param': [True]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} - ] + mutate_config = [ + {'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + {'method': 'MotionBlur', + 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, + {'method': 'GradientBlur', + 'params': {'point': [[10, 10]], 'auto_param': [True]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'SaltAndPepperNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'NaturalNoise', + 'params': {'ratio': [0.1], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'GradientLuminance', + 'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], + 'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], + 'auto_param': [False, True]}}, + {'method': 'Translate', + 'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, + {'method': 'Scale', + 'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, + {'method': 'Shear', + 'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'Perspective', + 'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], + 'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, + {'method': 'Curve', + 'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] # get training data data_list = "../common/dataset/MNIST/train" @@ -75,49 +108,36 @@ def example_lenet_mnist_fuzzing(): images = data[0].astype(np.float32) train_images.append(images) train_images = np.concatenate(train_images, axis=0) - neuron_num = 10 - segmented_num = 1000 - - # initialize fuzz test with training dataset - model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) + segmented_num = 100 # fuzz test with original test data - # get test data data_list = "../common/dataset/MNIST/test" - batch_size = 32 - init_samples = 5000 - max_iters = 50000 + batch_size = batch_size + init_samples = 50 + max_iters = 500 mutate_num_per_seed = 10 - ds = generate_mnist_dataset(data_list, batch_size, num_samples=init_samples, - sparse=False) + ds = generate_mnist_dataset(data_list, batch_size=batch_size, num_samples=init_samples, sparse=False) test_images = [] test_labels = [] for data in ds.create_tuple_iterator(output_numpy=True): - images = data[0].astype(np.float32) - labels = data[1] - test_images.append(images) - test_labels.append(labels) + test_images.append(data[0].astype(np.float32)) + test_labels.append(data[1]) test_images = np.concatenate(test_images, axis=0) test_labels = np.concatenate(test_labels, axis=0) - initial_seeds = [] + + coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=segmented_num, incremental=True) + kmnc = coverage.get_metrics(test_images[:100]) + print('kmnc: ', kmnc) # make initial seeds + initial_seeds = [] for img, label in zip(test_images, test_labels): initial_seeds.append([img, label]) - model_coverage_test.calculate_coverage( - np.array(test_images[:100]).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of test dataset before fuzzing is : %s', - model_coverage_test.get_kmnc()) - LOGGER.info(TAG, 'NBC of test dataset before fuzzing is : %s', - model_coverage_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of test dataset before fuzzing is : %s', - model_coverage_test.get_snac()) - - model_fuzz_test = Fuzzer(model, train_images, 10, 1000) + model_fuzz_test = Fuzzer(model) gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, - initial_seeds, - eval_metrics='auto', + initial_seeds, coverage, + evaluate=True, max_iters=max_iters, mutate_num_per_seed=mutate_num_per_seed) @@ -125,24 +145,10 @@ def example_lenet_mnist_fuzzing(): for key in metrics: LOGGER.info(TAG, key + ': %s', metrics[key]) - def split_dataset(image, label, proportion): - """ - Split the generated fuzz data into train and test set. - """ - indices = np.arange(len(image)) - random.shuffle(indices) - train_length = int(len(image) * proportion) - train_image = [image[i] for i in indices[:train_length]] - train_label = [label[i] for i in indices[:train_length]] - test_image = [image[i] for i in indices[:train_length]] - test_label = [label[i] for i in indices[:train_length]] - return train_image, train_label, test_image, test_label - - train_image, train_label, test_image, test_label = split_dataset( - gen_samples, gt, 0.7) + train_image, train_label, test_image, test_label = split_dataset(gen_samples, gt, 0.7) # load model B and test it on the test set - ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m2-10_1250.ckpt' + ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' net = LeNet5() load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) @@ -154,12 +160,11 @@ def example_lenet_mnist_fuzzing(): # enhense model robustness lr = 0.001 momentum = 0.9 - loss_fn = SoftmaxCrossEntropyWithLogits(Sparse=True) + loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True) optimizer = Momentum(net.trainable_params(), lr, momentum) adv_defense = AdversarialDefense(net, loss_fn, optimizer) - adv_defense.batch_defense(np.array(train_image).astype(np.float32), - np.argmax(train_label, axis=1).astype(np.int32)) + adv_defense.batch_defense(np.array(train_image).astype(np.float32), np.argmax(train_label, axis=1).astype(np.int32)) preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) print('Accuracy of enhensed model on test set is ', acc_en) @@ -167,5 +172,5 @@ def example_lenet_mnist_fuzzing(): if __name__ == '__main__': # device_target can be "CPU", "GPU" or "Ascend" - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") example_lenet_mnist_fuzzing() diff --git a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py index 3a30b6d..53321c3 100644 --- a/examples/ai_fuzzer/lenet5_mnist_fuzzing.py +++ b/examples/ai_fuzzer/lenet5_mnist_fuzzing.py @@ -35,24 +35,50 @@ def test_lenet_mnist_fuzzing(): load_dict = load_checkpoint(ckpt_path) load_param_into_net(net, load_dict) model = Model(net) - mutate_config = [{'method': 'Blur', - 'params': {'radius': [0.1, 0.2, 0.3], - 'auto_param': [True, False]}}, - {'method': 'Contrast', - 'params': {'auto_param': [True]}}, - {'method': 'Translate', - 'params': {'auto_param': [True]}}, - {'method': 'Brightness', - 'params': {'auto_param': [True]}}, - {'method': 'Noise', - 'params': {'auto_param': [True]}}, - {'method': 'Scale', - 'params': {'auto_param': [True]}}, - {'method': 'Shear', - 'params': {'auto_param': [True]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}} - ] + mutate_config = [ + {'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'MotionBlur', + 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, + {'method': 'GradientBlur', + 'params': {'point': [[10, 10]], 'auto_param': [True]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'SaltAndPepperNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'NaturalNoise', + 'params': {'ratio': [0.1, 0.2, 0.3], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], + 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'GradientLuminance', + 'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], + 'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], + 'auto_param': [False, True]}}, + {'method': 'Translate', + 'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, + {'method': 'Scale', + 'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, + {'method': 'Shear', + 'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'Perspective', + 'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], + 'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, + {'method': 'Curve', + 'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}, + {'method': 'PGD', + 'params': {'eps': [0.1, 0.2, 0.4], 'eps_iter': [0.05, 0.1], 'nb_iter': [1, 3]}}, + {'method': 'MDIIM', + 'params': {'eps': [0.1, 0.2, 0.4], 'prob': [0.5, 0.1], + 'norm_level': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', 'linf']}} + ] # get training data data_list = "../common/dataset/MNIST/train" @@ -88,7 +114,10 @@ def test_lenet_mnist_fuzzing(): print('KMNC of initial seeds is: ', kmnc) initial_seeds = initial_seeds[:100] model_fuzz_test = Fuzzer(model) - _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, + _, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, + initial_seeds, coverage, + evaluate=True, + max_iters=10, mutate_num_per_seed=20) if metrics: diff --git a/mindarmour/fuzz_testing/fuzzing.py b/mindarmour/fuzz_testing/fuzzing.py index 93dafce..3096b8e 100644 --- a/mindarmour/fuzz_testing/fuzzing.py +++ b/mindarmour/fuzz_testing/fuzzing.py @@ -24,10 +24,11 @@ from mindspore import nn from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ check_param_in_range, check_param_type, check_int_positive, check_param_bounds from mindarmour.utils.logger import LogUtil -from ..adv_robustness.attacks import FastGradientSignMethod, \ +from mindarmour.adv_robustness.attacks import FastGradientSignMethod, \ MomentumDiverseInputIterativeMethod, ProjectedGradientDescent -from .image_transform import Contrast, Brightness, Blur, \ - Noise, Translate, Scale, Shear, Rotate +from mindarmour.natural_robustness.transform.image import GaussianBlur, MotionBlur, GradientBlur, UniformNoise,\ + GaussianNoise, SaltAndPepperNoise, NaturalNoise, Contrast, GradientLuminance, Translate, Scale, Shear, Rotate, \ + Perspective, Curve from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage LOGGER = LogUtil.get_instance() @@ -104,17 +105,79 @@ class Fuzzer: target_model (Model): Target fuzz model. Examples: + >>> import numpy as np + >>> from mindspore import context + >>> from mindspore import nn + >>> from mindspore.common.initializer import TruncatedNormal + >>> from mindspore.ops import operations as P + >>> from mindspore.train import Model + >>> from mindspore.ops import TensorSummary + >>> from mindarmour.fuzz_testing import Fuzzer + >>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage + >>> + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") + >>> self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") + >>> self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02)) + >>> self.relu = nn.ReLU() + >>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + >>> self.reshape = P.Reshape() + >>> self.summary = TensorSummary() + >>> + >>> def construct(self, x): + >>> x = self.conv1(x) + >>> x = self.relu(x) + >>> self.summary('conv1', x) + >>> x = self.max_pool2d(x) + >>> x = self.conv2(x) + >>> x = self.relu(x) + >>> self.summary('conv2', x) + >>> x = self.max_pool2d(x) + >>> x = self.reshape(x, (-1, 16 * 5 * 5)) + >>> x = self.fc1(x) + >>> x = self.relu(x) + >>> self.summary('fc1', x) + >>> x = self.fc2(x) + >>> x = self.relu(x) + >>> self.summary('fc2', x) + >>> x = self.fc3(x) + >>> self.summary('fc3', x) + >>> return x + >>> >>> net = Net() >>> model = Model(net) - >>> mutate_config = [{'method': 'Blur', - ... 'params': {'auto_param': [True]}}, + >>> mutate_config = [{'method': 'GaussianBlur', + ... 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + ... {'method': 'MotionBlur', + ... 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], + ... 'auto_param': [True]}}, + ... {'method': 'UniformNoise', + ... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + ... {'method': 'GaussianNoise', + ... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, ... {'method': 'Contrast', - ... 'params': {'factor': [2]}}, - ... {'method': 'Translate', - ... 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, + ... 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + ... {'method': 'Rotate', + ... 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, ... {'method': 'FGSM', - ... 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] - >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) + ... 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] + >>> batch_size = 8 + >>> num_classe = 10 + >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) + >>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) + >>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) + >>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) + >>> initial_seeds = [] + >>> # make initial seeds + >>> for img, label in zip(test_images, test_labels): + >>> initial_seeds.append([img, label]) + + >>> initial_seeds = initial_seeds[:10] + >>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) >>> model_fuzz_test = Fuzzer(model) >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, ... nc, max_iters=100) @@ -125,18 +188,26 @@ class Fuzzer: # Allowed mutate strategies so far. self._strategies = {'Contrast': Contrast, - 'Brightness': Brightness, - 'Blur': Blur, - 'Noise': Noise, + 'GradientLuminance': GradientLuminance, + 'GaussianBlur': GaussianBlur, + 'MotionBlur': MotionBlur, + 'GradientBlur': GradientBlur, + 'UniformNoise': UniformNoise, + 'GaussianNoise': GaussianNoise, + 'SaltAndPepperNoise': SaltAndPepperNoise, + 'NaturalNoise': NaturalNoise, 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate, + 'Perspective': Perspective, + 'Curve': Curve, 'FGSM': FastGradientSignMethod, 'PGD': ProjectedGradientDescent, 'MDIIM': MomentumDiverseInputIterativeMethod} - self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] - self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] + self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate', 'Perspective', 'Curve'] + self._pixel_value_trans_list = ['Contrast', 'GradientLuminance', 'GaussianBlur', 'MotionBlur', 'GradientBlur', + 'UniformNoise', 'GaussianNoise', 'SaltAndPepperNoise', 'NaturalNoise'] self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] self._attack_param_checklists = { 'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, @@ -144,10 +215,11 @@ class Fuzzer: 'bounds': {'dtype': [tuple, list]}}, 'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, 'eps_iter': {'dtype': [float], 'range': [0, 1]}, - 'nb_iter': {'dtype': [int], 'range': [0, 100000]}, + 'nb_iter': {'dtype': [int]}, 'bounds': {'dtype': [tuple, list]}}, 'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, - 'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, + 'norm_level': {'dtype': [str, int], + 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf']}, 'prob': {'dtype': [float], 'range': [0, 1]}, 'bounds': {'dtype': [tuple, list]}}} @@ -157,18 +229,26 @@ class Fuzzer: Args: mutate_config (list): Mutate configs. The format is - [{'method': 'Blur', - 'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}}, - {'method': 'Contrast', - 'params': {'factor': [1, 1.5, 2]}}, - {'method': 'FGSM', - 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, - ...]. + [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'GaussianNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, + {'method': 'Contrast', + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, + {'method': 'FGSM', + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] + ...]. The supported methods list is in `self._strategies`, and the params of each method must within the range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', - 'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. + 'PGD' and 'MDIIM'. 'FGSM', 'PGD' and 'MDIIM'. are abbreviations of FastGradientSignMethod, + ProjectedGradientDescent and MomentumDiverseInputIterativeMethod. + `mutate_config` must have method in the type of pixel value based transform methods. The way of setting parameters for first and second type methods can be seen in 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to `self._attack_param_checklists`. @@ -278,7 +358,6 @@ class Fuzzer: if only_pixel_trans: while strategy['method'] not in self._pixel_value_trans_list: strategy = choice(mutate_config) - transform = mutates[strategy['method']] params = strategy['params'] method = strategy['method'] selected_param = {} @@ -290,9 +369,10 @@ class Fuzzer: shear_keys = selected_param.keys() if 'factor_x' in shear_keys and 'factor_y' in shear_keys: selected_param[choice(['factor_x', 'factor_y'])] = 0 - transform.set_params(**selected_param) - mutate_sample = transform.transform(seed[0]) + transform = mutates[strategy['method']](**selected_param) + mutate_sample = transform(seed[0]) else: + transform = mutates[strategy['method']] for param_name in selected_param: transform.__setattr__('_' + str(param_name), selected_param[param_name]) mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] @@ -360,6 +440,8 @@ class Fuzzer: _ = check_param_bounds('bounds', param_value) elif param_name == 'norm_level': _ = check_norm_level(param_value) + elif param_name == 'nb_iter': + _ = check_int_positive(param_name, param_value) else: allow_type = self._attack_param_checklists[method][param_name]['dtype'] allow_range = self._attack_param_checklists[method][param_name]['range'] @@ -372,7 +454,8 @@ class Fuzzer: for mutate in mutate_config: method = mutate['method'] if method not in self._attacks_list: - mutates[method] = self._strategies[method]() + # mutates[method] = self._strategies[method]() + mutates[method] = self._strategies[method] else: network = self._target_model._network loss_fn = self._target_model._loss_fn @@ -414,7 +497,6 @@ class Fuzzer: else: attack_success_rate = None metrics_report['Attack_success_rate'] = attack_success_rate - metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) return metrics_report diff --git a/mindarmour/fuzz_testing/image_transform.py b/mindarmour/fuzz_testing/image_transform.py deleted file mode 100644 index 52a1136..0000000 --- a/mindarmour/fuzz_testing/image_transform.py +++ /dev/null @@ -1,609 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Image transform -""" -import numpy as np -from PIL import Image, ImageEnhance, ImageFilter - -from mindspore.dataset.vision.py_transforms_util import is_numpy, \ - to_pil, hwc_to_chw -from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_numpy_param -from mindarmour.utils.logger import LogUtil - -LOGGER = LogUtil.get_instance() -TAG = 'Image Transformation' - - -def chw_to_hwc(img): - """ - Transpose the input image; shape (C, H, W) to shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be converted. - - Returns: - img (numpy.ndarray), Converted image. - """ - if is_numpy(img): - return img.transpose(1, 2, 0).copy() - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_hwc(img): - """ - Check if the input image is shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is shape (H, W, C). - """ - if is_numpy(img): - img_shape = np.shape(img) - if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_chw(img): - """ - Check if the input image is shape (H, W, C). - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is shape (H, W, C). - """ - if is_numpy(img): - img_shape = np.shape(img) - if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_rgb(img): - """ - Check if the input image is RGB. - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is RGB. - """ - if is_numpy(img): - img_shape = np.shape(img) - if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): - return True - return False - raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) - - -def is_normalized(img): - """ - Check if the input image is normalized between 0 to 1. - - Args: - img (numpy.ndarray): Image to be checked. - - Returns: - Bool, True if input is normalized between 0 to 1. - """ - if is_numpy(img): - minimal = np.min(img) - maximun = np.max(img) - if minimal >= 0 and maximun <= 1: - return True - return False - raise TypeError('img should be Numpy array. Got {}'.format(type(img))) - - -class ImageTransform: - """ - The abstract base class for all image transform classes. - """ - - def __init__(self): - pass - - def _check(self, image): - """ Check image format. If input image is RGB and its shape - is (C, H, W), it will be transposed to (H, W, C). If the value - of the image is not normalized , it will be normalized between 0 to 1.""" - rgb = is_rgb(image) - chw = False - gray3dim = False - normalized = is_normalized(image) - if rgb: - chw = is_chw(image) - if chw: - image = chw_to_hwc(image) - else: - image = image - else: - if len(np.shape(image)) == 3: - gray3dim = True - image = image[0] - else: - image = image - if normalized: - image = image*255 - return rgb, chw, normalized, gray3dim, np.uint8(image) - - def _original_format(self, image, chw, normalized, gray3dim): - """ Return transformed image with original format. """ - if not is_numpy(image): - image = np.array(image) - if chw: - image = hwc_to_chw(image) - if normalized: - image = image / 255 - if gray3dim: - image = np.expand_dims(image, 0) - return image - - def transform(self, image): - pass - - -class Contrast(ImageTransform): - """ - Contrast of an image. - - Args: - factor (Union[float, int]): Control the contrast of an image. If 1.0, - gives the original image. If 0, gives a gray image. Default: 1. - """ - - def __init__(self, factor=1): - super(Contrast, self).__init__() - self.set_params(factor) - - def set_params(self, factor=1, auto_param=False): - """ - Set contrast parameters. - - Args: - factor (Union[float, int]): Control the contrast of an image. If 1.0 - gives the original image. If 0 gives a gray image. Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(-5, 5) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - img_contrast = ImageEnhance.Contrast(image) - trans_image = img_contrast.enhance(self.factor) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - - return trans_image.astype(ori_dtype) - - -class Brightness(ImageTransform): - """ - Brightness of an image. - - Args: - factor (Union[float, int]): Control the brightness of an image. If 1.0 - gives the original image. If 0 gives a black image. Default: 1. - """ - - def __init__(self, factor=1): - super(Brightness, self).__init__() - self.set_params(factor) - - def set_params(self, factor=1, auto_param=False): - """ - Set brightness parameters. - - Args: - factor (Union[float, int]): Control the brightness of an image. If 1 - gives the original image. If 0 gives a black image. Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(0, 5) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - img_contrast = ImageEnhance.Brightness(image) - trans_image = img_contrast.enhance(self.factor) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Blur(ImageTransform): - """ - Blurs the image using Gaussian blur filter. - - Args: - radius(Union[float, int]): Blur radius, 0 means no blur. Default: 0. - """ - - def __init__(self, radius=0): - super(Blur, self).__init__() - self.set_params(radius) - - def set_params(self, radius=0, auto_param=False): - """ - Set blur parameters. - - Args: - radius (Union[float, int]): Blur radius, 0 means no blur. Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.radius = np.random.uniform(-1.5, 1.5) - else: - self.radius = check_param_multi_types('radius', radius, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - image = to_pil(image) - trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Noise(ImageTransform): - """ - Add noise of an image. - - Args: - factor (float): factor is the ratio of pixels to add noise. - If 0 gives the original image. Default 0. - """ - - def __init__(self, factor=0): - super(Noise, self).__init__() - self.set_params(factor) - - def set_params(self, factor=0, auto_param=False): - """ - Set noise parameters. - - Args: - factor (Union[float, int]): factor is the ratio of pixels to - add noise. If 0 gives the original image. Default 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor = np.random.uniform(0, 1) - else: - self.factor = check_param_multi_types('factor', factor, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image (numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) - trans_image = np.copy(image) - threshold = 1 - self.factor - trans_image[noise < -threshold] = 0 - trans_image[noise > threshold] = 1 - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Translate(ImageTransform): - """ - Translate an image. - - Args: - x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. - Default: 0. - y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. - Default: 0. - """ - - def __init__(self, x_bias=0, y_bias=0): - super(Translate, self).__init__() - self.set_params(x_bias, y_bias) - - def set_params(self, x_bias=0, y_bias=0, auto_param=False): - """ - Set translate parameters. - - Args: - x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0. - y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - x_bias = check_param_in_range('x_bias', x_bias, -1, 1) - y_bias = check_param_in_range('y_bias', y_bias, -1, 1) - self.auto_param = auto_param - if auto_param: - self.x_bias = np.random.uniform(-0.3, 0.3) - self.y_bias = np.random.uniform(-0.3, 0.3) - else: - self.x_bias = check_param_multi_types('x_bias', x_bias, - [int, float]) - self.y_bias = check_param_multi_types('y_bias', y_bias, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - image_shape = np.shape(image) - self.x_bias = image_shape[1]*self.x_bias - self.y_bias = image_shape[0]*self.y_bias - trans_image = img.transform(img.size, Image.AFFINE, - (1, 0, self.x_bias, 0, 1, self.y_bias)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Scale(ImageTransform): - """ - Scale an image in the middle. - - Args: - factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. - Default: 1. - factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. - Default: 1. - """ - - def __init__(self, factor_x=1, factor_y=1): - super(Scale, self).__init__() - self.set_params(factor_x, factor_y) - - def set_params(self, factor_x=1, factor_y=1, auto_param=False): - - """ - Set scale parameters. - - Args: - factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. - Default: 1. - factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. - Default: 1. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.factor_x = np.random.uniform(0.7, 3) - self.factor_y = np.random.uniform(0.7, 3) - else: - self.factor_x = check_param_multi_types('factor_x', factor_x, - [int, float]) - self.factor_y = check_param_multi_types('factor_y', factor_y, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - rgb, chw, normalized, gray3dim, image = self._check(image) - if rgb: - h, w, _ = np.shape(image) - else: - h, w = np.shape(image) - move_x_centor = w / 2*(1 - self.factor_x) - move_y_centor = h / 2*(1 - self.factor_y) - img = to_pil(image) - trans_image = img.transform(img.size, Image.AFFINE, - (self.factor_x, 0, move_x_centor, - 0, self.factor_y, move_y_centor)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Shear(ImageTransform): - """ - Shear an image, for each pixel (x, y) in the sheared image, the new value is - taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then - the sheared image will be rescaled to fit original size. - - Args: - factor_x (Union[float, int]): Shear factor of horizontal direction. - Default: 0. - factor_y (Union[float, int]): Shear factor of vertical direction. - Default: 0. - - """ - - def __init__(self, factor_x=0, factor_y=0): - super(Shear, self).__init__() - self.set_params(factor_x, factor_y) - - def set_params(self, factor_x=0, factor_y=0, auto_param=False): - """ - Set shear parameters. - - Args: - factor_x (Union[float, int]): Shear factor of horizontal direction. - Default: 0. - factor_y (Union[float, int]): Shear factor of vertical direction. - Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if factor_x != 0 and factor_y != 0: - msg = 'At least one of factor_x and factor_y is zero.' - LOGGER.error(TAG, msg) - raise ValueError(msg) - if auto_param: - if np.random.uniform(-1, 1) > 0: - self.factor_x = np.random.uniform(-2, 2) - self.factor_y = 0 - else: - self.factor_x = 0 - self.factor_y = np.random.uniform(-2, 2) - else: - self.factor_x = check_param_multi_types('factor', factor_x, - [int, float]) - self.factor_y = check_param_multi_types('factor', factor_y, - [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - rgb, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - if rgb: - h, w, _ = np.shape(image) - else: - h, w = np.shape(image) - if self.factor_x != 0: - boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] - min_x = min(boarder_x) - max_x = max(boarder_x) - scale = (max_x - min_x) / w - move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 - move_y_cen = h*(1 - scale) / 2 - else: - boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] - min_y = min(boarder_y) - max_y = max(boarder_y) - scale = (max_y - min_y) / h - move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 - move_x_cen = w*(1 - scale) / 2 - trans_image = img.transform(img.size, Image.AFFINE, - (scale, scale*self.factor_x, move_x_cen, - scale*self.factor_y, scale, move_y_cen)) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) - - -class Rotate(ImageTransform): - """ - Rotate an image of degrees counter clockwise around its center. - - Args: - angle(Union[float, int]): Degrees counter clockwise. Default: 0. - """ - - def __init__(self, angle=0): - super(Rotate, self).__init__() - self.set_params(angle) - - def set_params(self, angle=0, auto_param=False): - """ - Set rotate parameters. - - Args: - angle(Union[float, int]): Degrees counter clockwise. Default: 0. - auto_param (bool): True if auto generate parameters. Default: False. - """ - if auto_param: - self.angle = np.random.uniform(0, 360) - else: - self.angle = check_param_multi_types('angle', angle, [int, float]) - - def transform(self, image): - """ - Transform the image. - - Args: - image(numpy.ndarray): Original image to be transformed. - - Returns: - numpy.ndarray, transformed image. - """ - image = check_numpy_param('image', image) - ori_dtype = image.dtype - _, chw, normalized, gray3dim, image = self._check(image) - img = to_pil(image) - trans_image = img.rotate(self.angle, expand=False) - trans_image = self._original_format(trans_image, chw, normalized, - gray3dim) - return trans_image.astype(ori_dtype) diff --git a/tests/ut/python/fuzzing/test_fuzzer.py b/tests/ut/python/fuzzing/test_fuzzer.py index 1d585d4..34be0c5 100644 --- a/tests/ut/python/fuzzing/test_fuzzer.py +++ b/tests/ut/python/fuzzing/test_fuzzer.py @@ -99,15 +99,17 @@ def test_fuzzing_ascend(): model = Model(net) batch_size = 8 num_classe = 10 - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, + mutate_config = [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, {'method': 'Contrast', - 'params': {'factor': [2, 1]}}, - {'method': 'Translate', - 'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, {'method': 'FGSM', - 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} - ] + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) # fuzz test with original test data @@ -142,15 +144,17 @@ def test_fuzzing_cpu(): model = Model(net) batch_size = 8 num_classe = 10 - mutate_config = [{'method': 'Blur', - 'params': {'auto_param': [True]}}, + mutate_config = [{'method': 'GaussianBlur', + 'params': {'ksize': [1, 2, 3, 5], + 'auto_param': [True, False]}}, + {'method': 'UniformNoise', + 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, {'method': 'Contrast', - 'params': {'factor': [2, 1]}}, - {'method': 'Translate', - 'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, + 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, + {'method': 'Rotate', + 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, {'method': 'FGSM', - 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} - ] + 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] # initialize fuzz test with training dataset train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) diff --git a/tests/ut/python/fuzzing/test_image_transform.py b/tests/ut/python/fuzzing/test_image_transform.py deleted file mode 100644 index 9360746..0000000 --- a/tests/ut/python/fuzzing/test_image_transform.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Image transform test. -""" -import numpy as np -import pytest - -from mindarmour.utils.logger import LogUtil -from mindarmour.fuzz_testing.image_transform import Contrast, Brightness, \ - Blur, Noise, Translate, Scale, Shear, Rotate - -LOGGER = LogUtil.get_instance() -TAG = 'Image transform test' -LOGGER.set_level('INFO') - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_contrast(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Contrast() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_brightness(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Brightness() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_blur(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Blur() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_noise(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Noise() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_translate(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Translate() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_shear(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Shear() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_scale(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Scale() - trans.set_params(auto_param=True) - _ = trans.transform(image) - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_rotate(): - image = (np.random.rand(32, 32)).astype(np.float32) - trans = Rotate() - trans.set_params(auto_param=True) - _ = trans.transform(image)