From 1bca3514213436737c3a9fecd6590a6ec08c83da Mon Sep 17 00:00:00 2001 From: ye12121 Date: Tue, 14 Dec 2021 14:25:35 +0800 Subject: [PATCH] The default type of input data is changed from mindspore.dataset to iterable --- examples/reliability/model_fault_injection.py | 21 +- .../model_fault_injection/fault_injection.py | 107 ++++- .../model_fault_injection/test_fault_injection.py | 527 +++++++++++---------- 3 files changed, 384 insertions(+), 271 deletions(-) diff --git a/examples/reliability/model_fault_injection.py b/examples/reliability/model_fault_injection.py index 849388f..0312826 100644 --- a/examples/reliability/model_fault_injection.py +++ b/examples/reliability/model_fault_injection.py @@ -32,8 +32,9 @@ Please extract and restructure the file as shown above. """ import argparse +import numpy as np from mindspore import Model, context -from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.train.serialization import load_checkpoint from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector from examples.common.networks.lenet5.lenet5_net import LeNet5 @@ -70,8 +71,18 @@ elif test_flag == 'resnet50': net = resnet50(10) else: exit() -param_dict = load_checkpoint(ckpt_path) -load_param_into_net(net, param_dict) + +test_images = [] +test_labels = [] +for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) +ds_data = np.concatenate(test_images, axis=0) +ds_label = np.concatenate(test_labels, axis=0) + +param_dict = load_checkpoint(ckpt_path, net=net) model = Model(net) # Initialization @@ -81,8 +92,8 @@ fi_mode = ['single_layer', 'all_layer'] fi_size = [1, 2, 3] # Fault injection -fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) -results = fi.kick_off() +fi = FaultInjector(model, fi_type, fi_mode, fi_size) +results = fi.kick_off(ds_data, ds_label, iter_times=100) result_summary = fi.metrics() # print result diff --git a/mindarmour/reliability/model_fault_injection/fault_injection.py b/mindarmour/reliability/model_fault_injection/fault_injection.py index e099e92..a6784ef 100644 --- a/mindarmour/reliability/model_fault_injection/fault_injection.py +++ b/mindarmour/reliability/model_fault_injection/fault_injection.py @@ -22,7 +22,7 @@ from mindspore import ops, Tensor from mindarmour.reliability.model_fault_injection.fault_type import FaultType from mindarmour.utils.logger import LogUtil -from mindarmour.utils._check_param import check_int_positive +from mindarmour.utils._check_param import check_int_positive, check_param_type, _check_array_not_empty LOGGER = LogUtil.get_instance() TAG = 'FaultInjector' @@ -34,7 +34,6 @@ class FaultInjector: Args: model (Model): The model need to be evaluated. - data (Dataset): The data for testing. The evaluation is base on this data. fi_type (list): The type of the fault injection which include bitflips_random(flip randomly), bitflips_designated(flip the key bit), random, zeros, nan, inf, anti_activation precision_loss etc. fi_mode (list): The mode of fault injection. Fault inject on just single layer or all layers. @@ -43,20 +42,20 @@ class FaultInjector: Examples: >>> net = Net() >>> model = Model(net) - >>> ds_eval = create_dataloader() + >>> ds_data, ds_label = create_data() >>> fi_type = ['bitflips_random', 'zeros'] >>> fi_mode = ['single_layer', 'all_layer'] - >>> fi_size = [1, 2, 3] - >>> fi = FaultInjector(model, ds_eval, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) - >>> fi.kick_off() + >>> fi_size = [1, 2] + >>> fi = FaultInjector(model, fi_type=fi_type, fi_mode=fi_mode, fi_size=fi_size) + >>> fi.kick_off(ds_data, ds_label) """ - def __init__(self, model, data, fi_type=None, fi_mode=None, fi_size=None): + def __init__(self, model, fi_type=None, fi_mode=None, fi_size=None): """FaultInjector initiated.""" self.running_list = [] + self.fi_type_map = {} self._init_running_list(fi_type, fi_mode, fi_size) self.model = model - self.data = data self._fault_type = FaultType() self._check_param() self.result_list = [] @@ -68,19 +67,18 @@ class FaultInjector: def _check_param(self): """Check input parameters.""" - attr = self._fault_type.__dir__() - if not isinstance(self.data, mindspore.dataset.Dataset): - msg = "'Input data should be Mindspore Dataset', got {}.".format(type(self.data)) - LOGGER.error(TAG, msg) - raise TypeError(msg) - _ = check_int_positive('dataset_size', self.data.get_dataset_size()) + ori_attr = self._fault_type.__dir__() + attr = [] + for attr_ in ori_attr: + if not attr_.startswith('__') and attr_ not in ['_bitflip', '_fault_inject']: + attr.append(attr_) if not isinstance(self.model, mindspore.Model): msg = "'Input model should be Mindspore Model', got {}.".format(type(self.model)) LOGGER.error(TAG, msg) raise TypeError(msg) for param in self.running_list: if param['fi_type'] not in attr: - msg = "'Undefined fault type', got {}.".format(param['fi_type']) + msg = "'Undefined fault type', got {}.".format(self.fi_type_map[param['fi_type']]) LOGGER.error(TAG, msg) raise AttributeError(msg) if param['fi_mode'] not in ['single_layer', 'all_layer']: @@ -98,11 +96,28 @@ class FaultInjector: mode_ = ['single_layer', 'all_layer'] if size_ is None: size_ = list(range(1, 4)) + if not isinstance(type_, list): + msg = "'fi_type should be list', got {}.".format(type(type_)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + if not isinstance(mode_, list): + msg = "'fi_mode should be list', got {}.".format(type(mode_)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + if not isinstance(size_, list): + msg = "'fi_size should be list', got {}.".format(type(size_)) + LOGGER.error(TAG, msg) + raise TypeError(msg) for i in type_: - i = i if i.startswith('_') else '_' + i + if not isinstance(i, str): + msg = "'fi_type element should be str', got {} type {}.".format(i, type(i)) + LOGGER.error(TAG, msg) + raise TypeError(msg) + new_i = i if i.startswith('_') else '_' + i + self.fi_type_map[new_i] = i for j in mode_: for k in size_: - dict_ = {'fi_type': i, 'fi_mode': j, 'fi_size': k} + dict_ = {'fi_type': new_i, 'fi_mode': j, 'fi_size': k} self.running_list.append(dict_) def _frozen(self): @@ -131,20 +146,57 @@ class FaultInjector: LOGGER.error(TAG, msg) raise KeyError(msg) - def kick_off(self): + @staticmethod + def _calculate_batch_size(num, iter_times): + """Calculate batch size based on iter_times.""" + if num <= iter_times: + batch_list = [1] * num + idx_list = [0] * (num + 1) + else: + base_batch_size = num // iter_times + gt_num = num - iter_times * base_batch_size + le_num = iter_times - gt_num + batch_list = [base_batch_size + 1] * gt_num + [base_batch_size] * le_num + idx_list = [0] * (iter_times + 1) + for i, _ in enumerate(batch_list): + idx_list[i + 1] = idx_list[i] + batch_list[i] + return idx_list + + @staticmethod + def _check_kick_off_param(ds_data, ds_label, iter_times): + """check input data and label.""" + _ = check_int_positive('iter_times', iter_times) + _ = check_param_type('ds_data', ds_data, np.ndarray) + _ = _check_array_not_empty('ds_data', ds_data) + _ = check_param_type('ds_label', ds_label, np.ndarray) + _ = _check_array_not_empty('ds_label', ds_label) + + def kick_off(self, ds_data, ds_label, iter_times=100): """ Startup and return final results. + + Args: + ds_data(np.ndarray): Input data for testing. The evaluation is based on this data. + ds_label(np.ndarray): The label of data, corresponding to the data. + iter_times(int): The number of evaluations, which will determine the batch size. + Returns: - list, the result of fault injection. + - list, the result of fault injection. """ + self._check_kick_off_param(ds_data, ds_label, iter_times) + num = ds_data.shape[0] + idx_list = self._calculate_batch_size(num, iter_times) result_list = [] for i in range(-1, len(self.running_list)): arg = self.running_list[i] total = 0 correct = 0 - for data in self.data.create_dict_iterator(): - batch = data['image'] - label = data['label'] + for idx in range(len(idx_list) - 1): + a = ds_data[idx_list[idx]:idx_list[idx + 1], ...] + batch = Tensor.from_numpy(a) + label = Tensor.from_numpy(ds_label[idx_list[idx]:idx_list[idx + 1], ...]) + if label.ndim == 2: + label = self.argmax(label) if i != -1: self._reset_model() self._layer_states(arg['fi_type'], arg['fi_mode'], arg['fi_size']) @@ -153,20 +205,23 @@ class FaultInjector: mask = predict == label total += predict.size correct += self._reducesum(mask.astype(mindspore.float32)).asnumpy() - acc = correct / total + acc = correct / total if total else 0 if i == -1: self.original_acc = acc result_list.append({'original_acc': self.original_acc}) else: - result_list.append({'type': arg['fi_type'], 'mode': arg['fi_mode'], 'size': arg['fi_size'], + result_list.append({'type': arg['fi_type'][1:], 'mode': arg['fi_mode'], 'size': arg['fi_size'], 'acc': acc, 'SDC': self.original_acc - acc}) - self.data.reset() self._reset_model() self.result_list = result_list return result_list def metrics(self): - """metrics of final result.""" + """ + metrics of final result. + Returns: + list, the summary of result. + """ result_summary = [] single_layer_acc = [] single_layer_sdc = [] diff --git a/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py b/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py index ca727db..f99cab2 100644 --- a/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py +++ b/tests/ut/python/reliability/model_fault_injection/test_fault_injection.py @@ -1,240 +1,287 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Test for fault injection. -""" - -import pytest -import numpy as np - -from mindspore import Model -import mindspore.dataset as ds -from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from mindarmour.utils.logger import LogUtil -from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector - -from tests.ut.python.utils.mock_net import Net - -LOGGER = LogUtil.get_instance() -TAG = 'Fault injection test' -LOGGER.set_level('INFO') - - -def dataset_generator(): - """mock training data.""" - batch_size = 32 - batches = 128 - data = np.random.random((batches*batch_size, 1, 32, 32)).astype( - np.float32) - label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) - for i in range(batches): - yield data[i*batch_size:(i + 1)*batch_size],\ - label[i*batch_size:(i + 1)*batch_size] - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_fault_injector(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Run kick_off and metrics successfully - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - model = Model(net) - - ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) - fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer', 'all_layer'] - fi_size = [1] - - # Fault injection - fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_wrong_model(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Throw TypeError exception - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - - ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) - fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer', 'all_layer'] - fi_size = [1] - - # Fault injection - with pytest.raises(TypeError) as exc_info: - fi = FaultInjector(net, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - assert exc_info.type is TypeError - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_wrong_data(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Throw TypeError exception - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - model = Model(net) - - ds_eval = np.random.random((1000, 32, 32, 1)).astype(np.float32) - fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer', 'all_layer'] - fi_size = [1] - - # Fault injection - with pytest.raises(TypeError) as exc_info: - fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - assert exc_info.type is TypeError - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_wrong_fi_type(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Throw AttributeError exception - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - model = Model(net) - - ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) - fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer', 'all_layer'] - fi_size = [1] - - # Fault injection - with pytest.raises(AttributeError) as exc_info: - fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - assert exc_info.type is AttributeError - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_wrong_fi_mode(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Throw ValueError exception - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - model = Model(net) - - ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) - fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer_tail', 'all_layer'] - fi_size = [1] - - # Fault injection - with pytest.raises(ValueError) as exc_info: - fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - assert exc_info.type is ValueError - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_wrong_fi_size(): - """ - Feature: Fault injector - Description: Test fault injector - Expectation: Throw ValueError exception - """ - # load model - ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' - net = Net() - param_dict = load_checkpoint(ckpt_path) - load_param_into_net(net, param_dict) - model = Model(net) - - ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) - fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', - 'nan', 'inf', 'anti_activation', 'precision_loss'] - fi_mode = ['single_layer', 'all_layer'] - fi_size = [-1] - - # Fault injection - with pytest.raises(ValueError) as exc_info: - fi = FaultInjector(model, ds_eval, fi_type, fi_mode, fi_size) - _ = fi.kick_off() - _ = fi.metrics() - assert exc_info.type is ValueError +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test for fault injection. +""" + +import pytest +import numpy as np + +from mindspore import Model +import mindspore.dataset as ds +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from mindarmour.utils.logger import LogUtil +from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector + +from tests.ut.python.utils.mock_net import Net + +LOGGER = LogUtil.get_instance() +TAG = 'Fault injection test' +LOGGER.set_level('INFO') + + +def dataset_generator(): + """mock training data.""" + batch_size = 32 + batches = 128 + data = np.random.random((batches*batch_size, 1, 32, 32)).astype( + np.float32) + label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) + for i in range(batches): + yield data[i*batch_size:(i + 1)*batch_size],\ + label[i*batch_size:(i + 1)*batch_size] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_fault_injector(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Run kick_off and metrics successfully + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + test_images = [] + test_labels = [] + for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + ds_data = np.concatenate(test_images, axis=0) + ds_label = np.concatenate(test_labels, axis=0) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + fi = FaultInjector(model, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_model(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw TypeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + test_images = [] + test_labels = [] + for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + ds_data = np.concatenate(test_images, axis=0) + ds_label = np.concatenate(test_labels, axis=0) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(TypeError) as exc_info: + fi = FaultInjector(net, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + assert exc_info.type is TypeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_data(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw TypeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_data = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + ds_label = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(TypeError) as exc_info: + fi = FaultInjector(model, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + assert exc_info.type is TypeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_type(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw AttributeError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + test_images = [] + test_labels = [] + for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + ds_data = np.concatenate(test_images, axis=0) + ds_label = np.concatenate(test_labels, axis=0) + fi_type = ['bitflips_random_haha', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(AttributeError) as exc_info: + fi = FaultInjector(model, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + assert exc_info.type is AttributeError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_mode(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw ValueError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + test_images = [] + test_labels = [] + for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + ds_data = np.concatenate(test_images, axis=0) + ds_label = np.concatenate(test_labels, axis=0) + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer_tail', 'all_layer'] + fi_size = [1] + + # Fault injection + with pytest.raises(ValueError) as exc_info: + fi = FaultInjector(model, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + assert exc_info.type is ValueError + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +@pytest.mark.component_mindarmour +def test_wrong_fi_size(): + """ + Feature: Fault injector + Description: Test fault injector + Expectation: Throw ValueError exception + """ + # load model + ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' + net = Net() + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + model = Model(net) + + ds_eval = ds.GeneratorDataset(dataset_generator, ['image', 'label']) + test_images = [] + test_labels = [] + for data in ds_eval.create_tuple_iterator(output_numpy=True): + images = data[0].astype(np.float32) + labels = data[1] + test_images.append(images) + test_labels.append(labels) + ds_data = np.concatenate(test_images, axis=0) + ds_label = np.concatenate(test_labels, axis=0) + + fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', + 'nan', 'inf', 'anti_activation', 'precision_loss'] + fi_mode = ['single_layer', 'all_layer'] + fi_size = [-1] + + # Fault injection + with pytest.raises(ValueError) as exc_info: + fi = FaultInjector(model, fi_type, fi_mode, fi_size) + _ = fi.kick_off(ds_data, ds_label, iter_times=100) + _ = fi.metrics() + assert exc_info.type is ValueError