@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum | |||||
from mindarmour.adv_robustness.defenses import AdversarialDefense | from mindarmour.adv_robustness.defenses import AdversarialDefense | ||||
from mindarmour.fuzz_testing import Fuzzer | from mindarmour.fuzz_testing import Fuzzer | ||||
from mindarmour.fuzz_testing import ModelCoverageMetrics | |||||
from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from examples.common.dataset.data_processing import generate_mnist_dataset | from examples.common.dataset.data_processing import generate_mnist_dataset | ||||
@@ -38,33 +38,66 @@ TAG = 'Fuzz_testing and enhance model' | |||||
LOGGER.set_level('INFO') | LOGGER.set_level('INFO') | ||||
def split_dataset(image, label, proportion): | |||||
""" | |||||
Split the generated fuzz data into train and test set. | |||||
""" | |||||
indices = np.arange(len(image)) | |||||
random.shuffle(indices) | |||||
train_length = int(len(image) * proportion) | |||||
train_image = [image[i] for i in indices[:train_length]] | |||||
train_label = [label[i] for i in indices[:train_length]] | |||||
test_image = [image[i] for i in indices[:train_length]] | |||||
test_label = [label[i] for i in indices[:train_length]] | |||||
return train_image, train_label, test_image, test_label | |||||
def example_lenet_mnist_fuzzing(): | def example_lenet_mnist_fuzzing(): | ||||
""" | """ | ||||
An example of fuzz testing and then enhance the non-robustness model. | An example of fuzz testing and then enhance the non-robustness model. | ||||
""" | """ | ||||
# upload trained network | # upload trained network | ||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m1-10_1250.ckpt' | |||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | net = LeNet5() | ||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
model = Model(net) | model = Model(net) | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Brightness', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Noise', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}} | |||||
] | |||||
mutate_config = [ | |||||
{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
{'method': 'MotionBlur', | |||||
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, | |||||
{'method': 'GradientBlur', | |||||
'params': {'point': [[10, 10]], 'auto_param': [True]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'SaltAndPepperNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'NaturalNoise', | |||||
'params': {'ratio': [0.1], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'GradientLuminance', | |||||
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], | |||||
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'Perspective', | |||||
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], | |||||
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, | |||||
{'method': 'Curve', | |||||
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
# get training data | # get training data | ||||
data_list = "../common/dataset/MNIST/train" | data_list = "../common/dataset/MNIST/train" | ||||
@@ -75,49 +108,36 @@ def example_lenet_mnist_fuzzing(): | |||||
images = data[0].astype(np.float32) | images = data[0].astype(np.float32) | ||||
train_images.append(images) | train_images.append(images) | ||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
neuron_num = 10 | |||||
segmented_num = 1000 | |||||
# initialize fuzz test with training dataset | |||||
model_coverage_test = ModelCoverageMetrics(model, neuron_num, segmented_num, train_images) | |||||
segmented_num = 100 | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | |||||
data_list = "../common/dataset/MNIST/test" | data_list = "../common/dataset/MNIST/test" | ||||
batch_size = 32 | |||||
init_samples = 5000 | |||||
max_iters = 50000 | |||||
batch_size = batch_size | |||||
init_samples = 50 | |||||
max_iters = 500 | |||||
mutate_num_per_seed = 10 | mutate_num_per_seed = 10 | ||||
ds = generate_mnist_dataset(data_list, batch_size, num_samples=init_samples, | |||||
sparse=False) | |||||
ds = generate_mnist_dataset(data_list, batch_size=batch_size, num_samples=init_samples, sparse=False) | |||||
test_images = [] | test_images = [] | ||||
test_labels = [] | test_labels = [] | ||||
for data in ds.create_tuple_iterator(output_numpy=True): | for data in ds.create_tuple_iterator(output_numpy=True): | ||||
images = data[0].astype(np.float32) | |||||
labels = data[1] | |||||
test_images.append(images) | |||||
test_labels.append(labels) | |||||
test_images.append(data[0].astype(np.float32)) | |||||
test_labels.append(data[1]) | |||||
test_images = np.concatenate(test_images, axis=0) | test_images = np.concatenate(test_images, axis=0) | ||||
test_labels = np.concatenate(test_labels, axis=0) | test_labels = np.concatenate(test_labels, axis=0) | ||||
initial_seeds = [] | |||||
coverage = KMultisectionNeuronCoverage(model, train_images, segmented_num=segmented_num, incremental=True) | |||||
kmnc = coverage.get_metrics(test_images[:100]) | |||||
print('kmnc: ', kmnc) | |||||
# make initial seeds | # make initial seeds | ||||
initial_seeds = [] | |||||
for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
initial_seeds.append([img, label]) | initial_seeds.append([img, label]) | ||||
model_coverage_test.calculate_coverage( | |||||
np.array(test_images[:100]).astype(np.float32)) | |||||
LOGGER.info(TAG, 'KMNC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_kmnc()) | |||||
LOGGER.info(TAG, 'NBC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_nbc()) | |||||
LOGGER.info(TAG, 'SNAC of test dataset before fuzzing is : %s', | |||||
model_coverage_test.get_snac()) | |||||
model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||||
model_fuzz_test = Fuzzer(model) | |||||
gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | gen_samples, gt, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | ||||
initial_seeds, | |||||
eval_metrics='auto', | |||||
initial_seeds, coverage, | |||||
evaluate=True, | |||||
max_iters=max_iters, | max_iters=max_iters, | ||||
mutate_num_per_seed=mutate_num_per_seed) | mutate_num_per_seed=mutate_num_per_seed) | ||||
@@ -125,24 +145,10 @@ def example_lenet_mnist_fuzzing(): | |||||
for key in metrics: | for key in metrics: | ||||
LOGGER.info(TAG, key + ': %s', metrics[key]) | LOGGER.info(TAG, key + ': %s', metrics[key]) | ||||
def split_dataset(image, label, proportion): | |||||
""" | |||||
Split the generated fuzz data into train and test set. | |||||
""" | |||||
indices = np.arange(len(image)) | |||||
random.shuffle(indices) | |||||
train_length = int(len(image) * proportion) | |||||
train_image = [image[i] for i in indices[:train_length]] | |||||
train_label = [label[i] for i in indices[:train_length]] | |||||
test_image = [image[i] for i in indices[:train_length]] | |||||
test_label = [label[i] for i in indices[:train_length]] | |||||
return train_image, train_label, test_image, test_label | |||||
train_image, train_label, test_image, test_label = split_dataset( | |||||
gen_samples, gt, 0.7) | |||||
train_image, train_label, test_image, test_label = split_dataset(gen_samples, gt, 0.7) | |||||
# load model B and test it on the test set | # load model B and test it on the test set | ||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/lenet_m2-10_1250.ckpt' | |||||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
net = LeNet5() | net = LeNet5() | ||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
@@ -154,12 +160,11 @@ def example_lenet_mnist_fuzzing(): | |||||
# enhense model robustness | # enhense model robustness | ||||
lr = 0.001 | lr = 0.001 | ||||
momentum = 0.9 | momentum = 0.9 | ||||
loss_fn = SoftmaxCrossEntropyWithLogits(Sparse=True) | |||||
loss_fn = SoftmaxCrossEntropyWithLogits(sparse=True) | |||||
optimizer = Momentum(net.trainable_params(), lr, momentum) | optimizer = Momentum(net.trainable_params(), lr, momentum) | ||||
adv_defense = AdversarialDefense(net, loss_fn, optimizer) | adv_defense = AdversarialDefense(net, loss_fn, optimizer) | ||||
adv_defense.batch_defense(np.array(train_image).astype(np.float32), | |||||
np.argmax(train_label, axis=1).astype(np.int32)) | |||||
adv_defense.batch_defense(np.array(train_image).astype(np.float32), np.argmax(train_label, axis=1).astype(np.int32)) | |||||
preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() | preds_en = net(Tensor(test_image, dtype=mindspore.float32)).asnumpy() | ||||
acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) | acc_en = np.sum(np.argmax(preds_en, axis=1) == np.argmax(test_label, axis=1)) / len(test_label) | ||||
print('Accuracy of enhensed model on test set is ', acc_en) | print('Accuracy of enhensed model on test set is ', acc_en) | ||||
@@ -167,5 +172,5 @@ def example_lenet_mnist_fuzzing(): | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
# device_target can be "CPU", "GPU" or "Ascend" | # device_target can be "CPU", "GPU" or "Ascend" | ||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
example_lenet_mnist_fuzzing() | example_lenet_mnist_fuzzing() |
@@ -35,24 +35,50 @@ def test_lenet_mnist_fuzzing(): | |||||
load_dict = load_checkpoint(ckpt_path) | load_dict = load_checkpoint(ckpt_path) | ||||
load_param_into_net(net, load_dict) | load_param_into_net(net, load_dict) | ||||
model = Model(net) | model = Model(net) | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'radius': [0.1, 0.2, 0.3], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Brightness', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Noise', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'auto_param': [True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}} | |||||
] | |||||
mutate_config = [ | |||||
{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'MotionBlur', | |||||
'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], 'auto_param': [True]}}, | |||||
{'method': 'GradientBlur', | |||||
'params': {'point': [[10, 10]], 'auto_param': [True]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'SaltAndPepperNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'NaturalNoise', | |||||
'params': {'ratio': [0.1, 0.2, 0.3], 'k_x_range': [(1, 3), (1, 5)], 'k_y_range': [(1, 5)], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'GradientLuminance', | |||||
'params': {'color_start': [(0, 0, 0)], 'color_end': [(255, 255, 255)], 'start_point': [(10, 10)], | |||||
'scope': [0.5], 'pattern': ['light'], 'bright_rate': [0.3], 'mode': ['circle'], | |||||
'auto_param': [False, True]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0, 0.05, -0.05], 'y_bias': [0, -0.05, 0.05], 'auto_param': [False, True]}}, | |||||
{'method': 'Scale', | |||||
'params': {'factor_x': [1, 0.9], 'factor_y': [1, 0.9], 'auto_param': [False, True]}}, | |||||
{'method': 'Shear', | |||||
'params': {'factor': [0.2, 0.1], 'direction': ['horizontal', 'vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'Perspective', | |||||
'params': {'ori_pos': [[[0, 0], [0, 800], [800, 0], [800, 800]]], | |||||
'dst_pos': [[[50, 0], [0, 800], [780, 0], [800, 800]]], 'auto_param': [False, True]}}, | |||||
{'method': 'Curve', | |||||
'params': {'curves': [5], 'depth': [2], 'mode': ['vertical'], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}, | |||||
{'method': 'PGD', | |||||
'params': {'eps': [0.1, 0.2, 0.4], 'eps_iter': [0.05, 0.1], 'nb_iter': [1, 3]}}, | |||||
{'method': 'MDIIM', | |||||
'params': {'eps': [0.1, 0.2, 0.4], 'prob': [0.5, 0.1], | |||||
'norm_level': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf', 'linf']}} | |||||
] | |||||
# get training data | # get training data | ||||
data_list = "../common/dataset/MNIST/train" | data_list = "../common/dataset/MNIST/train" | ||||
@@ -88,7 +114,10 @@ def test_lenet_mnist_fuzzing(): | |||||
print('KMNC of initial seeds is: ', kmnc) | print('KMNC of initial seeds is: ', kmnc) | ||||
initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
model_fuzz_test = Fuzzer(model) | model_fuzz_test = Fuzzer(model) | ||||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, coverage, evaluate=True, max_iters=10, | |||||
_, _, _, _, metrics = model_fuzz_test.fuzzing(mutate_config, | |||||
initial_seeds, coverage, | |||||
evaluate=True, | |||||
max_iters=10, | |||||
mutate_num_per_seed=20) | mutate_num_per_seed=20) | ||||
if metrics: | if metrics: | ||||
@@ -24,10 +24,11 @@ from mindspore import nn | |||||
from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, check_param_multi_types, check_norm_level, \ | ||||
check_param_in_range, check_param_type, check_int_positive, check_param_bounds | check_param_in_range, check_param_type, check_int_positive, check_param_bounds | ||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
from ..adv_robustness.attacks import FastGradientSignMethod, \ | |||||
from mindarmour.adv_robustness.attacks import FastGradientSignMethod, \ | |||||
MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | MomentumDiverseInputIterativeMethod, ProjectedGradientDescent | ||||
from .image_transform import Contrast, Brightness, Blur, \ | |||||
Noise, Translate, Scale, Shear, Rotate | |||||
from mindarmour.natural_robustness.transform.image import GaussianBlur, MotionBlur, GradientBlur, UniformNoise,\ | |||||
GaussianNoise, SaltAndPepperNoise, NaturalNoise, Contrast, GradientLuminance, Translate, Scale, Shear, Rotate, \ | |||||
Perspective, Curve | |||||
from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage | from .model_coverage_metrics import CoverageMetrics, KMultisectionNeuronCoverage | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
@@ -104,17 +105,79 @@ class Fuzzer: | |||||
target_model (Model): Target fuzz model. | target_model (Model): Target fuzz model. | ||||
Examples: | Examples: | ||||
>>> import numpy as np | |||||
>>> from mindspore import context | |||||
>>> from mindspore import nn | |||||
>>> from mindspore.common.initializer import TruncatedNormal | |||||
>>> from mindspore.ops import operations as P | |||||
>>> from mindspore.train import Model | |||||
>>> from mindspore.ops import TensorSummary | |||||
>>> from mindarmour.fuzz_testing import Fuzzer | |||||
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage | |||||
>>> | |||||
>>> class Net(nn.Cell): | |||||
>>> def __init__(self): | |||||
>>> super(Net, self).__init__() | |||||
>>> self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") | |||||
>>> self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid") | |||||
>>> self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02)) | |||||
>>> self.relu = nn.ReLU() | |||||
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||||
>>> self.reshape = P.Reshape() | |||||
>>> self.summary = TensorSummary() | |||||
>>> | |||||
>>> def construct(self, x): | |||||
>>> x = self.conv1(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('conv1', x) | |||||
>>> x = self.max_pool2d(x) | |||||
>>> x = self.conv2(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('conv2', x) | |||||
>>> x = self.max_pool2d(x) | |||||
>>> x = self.reshape(x, (-1, 16 * 5 * 5)) | |||||
>>> x = self.fc1(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('fc1', x) | |||||
>>> x = self.fc2(x) | |||||
>>> x = self.relu(x) | |||||
>>> self.summary('fc2', x) | |||||
>>> x = self.fc3(x) | |||||
>>> self.summary('fc3', x) | |||||
>>> return x | |||||
>>> | |||||
>>> net = Net() | >>> net = Net() | ||||
>>> model = Model(net) | >>> model = Model(net) | ||||
>>> mutate_config = [{'method': 'Blur', | |||||
... 'params': {'auto_param': [True]}}, | |||||
>>> mutate_config = [{'method': 'GaussianBlur', | |||||
... 'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
... {'method': 'MotionBlur', | |||||
... 'params': {'degree': [1, 2, 5], 'angle': [45, 10, 100, 140, 210, 270, 300], | |||||
... 'auto_param': [True]}}, | |||||
... {'method': 'UniformNoise', | |||||
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
... {'method': 'GaussianNoise', | |||||
... 'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
... {'method': 'Contrast', | ... {'method': 'Contrast', | ||||
... 'params': {'factor': [2]}}, | |||||
... {'method': 'Translate', | |||||
... 'params': {'x_bias': [0.1, 0.2], 'y_bias': [0.2]}}, | |||||
... 'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
... {'method': 'Rotate', | |||||
... 'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
... {'method': 'FGSM', | ... {'method': 'FGSM', | ||||
... 'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}}] | |||||
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100) | |||||
... 'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
>>> batch_size = 8 | |||||
>>> num_classe = 10 | |||||
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||||
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||||
>>> test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||||
>>> test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||||
>>> initial_seeds = [] | |||||
>>> # make initial seeds | |||||
>>> for img, label in zip(test_images, test_labels): | |||||
>>> initial_seeds.append([img, label]) | |||||
>>> initial_seeds = initial_seeds[:10] | |||||
>>> nc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100, incremental=True) | |||||
>>> model_fuzz_test = Fuzzer(model) | >>> model_fuzz_test = Fuzzer(model) | ||||
>>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | >>> samples, gt_labels, preds, strategies, metrics = model_fuzz_test.fuzzing(mutate_config, initial_seeds, | ||||
... nc, max_iters=100) | ... nc, max_iters=100) | ||||
@@ -125,18 +188,26 @@ class Fuzzer: | |||||
# Allowed mutate strategies so far. | # Allowed mutate strategies so far. | ||||
self._strategies = {'Contrast': Contrast, | self._strategies = {'Contrast': Contrast, | ||||
'Brightness': Brightness, | |||||
'Blur': Blur, | |||||
'Noise': Noise, | |||||
'GradientLuminance': GradientLuminance, | |||||
'GaussianBlur': GaussianBlur, | |||||
'MotionBlur': MotionBlur, | |||||
'GradientBlur': GradientBlur, | |||||
'UniformNoise': UniformNoise, | |||||
'GaussianNoise': GaussianNoise, | |||||
'SaltAndPepperNoise': SaltAndPepperNoise, | |||||
'NaturalNoise': NaturalNoise, | |||||
'Translate': Translate, | 'Translate': Translate, | ||||
'Scale': Scale, | 'Scale': Scale, | ||||
'Shear': Shear, | 'Shear': Shear, | ||||
'Rotate': Rotate, | 'Rotate': Rotate, | ||||
'Perspective': Perspective, | |||||
'Curve': Curve, | |||||
'FGSM': FastGradientSignMethod, | 'FGSM': FastGradientSignMethod, | ||||
'PGD': ProjectedGradientDescent, | 'PGD': ProjectedGradientDescent, | ||||
'MDIIM': MomentumDiverseInputIterativeMethod} | 'MDIIM': MomentumDiverseInputIterativeMethod} | ||||
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||||
self._pixel_value_trans_list = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||||
self._affine_trans_list = ['Translate', 'Scale', 'Shear', 'Rotate', 'Perspective', 'Curve'] | |||||
self._pixel_value_trans_list = ['Contrast', 'GradientLuminance', 'GaussianBlur', 'MotionBlur', 'GradientBlur', | |||||
'UniformNoise', 'GaussianNoise', 'SaltAndPepperNoise', 'NaturalNoise'] | |||||
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | ||||
self._attack_param_checklists = { | self._attack_param_checklists = { | ||||
'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'FGSM': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
@@ -144,10 +215,11 @@ class Fuzzer: | |||||
'bounds': {'dtype': [tuple, list]}}, | 'bounds': {'dtype': [tuple, list]}}, | ||||
'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'PGD': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
'eps_iter': {'dtype': [float], 'range': [0, 1]}, | 'eps_iter': {'dtype': [float], 'range': [0, 1]}, | ||||
'nb_iter': {'dtype': [int], 'range': [0, 100000]}, | |||||
'nb_iter': {'dtype': [int]}, | |||||
'bounds': {'dtype': [tuple, list]}}, | 'bounds': {'dtype': [tuple, list]}}, | ||||
'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, | 'MDIIM': {'eps': {'dtype': [float], 'range': [0, 1]}, | ||||
'norm_level': {'dtype': [str, int], 'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'np.inf']}, | |||||
'norm_level': {'dtype': [str, int], | |||||
'range': [1, 2, '1', '2', 'l1', 'l2', 'inf', 'linf', 'np.inf']}, | |||||
'prob': {'dtype': [float], 'range': [0, 1]}, | 'prob': {'dtype': [float], 'range': [0, 1]}, | ||||
'bounds': {'dtype': [tuple, list]}}} | 'bounds': {'dtype': [tuple, list]}}} | ||||
@@ -157,18 +229,26 @@ class Fuzzer: | |||||
Args: | Args: | ||||
mutate_config (list): Mutate configs. The format is | mutate_config (list): Mutate configs. The format is | ||||
[{'method': 'Blur', | |||||
'params': {'radius': [0.1, 0.2], 'auto_param': [True, False]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'factor': [1, 1.5, 2]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1]}}, | |||||
...]. | |||||
[{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], 'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'GaussianNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
...]. | |||||
The supported methods list is in `self._strategies`, and the params of each method must within the | The supported methods list is in `self._strategies`, and the params of each method must within the | ||||
range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based | range of optional parameters. Supported methods are grouped in three types: Firstly, pixel value based | ||||
transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform | transform methods include: 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine transform | ||||
methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', | methods include: 'Translate', 'Scale', 'Shear' and 'Rotate'. Thirdly, attack methods include: 'FGSM', | ||||
'PGD' and 'MDIIM'. `mutate_config` must have method in the type of pixel value based transform methods. | |||||
'PGD' and 'MDIIM'. 'FGSM', 'PGD' and 'MDIIM'. are abbreviations of FastGradientSignMethod, | |||||
ProjectedGradientDescent and MomentumDiverseInputIterativeMethod. | |||||
`mutate_config` must have method in the type of pixel value based transform methods. | |||||
The way of setting parameters for first and second type methods can be seen in | The way of setting parameters for first and second type methods can be seen in | ||||
'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to | 'mindarmour/fuzz_testing/image_transform.py'. For third type methods, the optional parameters refer to | ||||
`self._attack_param_checklists`. | `self._attack_param_checklists`. | ||||
@@ -278,7 +358,6 @@ class Fuzzer: | |||||
if only_pixel_trans: | if only_pixel_trans: | ||||
while strategy['method'] not in self._pixel_value_trans_list: | while strategy['method'] not in self._pixel_value_trans_list: | ||||
strategy = choice(mutate_config) | strategy = choice(mutate_config) | ||||
transform = mutates[strategy['method']] | |||||
params = strategy['params'] | params = strategy['params'] | ||||
method = strategy['method'] | method = strategy['method'] | ||||
selected_param = {} | selected_param = {} | ||||
@@ -290,9 +369,10 @@ class Fuzzer: | |||||
shear_keys = selected_param.keys() | shear_keys = selected_param.keys() | ||||
if 'factor_x' in shear_keys and 'factor_y' in shear_keys: | if 'factor_x' in shear_keys and 'factor_y' in shear_keys: | ||||
selected_param[choice(['factor_x', 'factor_y'])] = 0 | selected_param[choice(['factor_x', 'factor_y'])] = 0 | ||||
transform.set_params(**selected_param) | |||||
mutate_sample = transform.transform(seed[0]) | |||||
transform = mutates[strategy['method']](**selected_param) | |||||
mutate_sample = transform(seed[0]) | |||||
else: | else: | ||||
transform = mutates[strategy['method']] | |||||
for param_name in selected_param: | for param_name in selected_param: | ||||
transform.__setattr__('_' + str(param_name), selected_param[param_name]) | transform.__setattr__('_' + str(param_name), selected_param[param_name]) | ||||
mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] | mutate_sample = transform.generate(np.array([seed[0].astype(np.float32)]), np.array([seed[1]]))[0] | ||||
@@ -360,6 +440,8 @@ class Fuzzer: | |||||
_ = check_param_bounds('bounds', param_value) | _ = check_param_bounds('bounds', param_value) | ||||
elif param_name == 'norm_level': | elif param_name == 'norm_level': | ||||
_ = check_norm_level(param_value) | _ = check_norm_level(param_value) | ||||
elif param_name == 'nb_iter': | |||||
_ = check_int_positive(param_name, param_value) | |||||
else: | else: | ||||
allow_type = self._attack_param_checklists[method][param_name]['dtype'] | allow_type = self._attack_param_checklists[method][param_name]['dtype'] | ||||
allow_range = self._attack_param_checklists[method][param_name]['range'] | allow_range = self._attack_param_checklists[method][param_name]['range'] | ||||
@@ -372,7 +454,8 @@ class Fuzzer: | |||||
for mutate in mutate_config: | for mutate in mutate_config: | ||||
method = mutate['method'] | method = mutate['method'] | ||||
if method not in self._attacks_list: | if method not in self._attacks_list: | ||||
mutates[method] = self._strategies[method]() | |||||
# mutates[method] = self._strategies[method]() | |||||
mutates[method] = self._strategies[method] | |||||
else: | else: | ||||
network = self._target_model._network | network = self._target_model._network | ||||
loss_fn = self._target_model._loss_fn | loss_fn = self._target_model._loss_fn | ||||
@@ -414,7 +497,6 @@ class Fuzzer: | |||||
else: | else: | ||||
attack_success_rate = None | attack_success_rate = None | ||||
metrics_report['Attack_success_rate'] = attack_success_rate | metrics_report['Attack_success_rate'] = attack_success_rate | ||||
metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) | metrics_report['Coverage_metrics'] = coverage.get_metrics(fuzz_samples) | ||||
return metrics_report | return metrics_report |
@@ -1,609 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform | |||||
""" | |||||
import numpy as np | |||||
from PIL import Image, ImageEnhance, ImageFilter | |||||
from mindspore.dataset.vision.py_transforms_util import is_numpy, \ | |||||
to_pil, hwc_to_chw | |||||
from mindarmour.utils._check_param import check_param_multi_types, check_param_in_range, check_numpy_param | |||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image Transformation' | |||||
def chw_to_hwc(img): | |||||
""" | |||||
Transpose the input image; shape (C, H, W) to shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be converted. | |||||
Returns: | |||||
img (numpy.ndarray), Converted image. | |||||
""" | |||||
if is_numpy(img): | |||||
return img.transpose(1, 2, 0).copy() | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_hwc(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_chw(img): | |||||
""" | |||||
Check if the input image is shape (H, W, C). | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is shape (H, W, C). | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_rgb(img): | |||||
""" | |||||
Check if the input image is RGB. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is RGB. | |||||
""" | |||||
if is_numpy(img): | |||||
img_shape = np.shape(img) | |||||
if len(np.shape(img)) == 3 and (img_shape[0] == 3 or img_shape[2] == 3): | |||||
return True | |||||
return False | |||||
raise TypeError('img should be numpy.ndarray. Got {}'.format(type(img))) | |||||
def is_normalized(img): | |||||
""" | |||||
Check if the input image is normalized between 0 to 1. | |||||
Args: | |||||
img (numpy.ndarray): Image to be checked. | |||||
Returns: | |||||
Bool, True if input is normalized between 0 to 1. | |||||
""" | |||||
if is_numpy(img): | |||||
minimal = np.min(img) | |||||
maximun = np.max(img) | |||||
if minimal >= 0 and maximun <= 1: | |||||
return True | |||||
return False | |||||
raise TypeError('img should be Numpy array. Got {}'.format(type(img))) | |||||
class ImageTransform: | |||||
""" | |||||
The abstract base class for all image transform classes. | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def _check(self, image): | |||||
""" Check image format. If input image is RGB and its shape | |||||
is (C, H, W), it will be transposed to (H, W, C). If the value | |||||
of the image is not normalized , it will be normalized between 0 to 1.""" | |||||
rgb = is_rgb(image) | |||||
chw = False | |||||
gray3dim = False | |||||
normalized = is_normalized(image) | |||||
if rgb: | |||||
chw = is_chw(image) | |||||
if chw: | |||||
image = chw_to_hwc(image) | |||||
else: | |||||
image = image | |||||
else: | |||||
if len(np.shape(image)) == 3: | |||||
gray3dim = True | |||||
image = image[0] | |||||
else: | |||||
image = image | |||||
if normalized: | |||||
image = image*255 | |||||
return rgb, chw, normalized, gray3dim, np.uint8(image) | |||||
def _original_format(self, image, chw, normalized, gray3dim): | |||||
""" Return transformed image with original format. """ | |||||
if not is_numpy(image): | |||||
image = np.array(image) | |||||
if chw: | |||||
image = hwc_to_chw(image) | |||||
if normalized: | |||||
image = image / 255 | |||||
if gray3dim: | |||||
image = np.expand_dims(image, 0) | |||||
return image | |||||
def transform(self, image): | |||||
pass | |||||
class Contrast(ImageTransform): | |||||
""" | |||||
Contrast of an image. | |||||
Args: | |||||
factor (Union[float, int]): Control the contrast of an image. If 1.0, | |||||
gives the original image. If 0, gives a gray image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Contrast, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set contrast parameters. | |||||
Args: | |||||
factor (Union[float, int]): Control the contrast of an image. If 1.0 | |||||
gives the original image. If 0 gives a gray image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(-5, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Contrast(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Brightness(ImageTransform): | |||||
""" | |||||
Brightness of an image. | |||||
Args: | |||||
factor (Union[float, int]): Control the brightness of an image. If 1.0 | |||||
gives the original image. If 0 gives a black image. Default: 1. | |||||
""" | |||||
def __init__(self, factor=1): | |||||
super(Brightness, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=1, auto_param=False): | |||||
""" | |||||
Set brightness parameters. | |||||
Args: | |||||
factor (Union[float, int]): Control the brightness of an image. If 1 | |||||
gives the original image. If 0 gives a black image. Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 5) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
img_contrast = ImageEnhance.Brightness(image) | |||||
trans_image = img_contrast.enhance(self.factor) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Blur(ImageTransform): | |||||
""" | |||||
Blurs the image using Gaussian blur filter. | |||||
Args: | |||||
radius(Union[float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
""" | |||||
def __init__(self, radius=0): | |||||
super(Blur, self).__init__() | |||||
self.set_params(radius) | |||||
def set_params(self, radius=0, auto_param=False): | |||||
""" | |||||
Set blur parameters. | |||||
Args: | |||||
radius (Union[float, int]): Blur radius, 0 means no blur. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.radius = np.random.uniform(-1.5, 1.5) | |||||
else: | |||||
self.radius = check_param_multi_types('radius', radius, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
image = to_pil(image) | |||||
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Noise(ImageTransform): | |||||
""" | |||||
Add noise of an image. | |||||
Args: | |||||
factor (float): factor is the ratio of pixels to add noise. | |||||
If 0 gives the original image. Default 0. | |||||
""" | |||||
def __init__(self, factor=0): | |||||
super(Noise, self).__init__() | |||||
self.set_params(factor) | |||||
def set_params(self, factor=0, auto_param=False): | |||||
""" | |||||
Set noise parameters. | |||||
Args: | |||||
factor (Union[float, int]): factor is the ratio of pixels to | |||||
add noise. If 0 gives the original image. Default 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor = np.random.uniform(0, 1) | |||||
else: | |||||
self.factor = check_param_multi_types('factor', factor, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image (numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) | |||||
trans_image = np.copy(image) | |||||
threshold = 1 - self.factor | |||||
trans_image[noise < -threshold] = 0 | |||||
trans_image[noise > threshold] = 1 | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Translate(ImageTransform): | |||||
""" | |||||
Translate an image. | |||||
Args: | |||||
x_bias (Union[int, float]): X-direction translation, x = x + x_bias*image_length. | |||||
Default: 0. | |||||
y_bias (Union[int, float]): Y-direction translation, y = y + y_bias*image_wide. | |||||
Default: 0. | |||||
""" | |||||
def __init__(self, x_bias=0, y_bias=0): | |||||
super(Translate, self).__init__() | |||||
self.set_params(x_bias, y_bias) | |||||
def set_params(self, x_bias=0, y_bias=0, auto_param=False): | |||||
""" | |||||
Set translate parameters. | |||||
Args: | |||||
x_bias (Union[float, int]): X-direction translation, and x_bias should be in range of (-1, 1). Default: 0. | |||||
y_bias (Union[float, int]): Y-direction translation, and y_bias should be in range of (-1, 1). Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
x_bias = check_param_in_range('x_bias', x_bias, -1, 1) | |||||
y_bias = check_param_in_range('y_bias', y_bias, -1, 1) | |||||
self.auto_param = auto_param | |||||
if auto_param: | |||||
self.x_bias = np.random.uniform(-0.3, 0.3) | |||||
self.y_bias = np.random.uniform(-0.3, 0.3) | |||||
else: | |||||
self.x_bias = check_param_multi_types('x_bias', x_bias, | |||||
[int, float]) | |||||
self.y_bias = check_param_multi_types('y_bias', y_bias, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
image_shape = np.shape(image) | |||||
self.x_bias = image_shape[1]*self.x_bias | |||||
self.y_bias = image_shape[0]*self.y_bias | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(1, 0, self.x_bias, 0, 1, self.y_bias)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Scale(ImageTransform): | |||||
""" | |||||
Scale an image in the middle. | |||||
Args: | |||||
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. | |||||
Default: 1. | |||||
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. | |||||
Default: 1. | |||||
""" | |||||
def __init__(self, factor_x=1, factor_y=1): | |||||
super(Scale, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=1, factor_y=1, auto_param=False): | |||||
""" | |||||
Set scale parameters. | |||||
Args: | |||||
factor_x (Union[float, int]): Rescale in X-direction, x=factor_x*x. | |||||
Default: 1. | |||||
factor_y (Union[float, int]): Rescale in Y-direction, y=factor_y*y. | |||||
Default: 1. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.factor_x = np.random.uniform(0.7, 3) | |||||
self.factor_y = np.random.uniform(0.7, 3) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor_x', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor_y', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
move_x_centor = w / 2*(1 - self.factor_x) | |||||
move_y_centor = h / 2*(1 - self.factor_y) | |||||
img = to_pil(image) | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(self.factor_x, 0, move_x_centor, | |||||
0, self.factor_y, move_y_centor)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Shear(ImageTransform): | |||||
""" | |||||
Shear an image, for each pixel (x, y) in the sheared image, the new value is | |||||
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then | |||||
the sheared image will be rescaled to fit original size. | |||||
Args: | |||||
factor_x (Union[float, int]): Shear factor of horizontal direction. | |||||
Default: 0. | |||||
factor_y (Union[float, int]): Shear factor of vertical direction. | |||||
Default: 0. | |||||
""" | |||||
def __init__(self, factor_x=0, factor_y=0): | |||||
super(Shear, self).__init__() | |||||
self.set_params(factor_x, factor_y) | |||||
def set_params(self, factor_x=0, factor_y=0, auto_param=False): | |||||
""" | |||||
Set shear parameters. | |||||
Args: | |||||
factor_x (Union[float, int]): Shear factor of horizontal direction. | |||||
Default: 0. | |||||
factor_y (Union[float, int]): Shear factor of vertical direction. | |||||
Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if factor_x != 0 and factor_y != 0: | |||||
msg = 'At least one of factor_x and factor_y is zero.' | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if auto_param: | |||||
if np.random.uniform(-1, 1) > 0: | |||||
self.factor_x = np.random.uniform(-2, 2) | |||||
self.factor_y = 0 | |||||
else: | |||||
self.factor_x = 0 | |||||
self.factor_y = np.random.uniform(-2, 2) | |||||
else: | |||||
self.factor_x = check_param_multi_types('factor', factor_x, | |||||
[int, float]) | |||||
self.factor_y = check_param_multi_types('factor', factor_y, | |||||
[int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
rgb, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
if rgb: | |||||
h, w, _ = np.shape(image) | |||||
else: | |||||
h, w = np.shape(image) | |||||
if self.factor_x != 0: | |||||
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] | |||||
min_x = min(boarder_x) | |||||
max_x = max(boarder_x) | |||||
scale = (max_x - min_x) / w | |||||
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 | |||||
move_y_cen = h*(1 - scale) / 2 | |||||
else: | |||||
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] | |||||
min_y = min(boarder_y) | |||||
max_y = max(boarder_y) | |||||
scale = (max_y - min_y) / h | |||||
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 | |||||
move_x_cen = w*(1 - scale) / 2 | |||||
trans_image = img.transform(img.size, Image.AFFINE, | |||||
(scale, scale*self.factor_x, move_x_cen, | |||||
scale*self.factor_y, scale, move_y_cen)) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) | |||||
class Rotate(ImageTransform): | |||||
""" | |||||
Rotate an image of degrees counter clockwise around its center. | |||||
Args: | |||||
angle(Union[float, int]): Degrees counter clockwise. Default: 0. | |||||
""" | |||||
def __init__(self, angle=0): | |||||
super(Rotate, self).__init__() | |||||
self.set_params(angle) | |||||
def set_params(self, angle=0, auto_param=False): | |||||
""" | |||||
Set rotate parameters. | |||||
Args: | |||||
angle(Union[float, int]): Degrees counter clockwise. Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | |||||
""" | |||||
if auto_param: | |||||
self.angle = np.random.uniform(0, 360) | |||||
else: | |||||
self.angle = check_param_multi_types('angle', angle, [int, float]) | |||||
def transform(self, image): | |||||
""" | |||||
Transform the image. | |||||
Args: | |||||
image(numpy.ndarray): Original image to be transformed. | |||||
Returns: | |||||
numpy.ndarray, transformed image. | |||||
""" | |||||
image = check_numpy_param('image', image) | |||||
ori_dtype = image.dtype | |||||
_, chw, normalized, gray3dim, image = self._check(image) | |||||
img = to_pil(image) | |||||
trans_image = img.rotate(self.angle, expand=False) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | |||||
gray3dim) | |||||
return trans_image.astype(ori_dtype) |
@@ -99,15 +99,17 @@ def test_fuzzing_ascend(): | |||||
model = Model(net) | model = Model(net) | ||||
batch_size = 8 | batch_size = 8 | ||||
num_classe = 10 | num_classe = 10 | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
mutate_config = [{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | {'method': 'Contrast', | ||||
'params': {'factor': [2, 1]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | {'method': 'FGSM', | ||||
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||||
] | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
# fuzz test with original test data | # fuzz test with original test data | ||||
@@ -142,15 +144,17 @@ def test_fuzzing_cpu(): | |||||
model = Model(net) | model = Model(net) | ||||
batch_size = 8 | batch_size = 8 | ||||
num_classe = 10 | num_classe = 10 | ||||
mutate_config = [{'method': 'Blur', | |||||
'params': {'auto_param': [True]}}, | |||||
mutate_config = [{'method': 'GaussianBlur', | |||||
'params': {'ksize': [1, 2, 3, 5], | |||||
'auto_param': [True, False]}}, | |||||
{'method': 'UniformNoise', | |||||
'params': {'factor': [0.1, 0.2, 0.3], 'auto_param': [False, True]}}, | |||||
{'method': 'Contrast', | {'method': 'Contrast', | ||||
'params': {'factor': [2, 1]}}, | |||||
{'method': 'Translate', | |||||
'params': {'x_bias': [0.1, 0.3], 'y_bias': [0.2]}}, | |||||
'params': {'alpha': [0.5, 1, 1.5], 'beta': [-10, 0, 10], 'auto_param': [False, True]}}, | |||||
{'method': 'Rotate', | |||||
'params': {'angle': [20, 90], 'auto_param': [False, True]}}, | |||||
{'method': 'FGSM', | {'method': 'FGSM', | ||||
'params': {'eps': [0.1, 0.2, 0.3], 'alpha': [0.1]}} | |||||
] | |||||
'params': {'eps': [0.3, 0.2, 0.4], 'alpha': [0.1], 'bounds': [(0, 1)]}}] | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
@@ -1,126 +0,0 @@ | |||||
# Copyright 2019 Huawei Technologies Co., Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
""" | |||||
Image transform test. | |||||
""" | |||||
import numpy as np | |||||
import pytest | |||||
from mindarmour.utils.logger import LogUtil | |||||
from mindarmour.fuzz_testing.image_transform import Contrast, Brightness, \ | |||||
Blur, Noise, Translate, Scale, Shear, Rotate | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = 'Image transform test' | |||||
LOGGER.set_level('INFO') | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_contrast(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Contrast() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_brightness(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Brightness() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_blur(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Blur() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_noise(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Noise() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_translate(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Translate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_shear(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Shear() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_scale(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Scale() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) | |||||
@pytest.mark.level0 | |||||
@pytest.mark.platform_x86_cpu | |||||
@pytest.mark.platform_x86_ascend_training | |||||
@pytest.mark.platform_arm_ascend_training | |||||
@pytest.mark.env_onecard | |||||
@pytest.mark.component_mindarmour | |||||
def test_rotate(): | |||||
image = (np.random.rand(32, 32)).astype(np.float32) | |||||
trans = Rotate() | |||||
trans.set_params(auto_param=True) | |||||
_ = trans.transform(image) |