Merge pull request !279 from ye12121/mastertags/v1.6.0
@@ -0,0 +1,92 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Fault injection example. | |||
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. | |||
Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. | |||
File structure: | |||
--cifar10-batches-bin | |||
--train | |||
--data_batch_1.bin | |||
--data_batch_2.bin | |||
--data_batch_3.bin | |||
--data_batch_4.bin | |||
--data_batch_5.bin | |||
--test | |||
--test_batch.bin | |||
Please extract and restructure the file as shown above. | |||
""" | |||
import argparse | |||
from mindspore import Model, context | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||
from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||
from examples.common.networks.vgg.vgg import vgg16 | |||
from examples.common.networks.resnet.resnet import resnet50 | |||
from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset | |||
parser = argparse.ArgumentParser(description='layer_states') | |||
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) | |||
parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) | |||
parser.add_argument('--device_id', type=int, default=0) | |||
args = parser.parse_args() | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) | |||
test_flag = args.model | |||
if test_flag == 'lenet': | |||
# load data | |||
DATA_FILE = '../common/dataset/MNIST_Data/test' | |||
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) | |||
net = LeNet5() | |||
elif test_flag == 'vgg16': | |||
from examples.common.networks.vgg.config import cifar_cfg as cfg | |||
DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||
ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' | |||
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||
net = vgg16(10, cfg, 'test') | |||
elif test_flag == 'resnet50': | |||
DATA_FILE = '../common/dataset/cifar10-batches-bin' | |||
ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' | |||
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) | |||
net = resnet50(10) | |||
else: | |||
exit() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
# Initialization | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1, 2, 3] | |||
# Fault injection | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
results = fi.kick_off() | |||
result_summary = fi.metrics() | |||
# print result | |||
for result in results: | |||
print(result) | |||
for result in result_summary: | |||
print(result) |
@@ -0,0 +1,20 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Reliability methods of MindArmour | |||
""" | |||
from .model_fault_injection.fault_injection import FaultInjector | |||
__all__ = ['FaultInjector'] |
@@ -0,0 +1,169 @@ | |||
# Demos of model fault injection | |||
## Introduction | |||
This is a demo of fault injection for Mindspore applications written in Python. | |||
## Preparation | |||
For the demo, we should prepare both datasets and pre-train models | |||
### Dateset | |||
For example: | |||
`MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows | |||
```test | |||
File structure: | |||
- data_path | |||
- train | |||
- train-images-idx3-ubyte | |||
- train-labels-idx1-ubyte | |||
- test | |||
- t10k-images-idx3-ubyte | |||
- t10k-labels-idx1-ubyte | |||
``` | |||
`CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows | |||
```test | |||
File structure: | |||
- data_path | |||
- train | |||
- data_batch_1.bin | |||
- data_batch_2.bin | |||
- data_batch_3.bin | |||
- data_batch_4.bin | |||
- data_batch_5.bin | |||
- test | |||
- test_batch.bin | |||
``` | |||
### CheckPoint file | |||
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint | |||
## Configuration | |||
There are five parameters need to set up. | |||
```python | |||
DATA_FILE = '../common/dataset/MNIST_Data/test' | |||
ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt' | |||
... | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1, 2, 3] | |||
``` | |||
`DATA_FILE` is the directory where you store the data. | |||
`ckpt_path` is the directory where you store the checkpoint file. | |||
`fi_type` : | |||
Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss` | |||
bitflips_random: Bits are flipped randomly in the chosen value. | |||
bitflips_designated: Specified bit is flipped in the chosen value. | |||
random: The chosen value are replaced with random value in the range [-1, 1] | |||
zeros: The chosen value are replaced with zero. | |||
nan: The chosen value are replaced with NaN. | |||
inf: The chosen value are replaced with Inf. | |||
anti_activation: Changing the sign of the chosen value. | |||
precision_loss: Round the chosen value to 1 decimal place | |||
`fi_mode` : | |||
There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`. | |||
`fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100% | |||
### Example configuration | |||
Sample 1: | |||
```python | |||
fi_type = ['bitflips_random', 'random', 'zeros', 'inf'] | |||
fi_mode = ['single_layer'] | |||
fi_size = [1] | |||
``` | |||
Sample 2: | |||
```python | |||
fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1, 2] | |||
``` | |||
## Usage | |||
Run the test to observe the fault injection. For example: | |||
```bash | |||
#!/bin/bash | |||
cd examples/reliability/ | |||
python model_fault_injection.py --device_target GPU --device_id 2 --model lenet | |||
``` | |||
`device_target` | |||
`model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model. | |||
## Result | |||
Finally, there are three kinds of result will be return. | |||
Sample: | |||
```test | |||
original_acc:0.979768 | |||
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||
... | |||
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||
all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601 | |||
``` | |||
### Original_acc | |||
The original accuracy of model: | |||
```test | |||
original_acc:0.979768 | |||
``` | |||
### Specific result of each input parameter | |||
Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`. | |||
```test | |||
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817 | |||
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751 | |||
... | |||
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801 | |||
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601 | |||
``` | |||
### Summary of mode | |||
Summary of `single_layer` or `all_layer`. | |||
```test | |||
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107 | |||
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300 | |||
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443 | |||
all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601 | |||
``` |
@@ -0,0 +1,18 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
This module provides model fault injection to evaluate the reliability of given model. | |||
""" | |||
from .fault_injection import FaultInjector | |||
from .fault_type import FaultType | |||
__all__ = ['FaultInjector', 'FaultType'] |
@@ -0,0 +1,224 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Fault injection module | |||
""" | |||
import random | |||
import numpy as np | |||
import mindspore | |||
from mindspore import ops, Tensor | |||
from mindarmour.reliability.model_fault_injection.fault_type import FaultType | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.utils._check_param import check_int_positive | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'FaultInjector' | |||
class FaultInjector: | |||
""" | |||
Fault injection for deep neural networks and evaluate performance. | |||
Args: | |||
model (Model): The model need to be evaluated. | |||
data (Dataset): The data for testing. The evaluation is base on this data. | |||
fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), | |||
bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. | |||
fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. | |||
fi_size (list): The number of fault injection.It mean that how many values need to be injected. | |||
Examples: | |||
>>> net = Net() | |||
>>> model = Model(net) | |||
>>> ds_eval = create_dataloader() | |||
>>> fi_type = ['bitflips_random', 'zeros'] | |||
>>> fi_mode = ['single_layer', 'all_layer'] | |||
>>> fi_size = [1, 2, 3] | |||
>>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) | |||
>>> fi.kick_off() | |||
""" | |||
def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): | |||
"""FaultInjector initiated.""" | |||
self.running_list = [] | |||
self._init_running_list(fi_type, fi_mode, fi_size) | |||
self.model = model | |||
self.data = data | |||
self._fault_type = FaultType() | |||
self._check_param() | |||
self.result_list = [] | |||
self.original_acc = 0 | |||
self.original_parameter = {} | |||
self.argmax = ops.Argmax() | |||
self._reducesum = ops.ReduceSum(keep_dims=False) | |||
self._frozen() | |||
def _check_param(self): | |||
"""Check input parameters.""" | |||
attr = self._fault_type.__dir__() | |||
if not isinstance(self.data, mindspore.dataset.Dataset): | |||
msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) | |||
LOGGER.error(TAG, msg) | |||
raise TypeError(msg) | |||
_ = check_int_positive('dataset_size', self.data.get_dataset_size()) | |||
if not isinstance(self.model, mindspore.Model): | |||
msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) | |||
LOGGER.error(TAG, msg) | |||
raise TypeError(msg) | |||
for param in self.running_list: | |||
if param['fi_type'] not in attr: | |||
msg = "'Undefined fault type', got {}.".format(param['fi_type']) | |||
LOGGER.error(TAG, msg) | |||
raise AttributeError(msg) | |||
if param['fi_mode'] not in ['single_layer', 'all_layer']: | |||
msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode']) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
_ = check_int_positive('fi_size', param['fi_size']) | |||
def _init_running_list(self, type_, mode_, size_): | |||
"""Initiate fault injection parameters of this evaluation.""" | |||
if type_ is None: | |||
type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', | |||
'anti_activation', 'precision_loss'] | |||
if mode_ is None: | |||
mode_ = ['single_layer', 'all_layer'] | |||
if size_ is None: | |||
size_ = list(range(1, 4)) | |||
for i in type_: | |||
i = i if i.startswith('_') else '_' + i | |||
for j in mode_: | |||
for k in size_: | |||
dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} | |||
self.running_list.append(dict_) | |||
def _frozen(self): | |||
"""Store original parameters of model.""" | |||
trainable_param = self.model.predict_network.trainable_params() | |||
for param in trainable_param: | |||
np_param = param.asnumpy().copy() | |||
bytes_ = np_param.tobytes() | |||
self.original_parameter[param.name] = {} | |||
self.original_parameter[param.name]['datatype'] = np_param.dtype | |||
self.original_parameter[param.name]['shape'] = np_param.shape | |||
self.original_parameter[param.name]['data'] = bytes_.hex() | |||
def _reset_model(self): | |||
"""Reset model with original parameters.""" | |||
for weight in self.model.predict_network.trainable_params(): | |||
name = weight.name | |||
if name in self.original_parameter.keys(): | |||
bytes_w = bytes.fromhex(self.original_parameter[name]['data']) | |||
datatype_w = self.original_parameter[name]['datatype'] | |||
shape_w = self.original_parameter[name]['shape'] | |||
np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w) | |||
weight.assign_value(Tensor.from_numpy(np_w)) | |||
else: | |||
msg = "Layer name not matched, got {}.".format(name) | |||
LOGGER.error(TAG, msg) | |||
raise KeyError(msg) | |||
def kick_off(self): | |||
""" | |||
Startup and return final results. | |||
Returns: | |||
list, the result of fault injection. | |||
""" | |||
result_list = [] | |||
for i in range(-1, len(self.running_list)): | |||
arg = self.running_list[i] | |||
total = 0 | |||
correct = 0 | |||
for data in self.data.create_dict_iterator(): | |||
batch = data['image'] | |||
label = data['label'] | |||
if i != -1: | |||
self._reset_model() | |||
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) | |||
output = self.model.predict(batch) | |||
predict = self.argmax(output) | |||
mask = predict == label | |||
total += predict.size | |||
correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() | |||
acc = correct / total | |||
if i == -1: | |||
self.original_acc = acc | |||
result_list.append({'original_acc': self.original_acc}) | |||
else: | |||
result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], | |||
'acc': acc, 'SDC': self.original_acc - acc}) | |||
self.data.reset() | |||
self._reset_model() | |||
self.result_list = result_list | |||
return result_list | |||
def metrics(self): | |||
"""metrics of final result.""" | |||
result_summary = [] | |||
single_layer_acc = [] | |||
single_layer_sdc = [] | |||
all_layer_acc = [] | |||
all_layer_sdc = [] | |||
for result in self.result_list: | |||
if 'mode' in result.keys(): | |||
if result['mode'] == 'single_layer': | |||
single_layer_acc.append(float(result['acc'])) | |||
single_layer_sdc.append(float(result['SDC'])) | |||
else: | |||
all_layer_acc.append(float(result['acc'])) | |||
all_layer_sdc.append(float(result['SDC'])) | |||
s_acc = np.array(single_layer_acc) | |||
s_sdc = np.array(single_layer_sdc) | |||
a_acc = np.array(all_layer_acc) | |||
a_sdc = np.array(all_layer_sdc) | |||
if single_layer_acc: | |||
result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f' | |||
% (np.mean(s_acc), np.max(s_acc), np.min(s_acc))) | |||
result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f' | |||
% (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc))) | |||
if all_layer_acc: | |||
result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f' | |||
% (np.mean(a_acc), np.max(a_acc), np.min(a_acc))) | |||
result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f' | |||
% (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc))) | |||
return result_summary | |||
def _layer_states(self, fi_type, fi_mode, fi_size): | |||
"""FI in layer states.""" | |||
# Choose a random layer for injection | |||
if fi_mode == "single_layer": | |||
# Single layer fault injection mode | |||
random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)] | |||
elif fi_mode == "all_layer": | |||
# Multiple layer fault injection mode | |||
random_num = list(range(len(self.model.predict_network.trainable_params()) - 1)) | |||
else: | |||
msg = 'undefined fi_mode {}'.format(fi_mode) | |||
LOGGER.error(TAG, msg) | |||
raise ValueError(msg) | |||
for n in random_num: | |||
# Get layer states info | |||
w = self.model.predict_network.trainable_params()[n] | |||
w_np = w.asnumpy().copy() | |||
elem_shape = w_np.shape | |||
w_np = w_np.reshape(-1) | |||
# fault inject | |||
new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size) | |||
# Reshape into original dimensions and store the faulty tensor | |||
new_w_np = np.reshape(new_w_np, elem_shape) | |||
w.set_data(Tensor.from_numpy(new_w_np)) |
@@ -0,0 +1,219 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
""" | |||
Fault type module | |||
""" | |||
import math | |||
import random | |||
from struct import pack, unpack | |||
import numpy as np | |||
from mindarmour.utils.logger import LogUtil | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'FaultType' | |||
class FaultType: | |||
"""Implementation of specified fault type.""" | |||
@staticmethod | |||
def _bitflip(value, pos): | |||
""" | |||
Implement of bitflip. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
pos (list): The index of flip position. | |||
Returns: | |||
numpy.ndarray, bitflip data. | |||
""" | |||
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||
value_format = 'B' * int(int(bits) / 8) | |||
value_bytes = value.tobytes() | |||
bytes_ = list(unpack(value_format, value_bytes)) | |||
for p in pos: | |||
[q, r] = divmod(p, 8) | |||
bytes_[q] ^= 1 << r | |||
new_value_bytes = pack(value_format, *bytes_) | |||
new_value = np.frombuffer(new_value_bytes, value.dtype) | |||
return new_value[0] | |||
def _fault_inject(self, value, fi_type, fi_size): | |||
""" | |||
Inject the specified fault into the randomly chosen values. | |||
For zeros, anti_activation and precision_loss, fi_size is the percentage of | |||
total number. And the others fault, fi_size is the exact number of values to | |||
be injected. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_type (str): Fault type. | |||
fi_size (int): The number of fault injection. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
num = value.size | |||
if fi_type in ['zeros', 'anti_activation', 'precision_loss']: | |||
change_size = (fi_size * num) / 100 | |||
change_size = math.floor(change_size) | |||
else: | |||
change_size = fi_size | |||
if change_size > num: | |||
change_size = num | |||
# Choose the indices for FI | |||
ind = random.sample(range(num), change_size) | |||
# got specified fault type | |||
try: | |||
func = getattr(self, fi_type) | |||
value = func(value, ind) | |||
return value | |||
except AttributeError: | |||
msg = "'Undefined fault type', got {}.".format(fi_type) | |||
LOGGER.error(TAG, msg) | |||
raise AttributeError(msg) | |||
def _bitflips_random(self, value, fi_indices): | |||
""" | |||
Flip bit randomly for specified value. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
for item in fi_indices: | |||
val = value[item] | |||
pos = random.sample(range(int(str(val.dtype)[-2:])), | |||
1 if np.random.random() < 0.618 else 2) | |||
val_new = self._bitflip(val, pos) | |||
value[item] = val_new | |||
return value | |||
def _bitflips_designated(self, value, fi_indices): | |||
""" | |||
Flip the key bit for specified value. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
for item in fi_indices: | |||
val = value[item] | |||
# uint8 uint16 uint32 uint64 | |||
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1] | |||
if 'uint' in str(val.dtype): | |||
pos = int(bits) - 1 | |||
# int8 int16 int32 int64 float16 float32 float64 | |||
else: | |||
pos = int(bits) - 2 | |||
val_new = self._bitflip(val, [pos]) | |||
value[item] = val_new | |||
return value | |||
@staticmethod | |||
def _random(value, fi_indices): | |||
""" | |||
Reset specified value randomly, range from -1 to 1. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
for item in fi_indices: | |||
value[item] = np.random.random() * 2 - 1 | |||
return value | |||
@staticmethod | |||
def _zeros(value, fi_indices): | |||
""" | |||
Set specified value into zeros. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
value[fi_indices] = 0. | |||
return value | |||
@staticmethod | |||
def _nan(value, fi_indices): | |||
""" | |||
Set specified value into nan. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
try: | |||
value[fi_indices] = np.nan | |||
return value | |||
except ValueError: | |||
return value | |||
@staticmethod | |||
def _inf(value, fi_indices): | |||
""" | |||
Set specified value into inf | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
try: | |||
value[fi_indices] = np.inf | |||
return value | |||
except OverflowError: | |||
return value | |||
@staticmethod | |||
def _anti_activation(value, fi_indices): | |||
""" | |||
Minus specified value. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
value[fi_indices] = -value[fi_indices] | |||
return value | |||
@staticmethod | |||
def _precision_loss(value, fi_indices): | |||
""" | |||
Round specified value, round to 1 decimal place. | |||
Args: | |||
value (numpy.ndarray): Input data. | |||
fi_indices (list): The index of injected data. | |||
Returns: | |||
numpy.ndarray, data after fault injection. | |||
""" | |||
value[fi_indices] = np.around(value[fi_indices], decimals=1) | |||
return value |
@@ -0,0 +1,240 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Test for fault injection. | |||
""" | |||
import pytest | |||
import numpy as np | |||
from mindspore import Model | |||
import mindspore.dataset as ds | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.utils.logger import LogUtil | |||
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector | |||
from tests.ut.python.utils.mock_net import Net | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Fault injection test' | |||
LOGGER.set_level('INFO') | |||
def dataset_generator(): | |||
"""mock training data.""" | |||
batch_size = 32 | |||
batches = 128 | |||
data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||
np.float32) | |||
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||
for i in range(batches): | |||
yield data[i*batch_size:(i + 1)*batch_size],\ | |||
label[i*batch_size:(i + 1)*batch_size] | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_fault_injector(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Run kick_off and metrics successfully | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1] | |||
# Fault injection | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_wrong_model(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Throw TypeError exception | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1] | |||
# Fault injection | |||
with pytest.raises(TypeError) as exc_info: | |||
fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
assert exc_info.type is TypeError | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_wrong_data(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Throw TypeError exception | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1] | |||
# Fault injection | |||
with pytest.raises(TypeError) as exc_info: | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
assert exc_info.type is TypeError | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_wrong_fi_type(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Throw AttributeError exception | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [1] | |||
# Fault injection | |||
with pytest.raises(AttributeError) as exc_info: | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
assert exc_info.type is AttributeError | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_wrong_fi_mode(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Throw ValueError exception | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer_tail', 'all_layer'] | |||
fi_size = [1] | |||
# Fault injection | |||
with pytest.raises(ValueError) as exc_info: | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
assert exc_info.type is ValueError | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_x86_cpu | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.env_onecard | |||
@pytest.mark.component_mindarmour | |||
def test_wrong_fi_size(): | |||
""" | |||
Feature: Fault injector | |||
Description: Test fault injector | |||
Expectation: Throw ValueError exception | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = Net() | |||
param_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, param_dict) | |||
model = Model(net) | |||
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) | |||
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', | |||
'nan', 'inf', 'anti_activation', 'precision_loss'] | |||
fi_mode = ['single_layer', 'all_layer'] | |||
fi_size = [-1] | |||
# Fault injection | |||
with pytest.raises(ValueError) as exc_info: | |||
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) | |||
_ = fi.kick_off() | |||
_ = fi.metrics() | |||
assert exc_info.type is ValueError |