@@ -27,7 +27,7 @@ import mindspore.nn as nn | |||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindspore import context | from mindspore import context | ||||
from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
from mindspore.train.serialization import load_param_into_net, load_checkpoint | from mindspore.train.serialization import load_param_into_net, load_checkpoint | ||||
from mindarmour.utils import LogUtil | from mindarmour.utils import LogUtil | ||||
@@ -187,12 +187,13 @@ if __name__ == '__main__': | |||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | ||||
# checkpoint save | # checkpoint save | ||||
callbacks = [LossMonitor()] | |||||
if args.rank_save_ckpt_flag: | if args.rank_save_ckpt_flag: | ||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, | ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, | ||||
keep_checkpoint_max=args.ckpt_save_max) | keep_checkpoint_max=args.ckpt_save_max) | ||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, | ckpt_cb = ModelCheckpoint(config=ckpt_config, | ||||
directory=args.outputs_dir, | directory=args.outputs_dir, | ||||
prefix='{}'.format(args.rank)) | prefix='{}'.format(args.rank)) | ||||
callbacks = ckpt_cb | |||||
callbacks.append(ckpt_cb) | |||||
model.train(args.max_epoch, dataset, callbacks=callbacks) | model.train(args.max_epoch, dataset, callbacks=callbacks) |
@@ -51,7 +51,7 @@ def test_lenet_mnist_coverage(): | |||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -41,12 +41,20 @@ def test_lenet_mnist_fuzzing(): | |||||
mutate_config = [{'method': 'Blur', | mutate_config = [{'method': 'Blur', | ||||
'params': {'auto_param': True}}, | 'params': {'auto_param': True}}, | ||||
{'method': 'Contrast', | {'method': 'Contrast', | ||||
'params': {'factor': 2}}, | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'Translate', | {'method': 'Translate', | ||||
'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'Brightness', | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'Noise', | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'Scale', | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'Shear', | |||||
'params': {'auto_param': True}}, | |||||
{'method': 'FGSM', | {'method': 'FGSM', | ||||
'params': {'eps': 0.1, 'alpha': 0.1}} | |||||
] | |||||
'params': {'eps': 0.3, 'alpha': 0.1}} | |||||
] | |||||
# get training data | # get training data | ||||
data_list = "./MNIST_unzip/train" | data_list = "./MNIST_unzip/train" | ||||
@@ -59,7 +67,7 @@ def test_lenet_mnist_fuzzing(): | |||||
train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -79,7 +87,7 @@ def test_lenet_mnist_fuzzing(): | |||||
# make initial seeds | # make initial seeds | ||||
for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
initial_seeds.append([img, label, 0]) | |||||
initial_seeds.append([img, label]) | |||||
initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||
@@ -11,6 +11,7 @@ from .monitor.monitor import RDPMonitor | |||||
from .monitor.monitor import ZCDPMonitor | from .monitor.monitor import ZCDPMonitor | ||||
from .optimizer.optimizer import DPOptimizerClassFactory | from .optimizer.optimizer import DPOptimizerClassFactory | ||||
from .train.model import DPModel | from .train.model import DPModel | ||||
from .evaluation.membership_inference import MembershipInference | |||||
__all__ = ['NoiseGaussianRandom', | __all__ = ['NoiseGaussianRandom', | ||||
'NoiseAdaGaussianRandom', | 'NoiseAdaGaussianRandom', | ||||
@@ -21,4 +22,5 @@ __all__ = ['NoiseGaussianRandom', | |||||
'RDPMonitor', | 'RDPMonitor', | ||||
'ZCDPMonitor', | 'ZCDPMonitor', | ||||
'DPOptimizerClassFactory', | 'DPOptimizerClassFactory', | ||||
'DPModel'] | |||||
'DPModel', | |||||
'MembershipInference'] |
@@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier | |||||
from sklearn.model_selection import GridSearchCV | from sklearn.model_selection import GridSearchCV | ||||
from sklearn.model_selection import RandomizedSearchCV | from sklearn.model_selection import RandomizedSearchCV | ||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = "Attacker" | |||||
def _attack_knn(features, labels, param_grid): | def _attack_knn(features, labels, param_grid): | ||||
""" | """ | ||||
@@ -114,17 +119,31 @@ def get_attack_model(features, labels, config): | |||||
features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
config (dict): Config of attacker, with key in ["method", "params"]. | config (dict): Config of attacker, with key in ["method", "params"]. | ||||
The format is {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, | |||||
params of each method must within the range of changeable parameters. | |||||
Tips of params implement can be found in | |||||
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | |||||
Returns: | Returns: | ||||
sklearn.BaseEstimator, trained model specify by config["method"]. | sklearn.BaseEstimator, trained model specify by config["method"]. | ||||
Examples: | |||||
>>> features = np.random.randn(10, 10) | |||||
>>> labels = np.random.randint(0, 2, 10) | |||||
>>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} | |||||
>>> attack_model = get_attack_model(features, labels, config) | |||||
""" | """ | ||||
method = str.lower(config["method"]) | method = str.lower(config["method"]) | ||||
if method == "knn": | if method == "knn": | ||||
return _attack_knn(features, labels, config["params"]) | return _attack_knn(features, labels, config["params"]) | ||||
if method in ["lr", "logitic regression"]: | |||||
if method == "lr": | |||||
return _attack_lr(features, labels, config["params"]) | return _attack_lr(features, labels, config["params"]) | ||||
if method == "mlp": | if method == "mlp": | ||||
return _attack_mlpc(features, labels, config["params"]) | return _attack_mlpc(features, labels, config["params"]) | ||||
if method in ["rf", "random forest"]: | |||||
if method == "rf": | |||||
return _attack_rf(features, labels, config["params"]) | return _attack_rf(features, labels, config["params"]) | ||||
raise ValueError("Method {} is not support.".format(config["method"])) | |||||
msg = "Method {} is not supported.".format(config["method"]) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) |
@@ -19,10 +19,14 @@ import numpy as np | |||||
import mindspore as ms | import mindspore as ms | ||||
from mindspore.train import Model | from mindspore.train import Model | ||||
import mindspore.nn as nn | |||||
import mindspore.context as context | |||||
from mindspore.dataset.engine import Dataset | |||||
from mindspore import Tensor | from mindspore import Tensor | ||||
from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | ||||
from mindarmour.utils.logger import LogUtil | |||||
LOGGER = LogUtil.get_instance() | |||||
TAG = "MembershipInference" | |||||
def _eval_info(pred, truth, option): | def _eval_info(pred, truth, option): | ||||
""" | """ | ||||
@@ -42,7 +46,9 @@ def _eval_info(pred, truth, option): | |||||
ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. | ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. | ||||
""" | """ | ||||
if pred.size == 0 or truth.size == 0: | if pred.size == 0 or truth.size == 0: | ||||
raise ValueError("Size of pred or truth is 0.") | |||||
msg = "Size of pred or truth is 0." | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
if option == "accuracy": | if option == "accuracy": | ||||
count = np.sum(pred == truth) | count = np.sum(pred == truth) | ||||
@@ -58,7 +64,25 @@ def _eval_info(pred, truth, option): | |||||
return -1 | return -1 | ||||
return count / np.sum(truth) | return count / np.sum(truth) | ||||
raise ValueError("The metric value {} is undefined.".format(option)) | |||||
msg = "The metric value {} is undefined.".format(option) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
def _softmax_cross_entropy(logits, labels): | |||||
""" | |||||
Calculate the SoftmaxCrossEntropy result between logits and labels. | |||||
Args: | |||||
logits (numpy.ndarray): Numpy array of shape(N, C). | |||||
labels (numpy.ndarray): Numpy array of shape(N, ) | |||||
Returns: | |||||
numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. | |||||
""" | |||||
labels = np.eye(logits.shape[1])[labels].astype(np.int32) | |||||
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) | |||||
return -1*np.sum(labels*np.log(logits), axis=1) | |||||
class MembershipInference: | class MembershipInference: | ||||
@@ -66,22 +90,23 @@ class MembershipInference: | |||||
Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. | Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. | ||||
The attack requires obtain loss or logits results of training samples. | The attack requires obtain loss or logits results of training samples. | ||||
References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. | |||||
References: `Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. | |||||
Membership Inference Attacks against Machine Learning Models. 2017. | Membership Inference Attacks against Machine Learning Models. 2017. | ||||
arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_ | |||||
<https://arxiv.org/abs/1610.05820v2>`_ | |||||
Args: | Args: | ||||
model (Model): Target model. | model (Model): Target model. | ||||
Examples: | Examples: | ||||
>>> # ds_train, eval_train are non-overlapping datasets from training dataset. | |||||
>>> # eval_train, eval_test are non-overlapping datasets from test dataset. | |||||
>>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | |||||
>>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | |||||
>>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | |||||
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | ||||
>>> inference_model = MembershipInference(model) | >>> inference_model = MembershipInference(model) | ||||
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | ||||
>>> inference_model.train(ds_train, ds_test, config) | |||||
>>> inference_model.train(train_1, test_1, config) | |||||
>>> metrics = ["precision", "recall", "accuracy"] | >>> metrics = ["precision", "recall", "accuracy"] | ||||
>>> result = inference_model.eval(eval_train, eval_test, metrics) | |||||
>>> result = inference_model.eval(train_2, test_2, metrics) | |||||
Raises: | Raises: | ||||
TypeError: If type of model is not mindspore.train.Model. | TypeError: If type of model is not mindspore.train.Model. | ||||
@@ -89,8 +114,12 @@ class MembershipInference: | |||||
def __init__(self, model): | def __init__(self, model): | ||||
if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) | |||||
msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
self.model = model | self.model = model | ||||
self.method_list = ["knn", "lr", "mlp", "rf"] | |||||
self.attack_list = [] | self.attack_list = [] | ||||
def train(self, dataset_train, dataset_test, attack_config): | def train(self, dataset_train, dataset_test, attack_config): | ||||
@@ -101,11 +130,48 @@ class MembershipInference: | |||||
Args: | Args: | ||||
dataset_train (mindspore.dataset): The training dataset for the target model. | dataset_train (mindspore.dataset): The training dataset for the target model. | ||||
dataset_test (mindspore.dataset): The test set for the target model. | dataset_test (mindspore.dataset): The test set for the target model. | ||||
attack_config (list): Parameter setting for the attack model. | |||||
attack_config (list): Parameter setting for the attack model. The format is | |||||
[{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, | |||||
{"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. | |||||
The support methods list is in self.method_list, and the params of each method | |||||
must within the range of changeable parameters. Tips of params implement | |||||
can be found in | |||||
"https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | |||||
Raises: | Raises: | ||||
ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. | |||||
KeyError: If each config in attack_config doesn't have keys {"method", "params"} | |||||
ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. | |||||
""" | """ | ||||
if not isinstance(dataset_train, Dataset): | |||||
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if not isinstance(dataset_test, Dataset): | |||||
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if not isinstance(attack_config, list): | |||||
msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
for config in attack_config: | |||||
if not isinstance(config, dict): | |||||
msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if {"params", "method"} != set(config.keys()): | |||||
msg = "Each config in attack_config must have keys 'method' and 'params'," \ | |||||
"but your key value is {}.".format(set(config.keys())) | |||||
LOGGER.error(TAG, msg) | |||||
raise KeyError(msg) | |||||
if str.lower(config["method"]) not in self.method_list: | |||||
msg = "Method {} is not support.".format(config["method"]) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
for config in attack_config: | for config in attack_config: | ||||
self.attack_list.append(get_attack_model(features, labels, config)) | self.attack_list.append(get_attack_model(features, labels, config)) | ||||
@@ -124,6 +190,28 @@ class MembershipInference: | |||||
Returns: | Returns: | ||||
list, Each element contains an evaluation indicator for the attack model. | list, Each element contains an evaluation indicator for the attack model. | ||||
""" | """ | ||||
if not isinstance(dataset_train, Dataset): | |||||
msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if not isinstance(dataset_test, Dataset): | |||||
msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if not isinstance(metrics, (list, tuple)): | |||||
msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
metrics = set(metrics) | |||||
metrics_list = {"precision", "accuracy", "recall"} | |||||
if not metrics <= metrics_list: | |||||
msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
result = [] | result = [] | ||||
features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
for attacker in self.attack_list: | for attacker in self.attack_list: | ||||
@@ -170,17 +258,12 @@ class MembershipInference: | |||||
N is the number of sample. C = 1 + dim(logits). | N is the number of sample. C = 1 + dim(logits). | ||||
- numpy.ndarray, Labels for each sample, Shape is (N,). | - numpy.ndarray, Labels for each sample, Shape is (N,). | ||||
""" | """ | ||||
if context.get_context("device_target") != "Ascend": | |||||
raise RuntimeError("The target device must be Ascend, " | |||||
"but current is {}.".format(context.get_context("device_target"))) | |||||
loss_logits = np.array([]) | loss_logits = np.array([]) | ||||
for batch in dataset_x.create_dict_iterator(): | for batch in dataset_x.create_dict_iterator(): | ||||
batch_data = Tensor(batch['image'], ms.float32) | batch_data = Tensor(batch['image'], ms.float32) | ||||
batch_labels = Tensor(batch['label'], ms.int32) | |||||
batch_logits = self.model.predict(batch_data) | |||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) | |||||
batch_loss = loss(batch_logits, batch_labels).asnumpy() | |||||
batch_logits = batch_logits.asnumpy() | |||||
batch_labels = batch['label'].astype(np.int32) | |||||
batch_logits = self.model.predict(batch_data).asnumpy() | |||||
batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) | |||||
batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) | batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) | ||||
if loss_logits.size == 0: | if loss_logits.size == 0: | ||||
@@ -193,5 +276,7 @@ class MembershipInference: | |||||
elif label == 0: | elif label == 0: | ||||
labels = np.zeros(len(loss_logits), np.int32) | labels = np.zeros(len(loss_logits), np.int32) | ||||
else: | else: | ||||
raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) | |||||
msg = "The value of label must be 0 or 1, but got {}.".format(label) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
return loss_logits, labels | return loss_logits, labels |
@@ -22,7 +22,8 @@ from mindspore import Tensor | |||||
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | ||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
check_param_multi_types, check_norm_level, check_param_in_range | |||||
check_param_multi_types, check_norm_level, check_param_in_range, \ | |||||
check_param_type, check_int_positive | |||||
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | ||||
Noise, Translate, Scale, Shear, Rotate | Noise, Translate, Scale, Shear, Rotate | ||||
from mindarmour.attacks import FastGradientSignMethod, \ | from mindarmour.attacks import FastGradientSignMethod, \ | ||||
@@ -93,7 +94,7 @@ class Fuzzer: | |||||
>>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | >>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | ||||
>>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] | >>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] | ||||
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
>>> model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||||
>>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||||
>>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | ||||
""" | """ | ||||
@@ -101,8 +102,9 @@ class Fuzzer: | |||||
self._target_model = check_model('model', target_model, Model) | self._target_model = check_model('model', target_model, Model) | ||||
train_dataset = check_numpy_param('train_dataset', train_dataset) | train_dataset = check_numpy_param('train_dataset', train_dataset) | ||||
self._coverage_metrics = ModelCoverageMetrics(target_model, | self._coverage_metrics = ModelCoverageMetrics(target_model, | ||||
neuron_num, | |||||
segmented_num, | segmented_num, | ||||
neuron_num, train_dataset) | |||||
train_dataset) | |||||
# Allowed mutate strategies so far. | # Allowed mutate strategies so far. | ||||
self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, | self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, | ||||
'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | ||||
@@ -115,23 +117,21 @@ class Fuzzer: | |||||
'Noise'] | 'Noise'] | ||||
self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | ||||
self._attack_param_checklists = { | self._attack_param_checklists = { | ||||
'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
'alpha': {'dtype': [float, int], | |||||
'FGSM': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
'alpha': {'dtype': [float], | |||||
'range': [0, 1]}, | 'range': [0, 1]}, | ||||
'bounds': {'dtype': [list, tuple], | |||||
'range': None}}}, | |||||
'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
'eps_iter': {'dtype': [float, int], | |||||
'range': [0, 1e5]}, | |||||
'nb_iter': {'dtype': [float, int], | |||||
'bounds': {'dtype': [tuple]}}}, | |||||
'PGD': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
'eps_iter': {'dtype': [float], | |||||
'range': [0, 1]}, | |||||
'nb_iter': {'dtype': [int], | |||||
'range': [0, 1e5]}, | 'range': [0, 1e5]}, | ||||
'bounds': {'dtype': [list, tuple], | |||||
'range': None}}}, | |||||
'bounds': {'dtype': [tuple]}}}, | |||||
'MDIIM': { | 'MDIIM': { | ||||
'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
'norm_level': {'dtype': [str], 'range': None}, | |||||
'prob': {'dtype': [float, int], 'range': [0, 1]}, | |||||
'bounds': {'dtype': [list, tuple], 'range': None}}}} | |||||
'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
'norm_level': {'dtype': [str]}, | |||||
'prob': {'dtype': [float], 'range': [0, 1]}, | |||||
'bounds': {'dtype': [tuple]}}}} | |||||
def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | ||||
eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): | eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): | ||||
@@ -140,16 +140,29 @@ class Fuzzer: | |||||
Args: | Args: | ||||
mutate_config (list): Mutate configs. The format is | mutate_config (list): Mutate configs. The format is | ||||
[{'method': 'Blur', 'params': {'auto_param': True}}, {'method': 'Contrast', 'params': {'factor': 2}}]. | |||||
The support methods list is in `self._strategies`, and the params of each | |||||
method must within the range of changeable parameters. | |||||
initial_seeds (numpy.ndarray): Initial seeds used to generate | |||||
mutated samples. | |||||
coverage_metric (str): Model coverage metric of neural networks. | |||||
Default: 'KMNC'. | |||||
eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the type is 'auto', | |||||
it will calculate all the metrics, else if the type is list or tuple, it will | |||||
calculate the metrics specified by user. Default: 'auto'. | |||||
[{'method': 'Blur', 'params': {'auto_param': True}}, | |||||
{'method': 'Contrast', 'params': {'factor': 2}}]. The | |||||
supported methods list is in `self._strategies`, and the | |||||
params of each method must within the range of changeable parameters. | |||||
Supported methods are grouped in three types: | |||||
Firstly, pixel value based transform methods include: | |||||
'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine | |||||
transform methods include: 'Translate', 'Scale', 'Shear' and | |||||
'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'. | |||||
`mutate_config` must have method in the type of pixel value based | |||||
transform methods. The way of setting parameters for first and | |||||
second type methods can be seen in 'mindarmour/fuzzing/image_transform.py'. | |||||
For third type methods, you can refer to the corresponding class. | |||||
initial_seeds (list[list]): Initial seeds used to generate mutated | |||||
samples. The format of initial seeds is [[image_data, label], | |||||
[...], ...]. | |||||
coverage_metric (str): Model coverage metric of neural networks. All | |||||
supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'. | |||||
eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the | |||||
type is 'auto', it will calculate all the metrics, else if the | |||||
type is list or tuple, it will calculate the metrics specified | |||||
by user. All supported evaluate methods are 'accuracy', | |||||
'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'. | |||||
max_iters (int): Max number of select a seed to mutate. | max_iters (int): Max number of select a seed to mutate. | ||||
Default: 10000. | Default: 10000. | ||||
mutate_num_per_seed (int): The number of mutate times for a seed. | mutate_num_per_seed (int): The number of mutate times for a seed. | ||||
@@ -173,16 +186,10 @@ class Fuzzer: | |||||
ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', | ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', | ||||
'kmnc', 'nbc', 'snac']. | 'kmnc', 'nbc', 'snac']. | ||||
""" | """ | ||||
eval_metrics_ = None | |||||
if isinstance(eval_metrics, (list, tuple)): | if isinstance(eval_metrics, (list, tuple)): | ||||
eval_metrics_ = [] | eval_metrics_ = [] | ||||
avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | ||||
for elem in eval_metrics: | for elem in eval_metrics: | ||||
if not isinstance(elem, str): | |||||
msg = 'the type of metric in list `eval_metrics` must be str, but got {}.' \ | |||||
.format(type(elem)) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if elem not in avaliable_metrics: | if elem not in avaliable_metrics: | ||||
msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ | msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ | ||||
.format(avaliable_metrics, elem) | .format(avaliable_metrics, elem) | ||||
@@ -203,7 +210,33 @@ class Fuzzer: | |||||
raise TypeError(msg) | raise TypeError(msg) | ||||
# Check whether the mutate_config meet the specification. | # Check whether the mutate_config meet the specification. | ||||
mutate_config = check_param_type('mutate_config', mutate_config, list) | |||||
for config in mutate_config: | |||||
check_param_type("config['params']", config['params'], dict) | |||||
if set(config.keys()) != {'method', 'params'}: | |||||
msg = "Config must contain 'method' and 'params', but got {}." \ | |||||
.format(set(config.keys())) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if config['method'] not in self._strategies.keys(): | |||||
msg = "Config methods must be in {}, but got {}." \ | |||||
.format(self._strategies.keys(), config['method']) | |||||
LOGGER.error(TAG, msg) | |||||
raise TypeError(msg) | |||||
if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: | |||||
msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}." \ | |||||
.format(coverage_metric) | |||||
LOGGER.error(TAG, msg) | |||||
raise ValueError(msg) | |||||
max_iters = check_int_positive('max_iters', max_iters) | |||||
mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) | |||||
mutates = self._init_mutates(mutate_config) | mutates = self._init_mutates(mutate_config) | ||||
initial_seeds = check_param_type('initial_seeds', initial_seeds, list) | |||||
for seed in initial_seeds: | |||||
check_param_type('seed', seed, list) | |||||
check_numpy_param('seed[0]', seed[0]) | |||||
check_numpy_param('seed[1]', seed[1]) | |||||
seed.append(0) | |||||
seed, initial_seeds = _select_next(initial_seeds) | seed, initial_seeds = _select_next(initial_seeds) | ||||
fuzz_samples = [] | fuzz_samples = [] | ||||
gt_labels = [] | gt_labels = [] | ||||
@@ -248,7 +281,7 @@ class Fuzzer: | |||||
for index in range(len(samples)): | for index in range(len(samples)): | ||||
mutate = samples[:index + 1] | mutate = samples[:index + 1] | ||||
self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | ||||
if coverage_metric == "KMNC": | |||||
if coverage_metric == 'KMNC': | |||||
coverages.append(self._coverage_metrics.get_kmnc()) | coverages.append(self._coverage_metrics.get_kmnc()) | ||||
if coverage_metric == 'NBC': | if coverage_metric == 'NBC': | ||||
coverages.append(self._coverage_metrics.get_nbc()) | coverages.append(self._coverage_metrics.get_nbc()) | ||||
@@ -357,18 +390,24 @@ class Fuzzer: | |||||
dict, evaluate metrics include accuarcy, attack success rate | dict, evaluate metrics include accuarcy, attack success rate | ||||
and neural coverage. | and neural coverage. | ||||
""" | """ | ||||
gt_labels = np.asarray(gt_labels) | |||||
fuzz_preds = np.asarray(fuzz_preds) | |||||
temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | ||||
metrics_report = {} | metrics_report = {} | ||||
if metrics == 'auto' or 'accuracy' in metrics: | if metrics == 'auto' or 'accuracy' in metrics: | ||||
gt_labels = np.asarray(gt_labels) | |||||
fuzz_preds = np.asarray(fuzz_preds) | |||||
acc = np.sum(temp) / np.size(temp) | |||||
if temp.any(): | |||||
acc = np.sum(temp) / np.size(temp) | |||||
else: | |||||
acc = 0 | |||||
metrics_report['Accuracy'] = acc | metrics_report['Accuracy'] = acc | ||||
if metrics == 'auto' or 'attack_success_rate' in metrics: | if metrics == 'auto' or 'attack_success_rate' in metrics: | ||||
cond = [elem in self._attacks_list for elem in fuzz_strategies] | cond = [elem in self._attacks_list for elem in fuzz_strategies] | ||||
temp = temp[cond] | temp = temp[cond] | ||||
attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||||
if temp.any(): | |||||
attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||||
else: | |||||
attack_success_rate = None | |||||
metrics_report['Attack_success_rate'] = attack_success_rate | metrics_report['Attack_success_rate'] = attack_success_rate | ||||
if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | ||||
@@ -350,8 +350,10 @@ class Translate(ImageTransform): | |||||
Translate an image. | Translate an image. | ||||
Args: | Args: | ||||
x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0. | |||||
y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0. | |||||
x_bias ([int, float): X-direction translation, x=x+x_bias*image_length. | |||||
Default: 0. | |||||
y_bias ([int, float): Y-direction translation, y=y+y_bias*image_wide. | |||||
Default: 0. | |||||
""" | """ | ||||
def __init__(self, x_bias=0, y_bias=0): | def __init__(self, x_bias=0, y_bias=0): | ||||
@@ -363,8 +365,10 @@ class Translate(ImageTransform): | |||||
Set translate parameters. | Set translate parameters. | ||||
Args: | Args: | ||||
x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0. | |||||
y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0. | |||||
x_bias ([float, int]): X-direction translation, x=x+x_bias*image_length. | |||||
Default: 0. | |||||
y_bias ([float, int]): Y-direction translation, y=y+y_bias*image_wide. | |||||
Default: 0. | |||||
auto_param (bool): True if auto generate parameters. Default: False. | auto_param (bool): True if auto generate parameters. Default: False. | ||||
""" | """ | ||||
self.auto_param = auto_param | self.auto_param = auto_param | ||||
@@ -579,7 +583,7 @@ class Rotate(ImageTransform): | |||||
""" | """ | ||||
_, chw, normalized, gray3dim, image = self._check(image) | _, chw, normalized, gray3dim, image = self._check(image) | ||||
img = to_pil(image) | img = to_pil(image) | ||||
trans_image = img.rotate(self.angle, expand=True) | |||||
trans_image = img.rotate(self.angle, expand=False) | |||||
trans_image = self._original_format(trans_image, chw, normalized, | trans_image = self._original_format(trans_image, chw, normalized, | ||||
gray3dim) | gray3dim) | ||||
return trans_image | return trans_image |
@@ -21,7 +21,7 @@ from mindspore import Tensor | |||||
from mindspore import Model | from mindspore import Model | ||||
from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
check_int_positive | |||||
check_int_positive, check_param_multi_types | |||||
from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
@@ -43,8 +43,8 @@ class ModelCoverageMetrics: | |||||
Args: | Args: | ||||
model (Model): The pre-trained model which waiting for testing. | model (Model): The pre-trained model which waiting for testing. | ||||
segmented_num (int): The number of segmented sections of neurons' output intervals. | |||||
neuron_num (int): The number of testing neurons. | neuron_num (int): The number of testing neurons. | ||||
segmented_num (int): The number of segmented sections of neurons' output intervals. | |||||
train_dataset (numpy.ndarray): Training dataset used for determine | train_dataset (numpy.ndarray): Training dataset used for determine | ||||
the neurons' output boundaries. | the neurons' output boundaries. | ||||
@@ -52,17 +52,18 @@ class ModelCoverageMetrics: | |||||
ValueError: If neuron_num is too big (for example, bigger than 1e+9). | ValueError: If neuron_num is too big (for example, bigger than 1e+9). | ||||
Examples: | Examples: | ||||
>>> train_images = np.random.random((10000, 128)).astype(np.float32) | |||||
>>> test_images = np.random.random((5000, 128)).astype(np.float32) | |||||
>>> net = LeNet5() | |||||
>>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) | |||||
>>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) | |||||
>>> model = Model(net) | >>> model = Model(net) | ||||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
>>> model_fuzz_test.calculate_coverage(test_images) | |||||
>>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
>>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
>>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
""" | """ | ||||
def __init__(self, model, segmented_num, neuron_num, train_dataset): | |||||
def __init__(self, model, neuron_num, segmented_num, train_dataset): | |||||
self._model = check_model('model', model, Model) | self._model = check_model('model', model, Model) | ||||
self._segmented_num = check_int_positive('segmented_num', segmented_num) | self._segmented_num = check_int_positive('segmented_num', segmented_num) | ||||
self._neuron_num = check_int_positive('neuron_num', neuron_num) | self._neuron_num = check_int_positive('neuron_num', neuron_num) | ||||
@@ -139,8 +140,8 @@ class ModelCoverageMetrics: | |||||
Args: | Args: | ||||
dataset (numpy.ndarray): Data for fuzz test. | dataset (numpy.ndarray): Data for fuzz test. | ||||
bias_coefficient (float): The coefficient used for changing the | |||||
neurons' output boundaries. Default: 0. | |||||
bias_coefficient (Union[int, float]): The coefficient used | |||||
for changing the neurons' output boundaries. Default: 0. | |||||
batch_size (int): The number of samples in a predict batch. | batch_size (int): The number of samples in a predict batch. | ||||
Default: 32. | Default: 32. | ||||
@@ -148,8 +149,10 @@ class ModelCoverageMetrics: | |||||
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | ||||
>>> model_fuzz_test.calculate_coverage(test_images) | >>> model_fuzz_test.calculate_coverage(test_images) | ||||
""" | """ | ||||
dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
batch_size = check_int_positive('batch_size', batch_size) | batch_size = check_int_positive('batch_size', batch_size) | ||||
bias_coefficient = check_param_multi_types('bias_coefficient', bias_coefficient, [int, float]) | |||||
self._lower_bounds -= bias_coefficient*self._var | self._lower_bounds -= bias_coefficient*self._var | ||||
self._upper_bounds += bias_coefficient*self._var | self._upper_bounds += bias_coefficient*self._var | ||||
intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | ||||
@@ -78,7 +78,11 @@ class LogUtil: | |||||
def set_level(self, level): | def set_level(self, level): | ||||
""" | """ | ||||
Set the logging level of this logger, level must be an integer or a | Set the logging level of this logger, level must be an integer or a | ||||
string. | |||||
string. Supported levels are 'NOTSET'(integer: 0), 'ERROR'(integer: 1-40), | |||||
'WARNING'('WARN', integer: 1-30), 'INFO'(integer: 1-20) and 'DEBUG'(integer: 1-10). | |||||
For example, if logger.set_level('WARNING') or logger.set_level(21), then | |||||
logger.warn() and logger.error() in scripts would be printed while running, | |||||
while logger.info() or logger.debug() would not be printed. | |||||
Args: | Args: | ||||
level (Union[int, str]): Level of logger. | level (Union[int, str]): Level of logger. | ||||
@@ -98,7 +98,7 @@ class GradWrapWithLoss(Cell): | |||||
Examples: | Examples: | ||||
>>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01) | >>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01) | ||||
>>> label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||||
>>> labels = Tensor(np.ones([1, 10]).astype(np.float32)) | |||||
>>> net = NET() | >>> net = NET() | ||||
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | ||||
>>> loss_net = WithLossCell(net, loss_fn) | >>> loss_net = WithLossCell(net, loss_fn) | ||||
@@ -71,7 +71,7 @@ def test_lenet_mnist_coverage_cpu(): | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -105,7 +105,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -102,7 +102,7 @@ def test_fuzzing_ascend(): | |||||
] | ] | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -113,7 +113,7 @@ def test_fuzzing_ascend(): | |||||
initial_seeds = [] | initial_seeds = [] | ||||
# make initial seeds | # make initial seeds | ||||
for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
initial_seeds.append([img, label, 0]) | |||||
initial_seeds.append([img, label]) | |||||
initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||
@@ -148,7 +148,7 @@ def test_fuzzing_cpu(): | |||||
] | ] | ||||
# initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
# fuzz test with original test data | # fuzz test with original test data | ||||
# get test data | # get test data | ||||
@@ -159,7 +159,7 @@ def test_fuzzing_cpu(): | |||||
initial_seeds = [] | initial_seeds = [] | ||||
# make initial seeds | # make initial seeds | ||||
for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
initial_seeds.append([img, label, 0]) | |||||
initial_seeds.append([img, label]) | |||||
initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||