diff --git a/examples/reliability/model_fault_injection.py b/examples/reliability/model_fault_injection.py new file mode 100644 index 0000000..849388f --- /dev/null +++ b/examples/reliability/model_fault_injection.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Fault injection example. +Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. +Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. +File structure: + --cifar10-batches-bin + --train + --data_batch_1.bin + --data_batch_2.bin + --data_batch_3.bin + --data_batch_4.bin + --data_batch_5.bin + --test + --test_batch.bin + +Please extract and restructure the file as shown above. +""" +import argparse + +from mindspore import Model, context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector +from examples.common.networks.lenet5.lenet5_net import LeNet5 +from examples.common.networks.vgg.vgg import vgg16 +from examples.common.networks.resnet.resnet import resnet50 +from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset + + +parser = argparse.ArgumentParser(description='layer_states') +parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) +parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) +parser.add_argument('--device_id', type=int, default=0) +args = parser.parse_args() +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + + +test_flag = args.model +if test_flag == 'lenet': + # load data + DATA_FILE = '../common/dataset/MNIST_Data/test' + ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) + net = LeNet5() +elif test_flag == 'vgg16': + from examples.common.networks.vgg.config import cifar_cfg as cfg + DATA_FILE = '../common/dataset/cifar10-batches-bin' + ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' + ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) + net = vgg16(10, cfg, 'test') +elif test_flag == 'resnet50': + DATA_FILE = '../common/dataset/cifar10-batches-bin' + ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' + ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) + net = resnet50(10) +else: + exit() +param_dict = load_checkpoint(ckpt_path) +load_param_into_net(net, param_dict) +model = Model(net) + +# Initialization +fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] +fi_mode = ['single_layer', 'all_layer'] +fi_size = [1, 2, 3] + +# Fault injection +fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) +results = fi.kick_off() +result_summary = fi.metrics() + +# print result +for result in results: + print(result) +for result in result_summary: + print(result) diff --git a/mindarmour/reliability/__init__.py b/mindarmour/reliability/__init__.py index e69de29..9093e46 100644 --- a/mindarmour/reliability/__init__.py +++ b/mindarmour/reliability/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Reliability methods of MindArmour +""" + +from .model_fault_injection.fault_injection import FaultInjector + +__all__ = ['FaultInjector'] diff --git a/mindarmour/reliability/model_fault_injection/README.md b/mindarmour/reliability/model_fault_injection/README.md new file mode 100644 index 0000000..e118912 --- /dev/null +++ b/mindarmour/reliability/model_fault_injection/README.md @@ -0,0 +1,169 @@ +# Demos of model fault injection + +## Introduction + +This is a demo of fault injection for Mindspore applications written in Python. + +## Preparation + +For the demo, we should prepare both datasets and pre-train models + +### Dateset + +For example: + +`MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows + +```test +File structure: + - data_path + - train + - train-images-idx3-ubyte + - train-labels-idx1-ubyte + - test + - t10k-images-idx3-ubyte + - t10k-labels-idx1-ubyte +``` + +`CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows + +```test +File structure: + - data_path + - train + - data_batch_1.bin + - data_batch_2.bin + - data_batch_3.bin + - data_batch_4.bin + - data_batch_5.bin + - test + - test_batch.bin +``` + +### CheckPoint file + +Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint + +## Configuration + +There are five parameters need to set up. + +```python +DATA_FILE = '../common/dataset/MNIST_Data/test' +ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt' + +... + +fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] +fi_mode = ['single_layer', 'all_layer'] +fi_size = [1, 2, 3] +``` + +`DATA_FILE` is the directory where you store the data. + +`ckpt_path` is the directory where you store the checkpoint file. + +`fi_type` : +Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss` + +bitflips_random: Bits are flipped randomly in the chosen value. + +bitflips_designated: Specified bit is flipped in the chosen value. + +random: The chosen value are replaced with random value in the range [-1, 1] + +zeros: The chosen value are replaced with zero. + +nan: The chosen value are replaced with NaN. + +inf: The chosen value are replaced with Inf. + +anti_activation: Changing the sign of the chosen value. + +precision_loss: Round the chosen value to 1 decimal place + +`fi_mode` : +There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`. + +`fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100% + +### Example configuration + +Sample 1: + +```python +fi_type = ['bitflips_random', 'random', 'zeros', 'inf'] +fi_mode = ['single_layer'] +fi_size = [1] +``` + +Sample 2: + +```python +fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss'] +fi_mode = ['single_layer', 'all_layer'] +fi_size = [1, 2] +``` + +## Usage + +Run the test to observe the fault injection. For example: + +```bash +#!/bin/bash +cd examples/reliability/ +python model_fault_injection.py --device_target GPU --device_id 2 --model lenet +``` + +`device_target` +`model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model. + +## Result + +Finally, there are three kinds of result will be return. + +Sample: + +```test +original_acc:0.979768 +type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 +type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 +... +type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 +type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 +single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 +single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 +all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 +all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601 +``` + +### Original_acc + +The original accuracy of model: + +```test +original_acc:0.979768 +``` + +### Specific result of each input parameter + +Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`. + +```test +type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 +type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 +... +type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 +type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 +``` + +### Summary of mode + +Summary of `single_layer` or `all_layer`. + +```test +single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 +single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 +all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 +all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601 +``` \ No newline at end of file diff --git a/mindarmour/reliability/model_fault_injection/__init__.py b/mindarmour/reliability/model_fault_injection/__init__.py new file mode 100644 index 0000000..f2c861a --- /dev/null +++ b/mindarmour/reliability/model_fault_injection/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This module provides model fault injection to evaluate the reliability of given model. +""" +from .fault_injection import FaultInjector +from .fault_type import FaultType + +__all__ = ['FaultInjector', 'FaultType'] diff --git a/mindarmour/reliability/model_fault_injection/fault_injection.py b/mindarmour/reliability/model_fault_injection/fault_injection.py new file mode 100644 index 0000000..e099e92 --- /dev/null +++ b/mindarmour/reliability/model_fault_injection/fault_injection.py @@ -0,0 +1,224 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Fault injection module +""" + +import random +import numpy as np + +import mindspore +from mindspore import ops, Tensor + +from mindarmour.reliability.model_fault_injection.fault_type import FaultType +from mindarmour.utils.logger import LogUtil +from mindarmour.utils._check_param import check_int_positive + +LOGGER = LogUtil.get_instance() +TAG = 'FaultInjector' + + +class FaultInjector: + """ + Fault injection for deep neural networks and evaluate performance. + + Args: + model (Model): The model need to be evaluated. + data (Dataset): The data for testing. The evaluation is base on this data. + fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), + bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. + fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. + fi_size (list): The number of fault injection.It mean that how many values need to be injected. + + Examples: + >>> net = Net() + >>> model = Model(net) + >>> ds_eval = create_dataloader() + >>> fi_type = ['bitflips_random', 'zeros'] + >>> fi_mode = ['single_layer', 'all_layer'] + >>> fi_size = [1, 2, 3] + >>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) + >>> fi.kick_off() + """ + + def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): + """FaultInjector initiated.""" + self.running_list = [] + self._init_running_list(fi_type, fi_mode, fi_size) + self.model = model + self.data = data + self._fault_type = FaultType() + self._check_param() + self.result_list = [] + self.original_acc = 0 + self.original_parameter = {} + self.argmax = ops.Argmax() + self._reducesum = ops.ReduceSum(keep_dims=False) + self._frozen() + + def _check_param(self): + """Check input parameters.""" + attr = self._fault_type.__dir__() + if not isinstance(self.data, mindspore.dataset.Dataset): + msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + _ = check_int_positive('dataset_size', self.data.get_dataset_size()) + if not isinstance(self.model, mindspore.Model): + msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + for param in self.running_list: + if param['fi_type'] not in attr: + msg = "'Undefined fault type', got {}.".format(param['fi_type']) + LOGGER.error(TAG, msg) + raise AttributeError(msg) + if param['fi_mode'] not in ['single_layer', 'all_layer']: + msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode']) + LOGGER.error(TAG, msg) + raise ValueError(msg) + _ = check_int_positive('fi_size', param['fi_size']) + + def _init_running_list(self, type_, mode_, size_): + """Initiate fault injection parameters of this evaluation.""" + if type_ is None: + type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', + 'anti_activation', 'precision_loss'] + if mode_ is None: + mode_ = ['single_layer', 'all_layer'] + if size_ is None: + size_ = list(range(1, 4)) + for i in type_: + i = i if i.startswith('_') else '_' + i + for j in mode_: + for k in size_: + dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} + self.running_list.append(dict_) + + def _frozen(self): + """Store original parameters of model.""" + trainable_param = self.model.predict_network.trainable_params() + for param in trainable_param: + np_param = param.asnumpy().copy() + bytes_ = np_param.tobytes() + self.original_parameter[param.name] = {} + self.original_parameter[param.name]['datatype'] = np_param.dtype + self.original_parameter[param.name]['shape'] = np_param.shape + self.original_parameter[param.name]['data'] = bytes_.hex() + + def _reset_model(self): + """Reset model with original parameters.""" + for weight in self.model.predict_network.trainable_params(): + name = weight.name + if name in self.original_parameter.keys(): + bytes_w = bytes.fromhex(self.original_parameter[name]['data']) + datatype_w = self.original_parameter[name]['datatype'] + shape_w = self.original_parameter[name]['shape'] + np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w) + weight.assign_value(Tensor.from_numpy(np_w)) + else: + msg = "Layer name not matched, got {}.".format(name) + LOGGER.error(TAG, msg) + raise KeyError(msg) + + def kick_off(self): + """ + Startup and return final results. + Returns: + list, the result of fault injection. + """ + result_list = [] + for i in range(-1, len(self.running_list)): + arg = self.running_list[i] + total = 0 + correct = 0 + for data in self.data.create_dict_iterator(): + batch = data['image'] + label = data['label'] + if i != -1: + self._reset_model() + self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) + output = self.model.predict(batch) + predict = self.argmax(output) + mask = predict == label + total += predict.size + correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() + acc = correct / total + if i == -1: + self.original_acc = acc + result_list.append({'original_acc': self.original_acc}) + else: + result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], + 'acc': acc, 'SDC': self.original_acc - acc}) + self.data.reset() + self._reset_model() + self.result_list = result_list + return result_list + + def metrics(self): + """metrics of final result.""" + result_summary = [] + single_layer_acc = [] + single_layer_sdc = [] + all_layer_acc = [] + all_layer_sdc = [] + for result in self.result_list: + if 'mode' in result.keys(): + if result['mode'] == 'single_layer': + single_layer_acc.append(float(result['acc'])) + single_layer_sdc.append(float(result['SDC'])) + else: + all_layer_acc.append(float(result['acc'])) + all_layer_sdc.append(float(result['SDC'])) + s_acc = np.array(single_layer_acc) + s_sdc = np.array(single_layer_sdc) + a_acc = np.array(all_layer_acc) + a_sdc = np.array(all_layer_sdc) + if single_layer_acc: + result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f' + % (np.mean(s_acc), np.max(s_acc), np.min(s_acc))) + result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f' + % (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc))) + if all_layer_acc: + result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f' + % (np.mean(a_acc), np.max(a_acc), np.min(a_acc))) + result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f' + % (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc))) + return result_summary + + def _layer_states(self, fi_type, fi_mode, fi_size): + """FI in layer states.""" + # Choose a random layer for injection + if fi_mode == "single_layer": + # Single layer fault injection mode + random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)] + elif fi_mode == "all_layer": + # Multiple layer fault injection mode + random_num = list(range(len(self.model.predict_network.trainable_params()) - 1)) + else: + msg = 'undefined fi_mode {}'.format(fi_mode) + LOGGER.error(TAG, msg) + raise ValueError(msg) + for n in random_num: + # Get layer states info + w = self.model.predict_network.trainable_params()[n] + w_np = w.asnumpy().copy() + elem_shape = w_np.shape + w_np = w_np.reshape(-1) + + # fault inject + new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size) + + # Reshape into original dimensions and store the faulty tensor + new_w_np = np.reshape(new_w_np, elem_shape) + w.set_data(Tensor.from_numpy(new_w_np)) diff --git a/mindarmour/reliability/model_fault_injection/fault_type.py b/mindarmour/reliability/model_fault_injection/fault_type.py new file mode 100644 index 0000000..6e40c22 --- /dev/null +++ b/mindarmour/reliability/model_fault_injection/fault_type.py @@ -0,0 +1,219 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +Fault type module +""" + +import math +import random +from struct import pack, unpack +import numpy as np + +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'FaultType' + + +class FaultType: + """Implementation of specified fault type.""" + @staticmethod + def _bitflip(value, pos): + """ + Implement of bitflip. + Args: + value (numpy.ndarray): Input data. + pos (list): The index of flip position. + + Returns: + numpy.ndarray, bitflip data. + """ + bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] + value_format = 'B' * int(int(bits) / 8) + value_bytes = value.tobytes() + bytes_ = list(unpack(value_format, value_bytes)) + for p in pos: + [q, r] = divmod(p, 8) + bytes_[q] ^= 1 << r + new_value_bytes = pack(value_format, *bytes_) + new_value = np.frombuffer(new_value_bytes, value.dtype) + return new_value[0] + + def _fault_inject(self, value, fi_type, fi_size): + """ + Inject the specified fault into the randomly chosen values. + For zeros, anti_activation and precision_loss, fi_size is the percentage of + total number. And the others fault, fi_size is the exact number of values to + be injected. + Args: + value (numpy.ndarray): Input data. + fi_type (str): Fault type. + fi_size (int): The number of fault injection. + + Returns: + numpy.ndarray, data after fault injection. + """ + num = value.size + if fi_type in ['zeros', 'anti_activation', 'precision_loss']: + change_size = (fi_size * num) / 100 + change_size = math.floor(change_size) + else: + change_size = fi_size + + if change_size > num: + change_size = num + # Choose the indices for FI + ind = random.sample(range(num), change_size) + + # got specified fault type + try: + func = getattr(self, fi_type) + value = func(value, ind) + return value + except AttributeError: + msg = "'Undefined fault type', got {}.".format(fi_type) + LOGGER.error(TAG, msg) + raise AttributeError(msg) + + def _bitflips_random(self, value, fi_indices): + """ + Flip bit randomly for specified value. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + for item in fi_indices: + val = value[item] + pos = random.sample(range(int(str(val.dtype)[-2:])), + 1 if np.random.random() < 0.618 else 2) + val_new = self._bitflip(val, pos) + value[item] = val_new + return value + + def _bitflips_designated(self, value, fi_indices): + """ + Flip the key bit for specified value. + + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + for item in fi_indices: + val = value[item] + # uint8 uint16 uint32 uint64 + bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] + if 'uint' in str(val.dtype): + pos = int(bits) - 1 + # int8 int16 int32 int64 float16 float32 float64 + else: + pos = int(bits) - 2 + val_new = self._bitflip(val, [pos]) + value[item] = val_new + return value + + @staticmethod + def _random(value, fi_indices): + """ + Reset specified value randomly, range from -1 to 1. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + for item in fi_indices: + value[item] = np.random.random() * 2 - 1 + return value + + @staticmethod + def _zeros(value, fi_indices): + """ + Set specified value into zeros. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + value[fi_indices] = 0. + return value + + @staticmethod + def _nan(value, fi_indices): + """ + Set specified value into nan. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + try: + value[fi_indices] = np.nan + return value + except ValueError: + return value + + @staticmethod + def _inf(value, fi_indices): + """ + Set specified value into inf + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + try: + value[fi_indices] = np.inf + return value + except OverflowError: + return value + + @staticmethod + def _anti_activation(value, fi_indices): + """ + Minus specified value. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + value[fi_indices] = -value[fi_indices] + return value + + @staticmethod + def _precision_loss(value, fi_indices): + """ + Round specified value, round to 1 decimal place. + Args: + value (numpy.ndarray): Input data. + fi_indices (list): The index of injected data. + + Returns: + numpy.ndarray, data after fault injection. + """ + value[fi_indices] = np.around(value[fi_indices], decimals=1) + return value diff --git a/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py b/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py new file mode 100644 index 0000000..ca727db --- /dev/null +++ b/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py @@ -0,0 +1,240 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test for fault injection. +""" + +import pytest +import numpy as np + +from mindspore import Model +import mindspore.dataset as ds +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from mindarmour.utils.logger import LogUtil +from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector + +from tests.ut.python.utils.mock_net import Net + +LOGGER = LogUtil.get_instance() +TAG = 'Fault injection test' +LOGGER.set_level('INFO') + + +def dataset_generator(): + """mock training data.""" + batch_size = 32 + batches = 128 + data = np.random.random((batches*batch_size, 1, 32, 32)).astype( + np.float32) + label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) + for i in range(batches): + yield data[i*batch_size:(i + 1)*batch_size],\ + label[i*batch_size:(i + 1)*batch_size] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_fault_injector(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Run kick_off and metrics successfully + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_model(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw TypeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(TypeError) as exc_info: + fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + assert exc_info.type is TypeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_data(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw TypeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(TypeError) as exc_info: + fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + assert exc_info.type is TypeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_type(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw AttributeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(AttributeError) as exc_info: + fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + assert exc_info.type is AttributeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_mode(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw ValueError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer_tail', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(ValueError) as exc_info: + fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + assert exc_info.type is ValueError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_size(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw ValueError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [-1] + + # Fault injection + with pytest.raises(ValueError) as exc_info: + fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) + _ = fi.kick_off() + _ = fi.metrics() + assert exc_info.type is ValueError