Browse Source

The default type of input data is changed from mindspore.dataset to iterable

tags/v1.6.0
ye12121 3 years ago
parent
commit
1bca351421
3 changed files with 384 additions and 271 deletions
  1. +16
    -5
      examples/reliability/model_fault_injection.py
  2. +81
    -26
      mindarmour/reliability/model_fault_injection/fault_injection.py
  3. +287
    -240
      tests/ut/python/reliability/model_fault_injection/test_fault_injection.py

+ 16
- 5
examples/reliability/model_fault_injection.py View File

@@ -32,8 +32,9 @@ Please extract and restructure the file as shown above.
""" """
import argparse import argparse


import numpy as np
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import load_checkpoint


from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from examples.common.networks.lenet5.lenet5_net import LeNet5 from examples.common.networks.lenet5.lenet5_net import LeNet5
@@ -70,8 +71,18 @@ elif test_flag == 'resnet50':
net = resnet50(10) net = resnet50(10)
else: else:
exit() exit()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)

test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)

param_dict = load_checkpoint(ckpt_path, net=net)
model = Model(net) model = Model(net)


# Initialization # Initialization
@@ -81,8 +92,8 @@ fi_mode = ['single_layer', 'all_layer']
fi_size = [1, 2, 3] fi_size = [1, 2, 3]


# Fault injection # Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
results = fi.kick_off()
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
results = fi.kick_off(ds_data, ds_label, iter_times=100)
result_summary = fi.metrics() result_summary = fi.metrics()


# print result # print result


+ 81
- 26
mindarmour/reliability/model_fault_injection/fault_injection.py View File

@@ -22,7 +22,7 @@ from mindspore import ops, Tensor


from mindarmour.reliability.model_fault_injection.fault_type import FaultType from mindarmour.reliability.model_fault_injection.fault_type import FaultType
from mindarmour.utils.logger import LogUtil from mindarmour.utils.logger import LogUtil
from mindarmour.utils._check_param import check_int_positive
from mindarmour.utils._check_param import check_int_positive, check_param_type, _check_array_not_empty


LOGGER = LogUtil.get_instance() LOGGER = LogUtil.get_instance()
TAG = 'FaultInjector' TAG = 'FaultInjector'
@@ -34,7 +34,6 @@ class FaultInjector:


Args: Args:
model (Model): The model need to be evaluated. model (Model): The model need to be evaluated.
data (Dataset): The data for testing. The evaluation is base on this data.
fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), fi_type (list): The type of the fault injection which include bitflips_random(flip randomly),
bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc.
fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers.
@@ -43,20 +42,20 @@ class FaultInjector:
Examples: Examples:
>>> net = Net() >>> net = Net()
>>> model = Model(net) >>> model = Model(net)
>>> ds_eval = create_dataloader()
>>> ds_data, ds_label = create_data()
>>> fi_type = ['bitflips_random', 'zeros'] >>> fi_type = ['bitflips_random', 'zeros']
>>> fi_mode = ['single_layer', 'all_layer'] >>> fi_mode = ['single_layer', 'all_layer']
>>> fi_size = [1, 2, 3]
>>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off()
>>> fi_size = [1, 2]
>>> fi = FaultInjector(model, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size)
>>> fi.kick_off(ds_data, ds_label)
""" """


def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None):
def __init__(self, model, fi_type=None, fi_mode=None, fi_size=None):
"""FaultInjector initiated.""" """FaultInjector initiated."""
self.running_list = [] self.running_list = []
self.fi_type_map = {}
self._init_running_list(fi_type, fi_mode, fi_size) self._init_running_list(fi_type, fi_mode, fi_size)
self.model = model self.model = model
self.data = data
self._fault_type = FaultType() self._fault_type = FaultType()
self._check_param() self._check_param()
self.result_list = [] self.result_list = []
@@ -68,19 +67,18 @@ class FaultInjector:


def _check_param(self): def _check_param(self):
"""Check input parameters.""" """Check input parameters."""
attr = self._fault_type.__dir__()
if not isinstance(self.data, mindspore.dataset.Dataset):
msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data))
LOGGER.error(TAG, msg)
raise TypeError(msg)
_ = check_int_positive('dataset_size', self.data.get_dataset_size())
ori_attr = self._fault_type.__dir__()
attr = []
for attr_ in ori_attr:
if not attr_.startswith('__') and attr_ not in ['_bitflip', '_fault_inject']:
attr.append(attr_)
if not isinstance(self.model, mindspore.Model): if not isinstance(self.model, mindspore.Model):
msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model))
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise TypeError(msg) raise TypeError(msg)
for param in self.running_list: for param in self.running_list:
if param['fi_type'] not in attr: if param['fi_type'] not in attr:
msg = "'Undefined fault type', got {}.".format(param['fi_type'])
msg = "'Undefined fault type', got {}.".format(self.fi_type_map[param['fi_type']])
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise AttributeError(msg) raise AttributeError(msg)
if param['fi_mode'] not in ['single_layer', 'all_layer']: if param['fi_mode'] not in ['single_layer', 'all_layer']:
@@ -98,11 +96,28 @@ class FaultInjector:
mode_ = ['single_layer', 'all_layer'] mode_ = ['single_layer', 'all_layer']
if size_ is None: if size_ is None:
size_ = list(range(1, 4)) size_ = list(range(1, 4))
if not isinstance(type_, list):
msg = "'fi_type should be list', got {}.".format(type(type_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(mode_, list):
msg = "'fi_mode should be list', got {}.".format(type(mode_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
if not isinstance(size_, list):
msg = "'fi_size should be list', got {}.".format(type(size_))
LOGGER.error(TAG, msg)
raise TypeError(msg)
for i in type_: for i in type_:
i = i if i.startswith('_') else '_' + i
if not isinstance(i, str):
msg = "'fi_type element should be str', got {} type {}.".format(i, type(i))
LOGGER.error(TAG, msg)
raise TypeError(msg)
new_i = i if i.startswith('_') else '_' + i
self.fi_type_map[new_i] = i
for j in mode_: for j in mode_:
for k in size_: for k in size_:
dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k}
dict_ = {'fi_type': new_i, 'fi_mode': j, 'fi_size': k}
self.running_list.append(dict_) self.running_list.append(dict_)


def _frozen(self): def _frozen(self):
@@ -131,20 +146,57 @@ class FaultInjector:
LOGGER.error(TAG, msg) LOGGER.error(TAG, msg)
raise KeyError(msg) raise KeyError(msg)


def kick_off(self):
@staticmethod
def _calculate_batch_size(num, iter_times):
"""Calculate batch size based on iter_times."""
if num <= iter_times:
batch_list = [1] * num
idx_list = [0] * (num + 1)
else:
base_batch_size = num // iter_times
gt_num = num - iter_times * base_batch_size
le_num = iter_times - gt_num
batch_list = [base_batch_size + 1] * gt_num + [base_batch_size] * le_num
idx_list = [0] * (iter_times + 1)
for i, _ in enumerate(batch_list):
idx_list[i + 1] = idx_list[i] + batch_list[i]
return idx_list

@staticmethod
def _check_kick_off_param(ds_data, ds_label, iter_times):
"""check input data and label."""
_ = check_int_positive('iter_times', iter_times)
_ = check_param_type('ds_data', ds_data, np.ndarray)
_ = _check_array_not_empty('ds_data', ds_data)
_ = check_param_type('ds_label', ds_label, np.ndarray)
_ = _check_array_not_empty('ds_label', ds_label)

def kick_off(self, ds_data, ds_label, iter_times=100):
""" """
Startup and return final results. Startup and return final results.

Args:
ds_data(np.ndarray): Input data for testing. The evaluation is based on this data.
ds_label(np.ndarray): The label of data, corresponding to the data.
iter_times(int): The number of evaluations, which will determine the batch size.

Returns: Returns:
list, the result of fault injection.
- list, the result of fault injection.
""" """
self._check_kick_off_param(ds_data, ds_label, iter_times)
num = ds_data.shape[0]
idx_list = self._calculate_batch_size(num, iter_times)
result_list = [] result_list = []
for i in range(-1, len(self.running_list)): for i in range(-1, len(self.running_list)):
arg = self.running_list[i] arg = self.running_list[i]
total = 0 total = 0
correct = 0 correct = 0
for data in self.data.create_dict_iterator():
batch = data['image']
label = data['label']
for idx in range(len(idx_list) - 1):
a = ds_data[idx_list[idx]:idx_list[idx + 1], ...]
batch = Tensor.from_numpy(a)
label = Tensor.from_numpy(ds_label[idx_list[idx]:idx_list[idx + 1], ...])
if label.ndim == 2:
label = self.argmax(label)
if i != -1: if i != -1:
self._reset_model() self._reset_model()
self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size'])
@@ -153,20 +205,23 @@ class FaultInjector:
mask = predict == label mask = predict == label
total += predict.size total += predict.size
correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy()
acc = correct / total
acc = correct / total if total else 0
if i == -1: if i == -1:
self.original_acc = acc self.original_acc = acc
result_list.append({'original_acc': self.original_acc}) result_list.append({'original_acc': self.original_acc})
else: else:
result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
result_list.append({'type': arg['fi_type'][1:], 'mode': arg['fi_mode'], 'size': arg['fi_size'],
'acc': acc, 'SDC': self.original_acc - acc}) 'acc': acc, 'SDC': self.original_acc - acc})
self.data.reset()
self._reset_model() self._reset_model()
self.result_list = result_list self.result_list = result_list
return result_list return result_list


def metrics(self): def metrics(self):
"""metrics of final result."""
"""
metrics of final result.
Returns:
list, the summary of result.
"""
result_summary = [] result_summary = []
single_layer_acc = [] single_layer_acc = []
single_layer_sdc = [] single_layer_sdc = []


+ 287
- 240
tests/ut/python/reliability/model_fault_injection/test_fault_injection.py View File

@@ -1,240 +1,287 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test for fault injection.
"""

import pytest
import numpy as np

from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector

from tests.ut.python.utils.mock_net import Net

LOGGER = LogUtil.get_instance()
TAG = 'Fault injection test'
LOGGER.set_level('INFO')


def dataset_generator():
"""mock training data."""
batch_size = 32
batches = 128
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fault_injector():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Run kick_off and metrics successfully
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_model():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_data():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is TypeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_type():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw AttributeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(AttributeError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is AttributeError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_mode():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer_tail', 'all_layer']
fi_size = [1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_size():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)

ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [-1]

# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size)
_ = fi.kick_off()
_ = fi.metrics()
assert exc_info.type is ValueError
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test for fault injection.
"""
import pytest
import numpy as np
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.utils.logger import LogUtil
from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector
from tests.ut.python.utils.mock_net import Net
LOGGER = LogUtil.get_instance()
TAG = 'Fault injection test'
LOGGER.set_level('INFO')
def dataset_generator():
"""mock training data."""
batch_size = 32
batches = 128
data = np.random.random((batches*batch_size, 1, 32, 32)).astype(
np.float32)
label = np.random.randint(0, 10, batches*batch_size).astype(np.int32)
for i in range(batches):
yield data[i*batch_size:(i + 1)*batch_size],\
label[i*batch_size:(i + 1)*batch_size]
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fault_injector():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Run kick_off and metrics successfully
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_model():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(net, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_data():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw TypeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_data = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
ds_label = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(TypeError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is TypeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_type():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw AttributeError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(AttributeError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is AttributeError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_mode():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer_tail', 'all_layer']
fi_size = [1]
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_wrong_fi_size():
"""
Feature: Fault injector
Description: Test fault injector
Expectation: Throw ValueError exception
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = Net()
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
model = Model(net)
ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label'])
test_images = []
test_labels = []
for data in ds_eval.create_tuple_iterator(output_numpy=True):
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
ds_data = np.concatenate(test_images, axis=0)
ds_label = np.concatenate(test_labels, axis=0)
fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros',
'nan', 'inf', 'anti_activation', 'precision_loss']
fi_mode = ['single_layer', 'all_layer']
fi_size = [-1]
# Fault injection
with pytest.raises(ValueError) as exc_info:
fi = FaultInjector(model, fi_type, fi_mode, fi_size)
_ = fi.kick_off(ds_data, ds_label, iter_times=100)
_ = fi.metrics()
assert exc_info.type is ValueError

Loading…
Cancel
Save