Browse Source

!279 Provides mindarmour with model fault injection

Merge pull request !279 from ye12121/master
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
cbed4ad6bd
7 changed files with 982 additions and 0 deletions
  1. +92
    -0
      examples/reliability/model_fault_injection.py
  2. +20
    -0
      mindarmour/reliability/__init__.py
  3. +169
    -0
      mindarmour/reliability/model_fault_injection/README.md
  4. +18
    -0
      mindarmour/reliability/model_fault_injection/__init__.py
  5. +224
    -0
      mindarmour/reliability/model_fault_injection/fault_injection.py
  6. +219
    -0
      mindarmour/reliability/model_fault_injection/fault_type.py
  7. +240
    -0
      tests/ut/python/reliability/model_fault_injection/test_fault_injection.py

+ 92
- 0
examples/reliability/model_fault_injection.py View File

@@ -0,0 +1,92 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Fault injection example.
Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint.
Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz.
File structure:
--cifar10-batches-bin
--train
--data_batch_1.bin
--data_batch_2.bin
--data_batch_3.bin
--data_batch_4.bin
--data_batch_5.bin
--test
--test_batch.bin

Please extract and restructure the file as shown above.
"""
import argparse

from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from examples.common.networks.lenet5.lenet5_net import LeNet5
from examples.common.networks.vgg.vgg import vgg16
from examples.common.networks.resnet.resnet import resnet50
from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset


parser = argparse.ArgumentParser(description='layer_states')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])
parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16'])
parser.add_argument('--device_id', type=int, default=0)
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)


test_flag = args.model
if test_flag == 'lenet':
# load data
DATA_FILE = '../common/dataset/MNIST_Data/test'
ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64)
net = LeNet5()
elif test_flag == 'vgg16':
from examples.common.networks.vgg.config import cifar_cfg as cfg
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = vgg16(10, cfg, 'test')
elif test_flag == 'resnet50':
DATA_FILE = '../common/dataset/cifar10-batches-bin'
ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt'
ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False)
net = resnet50(10)
else:
exit()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

# Initialization
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3]

# Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
results = fi.kick_off()
result_summary = fi.metrics()

# print result
for result in results:
print(result)
for result in result_summary:
print(result)

+ 20
- 0
mindarmour/reliability/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reliability methods of MindArmour
"""

from .model_fault_injection.fault_injection import FaultInjector

__all__ = ['FaultInjector']

+ 169
- 0
mindarmour/reliability/model_fault_injection/README.md View File

@@ -0,0 +1,169 @@
# Demos of model fault injection

## Introduction

This is a demo of fault injection for Mindspore applications written in Python.

## Preparation

For the demo, we should prepare both datasets and pre-train models

### Dateset

For example:

`MINST`:Download MNIST dataset from: http://yann.lecun.com/exdb/mnist/ and extract as follows

```test
File structure:
- data_path
- train
- train-images-idx3-ubyte
- train-labels-idx1-ubyte
- test
- t10k-images-idx3-ubyte
- t10k-labels-idx1-ubyte
```

`CIFAR10`:Download CIFAR10 dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz and extract as follows

```test
File structure:
- data_path
- train
- data_batch_1.bin
- data_batch_2.bin
- data_batch_3.bin
- data_batch_4.bin
- data_batch_5.bin
- test
- test_batch.bin
```

### CheckPoint file

Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint

## Configuration

There are five parameters need to set up.

```python
DATA_FILE = '../common/dataset/MNIST_Data/test'
ckpt_path = '../common/networks/checkpoint_lenet_1-10_1875.ckpt'

...

fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3]
```

`DATA_FILE` is the directory where you store the data.

`ckpt_path` is the directory where you store the checkpoint file.

`fi_type` :
Eight types of faults can be injected. These are `bitflips_random`, `bitflips_designated`, `random`, `zeros`, `nan`, `inf`, `anti_activation` and `precision_loss`

bitflips_random: Bits are flipped randomly in the chosen value.

bitflips_designated: Specified bit is flipped in the chosen value.

random: The chosen value are replaced with random value in the range [-1, 1]

zeros: The chosen value are replaced with zero.

nan: The chosen value are replaced with NaN.

inf: The chosen value are replaced with Inf.

anti_activation: Changing the sign of the chosen value.

precision_loss: Round the chosen value to 1 decimal place

`fi_mode` :
There are twe kinds of injection modes can be specified, `single_layer` or `all_layer`.

`fi_size` is usually the exact number of values to be injected with the specified fault. For `zeros`, `anti_activation` and `precision_loss` fault, `fi_size` is the percentage of total tensor values and varies from 0% to 100%

### Example configuration

Sample 1:

```python
fi_type = ['bitflips_random', 'random', 'zeros', 'inf']
fi_mode = ['single_layer']
fi_size = [1]
```

Sample 2:

```python
fi_type = ['bitflips_designated', 'random', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2]
```

## Usage

Run the test to observe the fault injection. For example:

```bash
#!/bin/bash
cd examples/reliability/
python model_fault_injection.py --device_target GPU --device_id 2 --model lenet
```

`device_target`
`model` is the target model need to be evaluation, choose from `lenet`, `vgg16` and `resnet`, or implement your own model.

## Result

Finally, there are three kinds of result will be return.

Sample:

```test
original_acc:0.979768
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751
...
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443
all_layer_acc_mean:0.282719 all_layer_acc_max:0.890325 all_layer_acc_min:0.000601
```

### Original_acc

The original accuracy of model:

```test
original_acc:0.979768
```

### Specific result of each input parameter

Each result including `type`, `mode`, `size`, `acc` and `SDC`. `type`, `mode` and `size` match along with `fi_type`, `fi_mode` and `fi_size`.

```test
type:bitflips_random mode:single_layer size:1 acc:0.968950 SDC:0.010817
type:bitflips_random mode:single_layer size:2 acc:0.948017 SDC:0.031751
...
type:precision_loss mode:all_layer size:2 acc:0.978966 SDC:0.000801
type:precision_loss mode:all_layer size:3 acc:0.979167 SDC:0.000601
```

### Summary of mode

Summary of `single_layer` or `all_layer`.

```test
single_layer_acc_mean:0.819732 single_layer_acc_max:0.980068 single_layer_acc_min:0.192107
single_layer_SDC_mean:0.160035 single_layer_SDC_max:0.787660 single_layer_SDC_min:-0.000300
all_layer_acc_mean:0.697049 all_layer_acc_max:0.979167 all_layer_acc_min:0.089443
all_layer_SDC_mean:0.282719 all_layer_SDC_max:0.890325 all_layer_SDC_min:0.000601
```

+ 18
- 0
mindarmour/reliability/model_fault_injection/__init__.py View File

@@ -0,0 +1,18 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
This module provides model fault injection to evaluate the reliability of given model.
"""
from .fault_injection import FaultInjector
from .fault_type import FaultType

__all__ = ['FaultInjector', 'FaultType']

+ 224
- 0
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -0,0 +1,224 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Fault injection module
"""

import random
import numpy as np

import mindspore
from mindspore import ops, Tensor

from mindarmour.reliability.model_fault_injection.fault_type import FaultType
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_int_positive

LOGGER = LogUtil.get_instance()
TAG = 'FaultInjector'


class FaultInjector:
"""
Fault injection for deep neural networks and evaluate performance.

Args:
model (Model): The model need to be evaluated.
data (Dataset): The data for testing. The evaluation is base on this data.
fi_type (list): The type of the fault injection which include bitflips_random(flip randomly),
bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc.
fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers.
fi_size (list): The number of fault injection.It mean that how many values need to be injected.

Examples:
>>> net = Net()
>>> model = Model(net)
>>> ds_eval = create_dataloader()
>>> fi_type = ['bitflips_random', 'zeros']
>>> fi_mode = ['single_layer', 'all_layer']
>>> fi_size = [1, 2, 3]
>>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off()
"""

def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None):
"""FaultInjector initiated."""
self.running_list = []
self._init_running_list(fi_type, fi_mode, fi_size)
self.model = model
self.data = data
self._fault_type = FaultType()
self._check_param()
self.result_list = []
self.original_acc = 0
self.original_parameter = {}
self.argmax = ops.Argmax()
self._reducesum = ops.ReduceSum(keep_dims=False)
self._frozen()

def _check_param(self):
"""Check input parameters."""
attr = self._fault_type.__dir__()
if not isinstance(self.data, mindspore.dataset.Dataset):
msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data))
LOGGER.error(TAG, msg)
raise TypeError(msg)
_ = check_int_positive('dataset_size', self.data.get_dataset_size())
if not isinstance(self.model, mindspore.Model):
msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for param in self.running_list:
if param['fi_type'] not in attr:
msg = "'Undefined fault type', got {}.".format(param['fi_type'])
LOGGER.error(TAG, msg)
raise AttributeError(msg)
if param['fi_mode'] not in ['single_layer', 'all_layer']:
msg = "'fault mode should be single_layer or all_layer', but got {}.".format(param['fi_mode'])
LOGGER.error(TAG, msg)
raise ValueError(msg)
_ = check_int_positive('fi_size', param['fi_size'])

def _init_running_list(self, type_, mode_, size_):
"""Initiate fault injection parameters of this evaluation."""
if type_ is None:
type_ = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf',
'anti_activation', 'precision_loss']
if mode_ is None:
mode_ = ['single_layer', 'all_layer']
if size_ is None:
size_ = list(range(1, 4))
for i in type_:
i = i if i.startswith('_') else '_' + i
for j in mode_:
for k in size_:
dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k}
self.running_list.append(dict_)

def _frozen(self):
"""Store original parameters of model."""
trainable_param = self.model.predict_network.trainable_params()
for param in trainable_param:
np_param = param.asnumpy().copy()
bytes_ = np_param.tobytes()
self.original_parameter[param.name] = {}
self.original_parameter[param.name]['datatype'] = np_param.dtype
self.original_parameter[param.name]['shape'] = np_param.shape
self.original_parameter[param.name]['data'] = bytes_.hex()

def _reset_model(self):
"""Reset model with original parameters."""
for weight in self.model.predict_network.trainable_params():
name = weight.name
if name in self.original_parameter.keys():
bytes_w = bytes.fromhex(self.original_parameter[name]['data'])
datatype_w = self.original_parameter[name]['datatype']
shape_w = self.original_parameter[name]['shape']
np_w = np.frombuffer(bytes_w, dtype=datatype_w).reshape(shape_w)
weight.assign_value(Tensor.from_numpy(np_w))
else:
msg = "Layer name not matched, got {}.".format(name)
LOGGER.error(TAG, msg)
raise KeyError(msg)

def kick_off(self):
"""
Startup and return final results.
Returns:
list, the result of fault injection.
"""
result_list = []
for i in range(-1, len(self.running_list)):
arg = self.running_list[i]
total = 0
correct = 0
for data in self.data.create_dict_iterator():
batch = data['image']
label = data['label']
if i != -1:
self._reset_model()
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size'])
output = self.model.predict(batch)
predict = self.argmax(output)
mask = predict == label
total += predict.size
correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy()
acc = correct / total
if i == -1:
self.original_acc = acc
result_list.append({'original_acc': self.original_acc})
else:
result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
'acc': acc, 'SDC': self.original_acc - acc})
self.data.reset()
self._reset_model()
self.result_list = result_list
return result_list

def metrics(self):
"""metrics of final result."""
result_summary = []
single_layer_acc = []
single_layer_sdc = []
all_layer_acc = []
all_layer_sdc = []
for result in self.result_list:
if 'mode' in result.keys():
if result['mode'] == 'single_layer':
single_layer_acc.append(float(result['acc']))
single_layer_sdc.append(float(result['SDC']))
else:
all_layer_acc.append(float(result['acc']))
all_layer_sdc.append(float(result['SDC']))
s_acc = np.array(single_layer_acc)
s_sdc = np.array(single_layer_sdc)
a_acc = np.array(all_layer_acc)
a_sdc = np.array(all_layer_sdc)
if single_layer_acc:
result_summary.append('single_layer_acc_mean:%f single_layer_acc_max:%f single_layer_acc_min:%f'
% (np.mean(s_acc), np.max(s_acc), np.min(s_acc)))
result_summary.append('single_layer_SDC_mean:%f single_layer_SDC_max:%f single_layer_SDC_min:%f'
% (np.mean(s_sdc), np.max(s_sdc), np.min(s_sdc)))
if all_layer_acc:
result_summary.append('all_layer_acc_mean:%f all_layer_acc_max:%f all_layer_acc_min:%f'
% (np.mean(a_acc), np.max(a_acc), np.min(a_acc)))
result_summary.append('all_layer_SDC_mean:%f all_layer_SDC_max:%f all_layer_SDC_min:%f'
% (np.mean(a_sdc), np.max(a_sdc), np.min(a_sdc)))
return result_summary

def _layer_states(self, fi_type, fi_mode, fi_size):
"""FI in layer states."""
# Choose a random layer for injection
if fi_mode == "single_layer":
# Single layer fault injection mode
random_num = [random.randint(0, len(self.model.predict_network.trainable_params()) - 1)]
elif fi_mode == "all_layer":
# Multiple layer fault injection mode
random_num = list(range(len(self.model.predict_network.trainable_params()) - 1))
else:
msg = 'undefined fi_mode {}'.format(fi_mode)
LOGGER.error(TAG, msg)
raise ValueError(msg)
for n in random_num:
# Get layer states info
w = self.model.predict_network.trainable_params()[n]
w_np = w.asnumpy().copy()
elem_shape = w_np.shape
w_np = w_np.reshape(-1)

# fault inject
new_w_np = self._fault_type._fault_inject(w_np, fi_type, fi_size)

# Reshape into original dimensions and store the faulty tensor
new_w_np = np.reshape(new_w_np, elem_shape)
w.set_data(Tensor.from_numpy(new_w_np))

+ 219
- 0
mindarmour/reliability/model_fault_injection/fault_type.py View File

@@ -0,0 +1,219 @@
# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Fault type module
"""

import math
import random
from struct import pack, unpack
import numpy as np

from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'FaultType'


class FaultType:
"""Implementation of specified fault type."""
@staticmethod
def _bitflip(value, pos):
"""
Implement of bitflip.
Args:
value (numpy.ndarray): Input data.
pos (list): The index of flip position.

Returns:
numpy.ndarray, bitflip data.
"""
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1]
value_format = 'B' * int(int(bits) / 8)
value_bytes = value.tobytes()
bytes_ = list(unpack(value_format, value_bytes))
for p in pos:
[q, r] = divmod(p, 8)
bytes_[q] ^= 1 << r
new_value_bytes = pack(value_format, *bytes_)
new_value = np.frombuffer(new_value_bytes, value.dtype)
return new_value[0]

def _fault_inject(self, value, fi_type, fi_size):
"""
Inject the specified fault into the randomly chosen values.
For zeros, anti_activation and precision_loss, fi_size is the percentage of
total number. And the others fault, fi_size is the exact number of values to
be injected.
Args:
value (numpy.ndarray): Input data.
fi_type (str): Fault type.
fi_size (int): The number of fault injection.

Returns:
numpy.ndarray, data after fault injection.
"""
num = value.size
if fi_type in ['zeros', 'anti_activation', 'precision_loss']:
change_size = (fi_size * num) / 100
change_size = math.floor(change_size)
else:
change_size = fi_size

if change_size > num:
change_size = num
# Choose the indices for FI
ind = random.sample(range(num), change_size)

# got specified fault type
try:
func = getattr(self, fi_type)
value = func(value, ind)
return value
except AttributeError:
msg = "'Undefined fault type', got {}.".format(fi_type)
LOGGER.error(TAG, msg)
raise AttributeError(msg)

def _bitflips_random(self, value, fi_indices):
"""
Flip bit randomly for specified value.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
for item in fi_indices:
val = value[item]
pos = random.sample(range(int(str(val.dtype)[-2:])),
1 if np.random.random() < 0.618 else 2)
val_new = self._bitflip(val, pos)
value[item] = val_new
return value

def _bitflips_designated(self, value, fi_indices):
"""
Flip the key bit for specified value.

Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
for item in fi_indices:
val = value[item]
# uint8 uint16 uint32 uint64
bits = str(value.dtype)[-2:] if str(value.dtype)[-2].isdigit() else str(value.dtype)[-1]
if 'uint' in str(val.dtype):
pos = int(bits) - 1
# int8 int16 int32 int64 float16 float32 float64
else:
pos = int(bits) - 2
val_new = self._bitflip(val, [pos])
value[item] = val_new
return value

@staticmethod
def _random(value, fi_indices):
"""
Reset specified value randomly, range from -1 to 1.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
for item in fi_indices:
value[item] = np.random.random() * 2 - 1
return value

@staticmethod
def _zeros(value, fi_indices):
"""
Set specified value into zeros.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
value[fi_indices] = 0.
return value

@staticmethod
def _nan(value, fi_indices):
"""
Set specified value into nan.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
try:
value[fi_indices] = np.nan
return value
except ValueError:
return value

@staticmethod
def _inf(value, fi_indices):
"""
Set specified value into inf
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
try:
value[fi_indices] = np.inf
return value
except OverflowError:
return value

@staticmethod
def _anti_activation(value, fi_indices):
"""
Minus specified value.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
value[fi_indices] = -value[fi_indices]
return value

@staticmethod
def _precision_loss(value, fi_indices):
"""
Round specified value, round to 1 decimal place.
Args:
value (numpy.ndarray): Input data.
fi_indices (list): The index of injected data.

Returns:
numpy.ndarray, data after fault injection.
"""
value[fi_indices] = np.around(value[fi_indices], decimals=1)
return value

+ 240
- 0
tests/ut/python/reliability/model_fault_injection/test_fault_injection.py View File

@@ -0,0 +1,240 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test for fault injection.
"""

import pytest
import numpy as np

from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector

from tests.ut.python.utils.mock_net import Net

LOGGER = LogUtil.get_instance()
TAG = 'Fault injection test'
LOGGER.set_level('INFO')


def dataset_generator():
"""mock training data."""
batch_size = 32
batches = 128
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fault_injector():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Run kick_off and metrics successfully
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_model():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_data():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_type():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw AttributeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(AttributeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is AttributeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_mode():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer_tail', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_size():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [-1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError

Loading…
Cancel
Save