# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Fault injection example. Download checkpoint from: https://www.mindspore.cn/resources/hub or just trained your own checkpoint. Download dataset from: http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz. File structure: --cifar10-batches-bin --train --data_batch_1.bin --data_batch_2.bin --data_batch_3.bin --data_batch_4.bin --data_batch_5.bin --test --test_batch.bin Please extract and restructure the file as shown above. """ import argparse import numpy as np from mindspore import Model, context from mindspore.train.serialization import load_checkpoint from mindarmour.reliability.model_fault_injection.fault_injection import FaultInjector from examples.common.networks.lenet5.lenet5_net import LeNet5 from examples.common.networks.vgg.vgg import vgg16 from examples.common.networks.resnet.resnet import resnet50 from examples.common.dataset.data_processing import create_dataset_cifar, generate_mnist_dataset parser = argparse.ArgumentParser(description='layer_states') parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) parser.add_argument('--model', type=str, default='lenet', choices=['lenet', 'resnet50', 'vgg16']) parser.add_argument('--device_id', type=int, default=0) args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) test_flag = args.model if test_flag == 'lenet': # load data DATA_FILE = '../common/dataset/MNIST_Data/test' ckpt_path = '../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' ds_eval = generate_mnist_dataset(DATA_FILE, batch_size=64) net = LeNet5() elif test_flag == 'vgg16': from examples.common.networks.vgg.config import cifar_cfg as cfg DATA_FILE = '../common/dataset/cifar10-batches-bin' ckpt_path = '../common/networks/vgg16_ascend_v111_cifar10_offical_cv_bs64_acc93.ckpt' ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) net = vgg16(10, cfg, 'test') elif test_flag == 'resnet50': DATA_FILE = '../common/dataset/cifar10-batches-bin' ckpt_path = '../common/networks/resnet50_ascend_v111_cifar10_offical_cv_bs32_acc92.ckpt' ds_eval = create_dataset_cifar(DATA_FILE, 224, 224, training=False) net = resnet50(10) else: exit() test_images = [] test_labels = [] for data in ds_eval.create_tuple_iterator(output_numpy=True): images = data[0].astype(np.float32) labels = data[1] test_images.append(images) test_labels.append(labels) ds_data = np.concatenate(test_images, axis=0) ds_label = np.concatenate(test_labels, axis=0) param_dict = load_checkpoint(ckpt_path, net=net) model = Model(net) # Initialization fi_type = ['bitflips_random', 'bitflips_designated', 'random', 'zeros', 'nan', 'inf', 'anti_activation', 'precision_loss'] fi_mode = ['single_layer', 'all_layer'] fi_size = [1, 2, 3] # Fault injection fi = FaultInjector(model, fi_type, fi_mode, fi_size) results = fi.kick_off(ds_data, ds_label, iter_times=100) result_summary = fi.metrics() # print result for result in results: print(result) for result in result_summary: print(result)