Browse Source

!300 Changing default type of input data from mindspore.dataset to iterable

Merge pull request !300 from ye12121/master
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
da422e92b0
3 changed files with 384 additions and 271 deletions
  1. +16
    -5
      examples/reliability/model_fault_injection.py
  2. +81
    -26
      mindarmour/reliability/model_fault_injection/fault_injection.py
  3. +287
    -240
      tests/ut/python/reliability/model_fault_injection/test_fault_injection.py

+ 16
- 5
examples/reliability/model_fault_injection.py View File

@@ -32,8 +32,9 @@ Please extract and restructure the file as shown above.
"""
import argparse

import numpy as np
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint

from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from examples.common.networks.lenet5.lenet5_net import LeNet5
@@ -70,8 +71,18 @@ elif test_flag == 'resnet50':
net = resnet50(10)
else:
exit()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)

test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)

param_dict = load_checkpoint(ckpt_path, net=net)
model = Model(net)

# Initialization
@@ -81,8 +92,8 @@ fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3]

# Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
results = fi.kick_off()
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
results = fi.kick_off(ds_data, ds_label, iter_times=100)
result_summary = fi.metrics()

# print result


+ 81
- 26
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -22,7 +22,7 @@ from mindspore import ops, Tensor

from mindarmour.reliability.model_fault_injection.fault_type import FaultType
from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_int_positive
from mindarmour.utils._check_param import check_int_positive, check_param_type, _check_array_not_empty

LOGGER = LogUtil.get_instance()
TAG = 'FaultInjector'
@@ -34,7 +34,6 @@ class FaultInjector:

Args:
model (Model): The model need to be evaluated.
data (Dataset): The data for testing. The evaluation is base on this data.
fi_type (list): The type of the fault injection which include bitflips_random(flip randomly),
bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc.
fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers.
@@ -43,20 +42,20 @@ class FaultInjector:
Examples:
>>> net = Net()
>>> model = Model(net)
>>> ds_eval = create_dataloader()
>>> ds_data, ds_label = create_data()
>>> fi_type = ['bitflips_random', 'zeros']
>>> fi_mode = ['single_layer', 'all_layer']
>>> fi_size = [1, 2, 3]
>>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off()
>>> fi_size = [1, 2]
>>> fi = FaultInjector(model, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off(ds_data, ds_label)
"""

def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None):
def __init__(self, model, fi_type=None, fi_mode=None, fi_size=None):
"""FaultInjector initiated."""
self.running_list = []
self.fi_type_map = {}
self._init_running_list(fi_type, fi_mode, fi_size)
self.model = model
self.data = data
self._fault_type = FaultType()
self._check_param()
self.result_list = []
@@ -68,19 +67,18 @@ class FaultInjector:

def _check_param(self):
"""Check input parameters."""
attr = self._fault_type.__dir__()
if not isinstance(self.data, mindspore.dataset.Dataset):
msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data))
LOGGER.error(TAG, msg)
raise TypeError(msg)
_ = check_int_positive('dataset_size', self.data.get_dataset_size())
ori_attr = self._fault_type.__dir__()
attr = []
for attr_ in ori_attr:
if not attr_.startswith('__') and attr_ not in ['_bitflip', '_fault_inject']:
attr.append(attr_)
if not isinstance(self.model, mindspore.Model):
msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for param in self.running_list:
if param['fi_type'] not in attr:
msg = "'Undefined fault type', got {}.".format(param['fi_type'])
msg = "'Undefined fault type', got {}.".format(self.fi_type_map[param['fi_type']])
LOGGER.error(TAG, msg)
raise AttributeError(msg)
if param['fi_mode'] not in ['single_layer', 'all_layer']:
@@ -98,11 +96,28 @@ class FaultInjector:
mode_ = ['single_layer', 'all_layer']
if size_ is None:
size_ = list(range(1, 4))
if not isinstance(type_, list):
msg = "'fi_type should be list', got {}.".format(type(type_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(mode_, list):
msg = "'fi_mode should be list', got {}.".format(type(mode_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(size_, list):
msg = "'fi_size should be list', got {}.".format(type(size_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for i in type_:
i = i if i.startswith('_') else '_' + i
if not isinstance(i, str):
msg = "'fi_type element should be str', got {} type {}.".format(i, type(i))
LOGGER.error(TAG, msg)
raise TypeError(msg)
new_i = i if i.startswith('_') else '_' + i
self.fi_type_map[new_i] = i
for j in mode_:
for k in size_:
dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k}
dict_ = {'fi_type': new_i, 'fi_mode': j, 'fi_size': k}
self.running_list.append(dict_)

def _frozen(self):
@@ -131,20 +146,57 @@ class FaultInjector:
LOGGER.error(TAG, msg)
raise KeyError(msg)

def kick_off(self):
@staticmethod
def _calculate_batch_size(num, iter_times):
"""Calculate batch size based on iter_times."""
if num <= iter_times:
batch_list = [1] * num
idx_list = [0] * (num + 1)
else:
base_batch_size = num // iter_times
gt_num = num - iter_times * base_batch_size
le_num = iter_times - gt_num
batch_list = [base_batch_size + 1] * gt_num + [base_batch_size] * le_num
idx_list = [0] * (iter_times + 1)
for i, _ in enumerate(batch_list):
idx_list[i + 1] = idx_list[i] + batch_list[i]
return idx_list

@staticmethod
def _check_kick_off_param(ds_data, ds_label, iter_times):
"""check input data and label."""
_ = check_int_positive('iter_times', iter_times)
_ = check_param_type('ds_data', ds_data, np.ndarray)
_ = _check_array_not_empty('ds_data', ds_data)
_ = check_param_type('ds_label', ds_label, np.ndarray)
_ = _check_array_not_empty('ds_label', ds_label)

def kick_off(self, ds_data, ds_label, iter_times=100):
"""
Startup and return final results.

Args:
ds_data(np.ndarray): Input data for testing. The evaluation is based on this data.
ds_label(np.ndarray): The label of data, corresponding to the data.
iter_times(int): The number of evaluations, which will determine the batch size.

Returns:
list, the result of fault injection.
- list, the result of fault injection.
"""
self._check_kick_off_param(ds_data, ds_label, iter_times)
num = ds_data.shape[0]
idx_list = self._calculate_batch_size(num, iter_times)
result_list = []
for i in range(-1, len(self.running_list)):
arg = self.running_list[i]
total = 0
correct = 0
for data in self.data.create_dict_iterator():
batch = data['image']
label = data['label']
for idx in range(len(idx_list) - 1):
a = ds_data[idx_list[idx]:idx_list[idx + 1], ...]
batch = Tensor.from_numpy(a)
label = Tensor.from_numpy(ds_label[idx_list[idx]:idx_list[idx + 1], ...])
if label.ndim == 2:
label = self.argmax(label)
if i != -1:
self._reset_model()
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size'])
@@ -153,20 +205,23 @@ class FaultInjector:
mask = predict == label
total += predict.size
correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy()
acc = correct / total
acc = correct / total if total else 0
if i == -1:
self.original_acc = acc
result_list.append({'original_acc': self.original_acc})
else:
result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
result_list.append({'type': arg['fi_type'][1:], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
'acc': acc, 'SDC': self.original_acc - acc})
self.data.reset()
self._reset_model()
self.result_list = result_list
return result_list

def metrics(self):
"""metrics of final result."""
"""
metrics of final result.
Returns:
list, the summary of result.
"""
result_summary = []
single_layer_acc = []
single_layer_sdc = []


+ 287
- 240
tests/ut/python/reliability/model_fault_injection/test_fault_injection.py View File

@@ -1,240 +1,287 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test for fault injection.
"""

import pytest
import numpy as np

from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector

from tests.ut.python.utils.mock_net import Net

LOGGER = LogUtil.get_instance()
TAG = 'Fault injection test'
LOGGER.set_level('INFO')


def dataset_generator():
"""mock training data."""
batch_size = 32
batches = 128
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fault_injector():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Run kick_off and metrics successfully
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_model():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_data():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_type():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw AttributeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(AttributeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is AttributeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_mode():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer_tail', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_size():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [-1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test for fault injection.
"""
import pytest
import numpy as np
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from tests.ut.python.utils.mock_net import Net
LOGGER = LogUtil.get_instance()
TAG = 'Fault injection test'
LOGGER.set_level('INFO')
def dataset_generator():
"""mock training data."""
batch_size = 32
batches = 128
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fault_injector():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Run kick_off and metrics successfully
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_model():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_data():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_data = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
ds_label = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_type():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw AttributeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(AttributeError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is AttributeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_mode():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer_tail', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_size():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [-1]
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError

Loading…
Cancel
Save