diff --git a/examples/reliability/concept_drift_check_images_lenet.py b/examples/reliability/concept_drift_check_images_lenet.py index 5d77c10..1bdf3d5 100644 --- a/examples/reliability/concept_drift_check_images_lenet.py +++ b/examples/reliability/concept_drift_check_images_lenet.py @@ -17,9 +17,8 @@ from mindspore import Tensor from mindspore.train.model import Model from mindspore import Model, nn, context from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 -from mindspore.train.summary.summary_record import _get_summary_tensor_data from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval +from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster """ @@ -27,23 +26,6 @@ Examples for Lenet. """ -def feature_extract(data, feature_model, layer='output[:Tensor]'): - """ - Extract features. - Args: - data (numpy.ndarray): Input data. - feature_model (Model): The model for extracting features. - layer (str): The feature layer. The layer name could be 'output[:Tensor]', - '1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'. - - Returns: - numpy.ndarray, the feature of input data. - """ - feature_model.predict(Tensor(data)) - layer_out = _get_summary_tensor_data() - return layer_out[layer].asnumpy() - - if __name__ == '__main__': # load model ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' @@ -53,13 +35,13 @@ if __name__ == '__main__': model = Model(net) # load data ds_train = np.load('../../tests/ut/python/dataset/concept_train_lenet.npy') - ds_test = np.load('../../tests/ut/python/dataset/concept_test_lenet.npy') - ds_train = feature_extract(ds_train, model, layer='output[:Tensor]') - ds_test = feature_extract(ds_test, model, layer='output[:Tensor]') - # ood detect - detector = OodDetector(ds_train, ds_test, n_cluster=10) - score = detector.ood_detector() - # Evaluation - num = int(len(ds_test)/2) + ds_eval = np.load('../../tests/ut/python/dataset/concept_test_lenet1.npy') + ds_test = np.load('../../tests/ut/python/dataset/concept_test_lenet2.npy') + # ood detector initialization + detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') + # get optimal threshold with ds_eval + num = int(len(ds_eval) / 2) label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 - dec_acc = result_eval(score, label, threshold=0.5) + optimal_threshold = detector.get_optimal_threshold(label, ds_eval) + # get result of ds_test2. We can also set threshold by ourselves. + result = detector.ood_predict(optimal_threshold, ds_test) diff --git a/examples/reliability/concept_drift_check_images_resnet.py b/examples/reliability/concept_drift_check_images_resnet.py index f97045b..5457e14 100644 --- a/examples/reliability/concept_drift_check_images_resnet.py +++ b/examples/reliability/concept_drift_check_images_resnet.py @@ -17,9 +17,8 @@ from mindspore import Tensor from mindspore.train.model import Model from mindspore import Model, nn, context from examples.common.networks.resnet.resnet import resnet50 -from mindspore.train.summary.summary_record import _get_summary_tensor_data from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval +from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster """ @@ -27,23 +26,6 @@ Examples for Resnet. """ -def feature_extract(data, feature_model, layer='output[:Tensor]'): - """ - Extract features. - Args: - data (numpy.ndarray): Input data. - feature_model (Model): The model for extracting features. - layer (str): The feature layer. The layer name could be 'output[:Tensor]', - '1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'. - - Returns: - numpy.ndarray, the feature of input data. - """ - feature_model.predict(Tensor(data)) - layer_out = _get_summary_tensor_data() - return layer_out[layer].asnumpy() - - if __name__ == '__main__': # load model ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/resnet_1-20_1875.ckpt' @@ -52,14 +34,14 @@ if __name__ == '__main__': load_param_into_net(net, load_dict) model = Model(net) # load data - ds_train = np.load('./train.npy') - ds_test = np.load('./test.npy') - ds_train = feature_extract(ds_train, model, layer='output[:Tensor]') - ds_test = feature_extract(ds_test, model, layer='output[:Tensor]') - # ood detect - detector = OodDetector(ds_train, ds_test, n_cluster=10) - score = detector.ood_detector() - # Evaluation - num = int(len(ds_test)/2) + ds_train = np.load('train.npy') + ds_eval = np.load('test1.npy') + ds_test = np.load('test2.npy') + # ood detector initialization + detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') + # get optimal threshold with ds_eval + num = int(len(ds_eval) / 2) label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 - dec_acc = result_eval(score, label, threshold=0.5) + optimal_threshold = detector.get_optimal_threshold(label, ds_eval) + # get result of ds_test2. We can also set threshold by ourselves. + result = detector.ood_predict(optimal_threshold, ds_test) diff --git a/mindarmour/reliability/concept_drift/README.md b/mindarmour/reliability/concept_drift/README.md index f9e7300..f80239b 100644 --- a/mindarmour/reliability/concept_drift/README.md +++ b/mindarmour/reliability/concept_drift/README.md @@ -109,8 +109,8 @@ For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10 ### Environment Requirements -- Hardware(Ascend) - - Prepare hardware environment with Ascend. +- Hardware + - Prepare hardware environment with Ascend, CPU and GPU. - Framework - MindSpore - For more information, please check the resources below: @@ -122,76 +122,176 @@ For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10 #### Import ```python -import logging -import pytest import numpy as np from mindspore import Tensor from mindspore.train.model import Model -from mindarmour.utils.logger import LogUtil from mindspore import Model, nn, context from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 -from mindspore.train.summary.summary_record import _get_summary_tensor_data -from mindspore.train.serializaton import load_checkpoint, load_pram_into_net -from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster ``` #### Load Classification Model +For convenience, we use a pre-trained model file `checkpoint_lenet-10_1875.ckpt` +in 'mindarmour/tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'. + ```python -ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' +ckpt_path = 'checkpoint_lenet-10_1875.ckpt' net = LeNet5() load_dict = load_checkpoint(ckpt_path) -load_pram_into_net(net, load_dict) +load_param_into_net(net, load_dict) model = Model(net) ``` ->`ckpt_path(str)`:the model path. +>`ckpt_path(str)`: the model path. -#### Data processing -We extract the data features by the Lenet network. +We can also use self-constructed model. +It is important that we need to name the model layer, and get the layer outputs. +Take LeNet as an example. +Firstly, we import `TensorSummary` module. +Secondly, we initialize it as `self.summary = TensorSummary()`. +Finally, we add `self.summary('name', x)` after each layer we pay attention to. Here, `name` of each layer is given by users. +After the above process, we can train the model and load it. ```python +from mindspore import nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import TensorSummary + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): + """Wrap conv.""" + weight = weight_variable() + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, + weight_init=weight, has_bias=False, pad_mode="valid") + +def fc_with_initialize(input_channels, out_channels): + """Wrap initialize method of full connection layer.""" + weight = weight_variable() + bias = weight_variable() + return nn.Dense(input_channels, out_channels, weight, bias) + +def weight_variable(): + """Wrap initialize variable.""" + return TruncatedNormal(0.05) + +class LeNet5(nn.Cell): + """ + Lenet network + """ + def __init__(self): + super(LeNet5, self).__init__() + self.conv1 = conv(1, 6, 5) + self.conv2 = conv(6, 16, 5) + self.fc1 = fc_with_initialize(16*5*5, 120) + self.fc2 = fc_with_initialize(120, 84) + self.fc3 = fc_with_initialize(84, 10) + self.relu = nn.ReLU() + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.summary = TensorSummary() + + def construct(self, x): + """ + construct the network architecture + Returns: + x (tensor): network output + """ + x = self.conv1(x) + self.summary('1', x) + + x = self.relu(x) + self.summary('2', x) + + x = self.max_pool2d(x) + self.summary('3', x) + + x = self.conv2(x) + self.summary('4', x) + + x = self.relu(x) + self.summary('5', x) + + x = self.max_pool2d(x) + self.summary('6', x) + + x = self.flatten(x) + self.summary('7', x) + + x = self.fc1(x) + self.summary('8', x) + + x = self.relu(x) + self.summary('9', x) + + x = self.fc2(x) + self.summary('10', x) + + x = self.relu(x) + self.summary('11', x) + + x = self.fc3(x) + self.summary('output', x) + return x + +``` +#### Load Data + +We prepare three datasets. The training dataset, that is the same as the dataset to train the Lenet. Two testing datasets, the first testing dataset is with OOD label(0 for non-ood, and 1 for ood) for finding an optimal threshold for ood detection. +The second testing dataset is for ood validation. The first testing dataset is not necessary if we would like to set threshold by ourselves + +```python ds_train = np.load('../../dataset/concept_train_lenet.npy') -ds_test = np.load('../../dataset/concept_test_lenet.npy') -ds_train = feature_extract(ds_train, model, layer='9[:Tensor]') -ds_test = feature_extract(ds_test, model, layer='9[:Tensor]') +ds_eval = np.load('../../dataset/concept_test_lenet1.npy') +ds_test = np.load('../../dataset/concept_test_lenet2.npy') ``` > `ds_train(numpy.ndarray)`: the train data. -> `ds_test(numpy.ndarray)`: the test data. -> `model(Model)`: the Lenet model. +> `ds_eval(numpy.ndarray)`: the data for finding an optimal threshold. This dataset is not necessary. +> `ds_test(numpy.ndarray)`: the test data for ood detection. -#### Train the concept drift detector +#### OOD detector initialization OOD detector for Lenet. ```python -detector = OodDetector(ds_train, ds_test, n_cluster=10) -score = detector.ood_detector() +# ood detector initialization +detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') ``` - -> `ds_train(numpy.ndarray)`: the train data. -> `ds_test(numpy.ndarray)`: the test data. +> `model(Model)`: the model trained by the `ds_train`. +> `ds_train(numpy.ndarray)`: the training data. > `n_cluster(int)`: the feature cluster number. +> `layer(str)`: the name of the feature extraction layer. + +In our example, we input the layer name `output[:Tensor]`, which can also be`9[:Tensor]`, `10[:Tensor]`, `11[:Tensor]` for LeNet. -#### Evaluation +#### Optimal Threshold + +This step is optional. If we have a labeled dataset, named `ds_eval`, we can use the following code to find the optimal detection threshold. ```python -num = int(len(ds_test)/2) +# get optimal threshold with ds_eval +num = int(len(ds_eval) / 2) label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 -dec_acc = result_eval(score, label, threshold=0.5) +optimal_threshold = detector.get_optimal_threshold(label, ds_eval) ``` -> `ds_test(numpy.ndarray)`: the test data. -> `score(numpy.ndarray)`: the concept drift score. -> `label(numpy.ndarray)`: the drift label. -> `threshold(float)`: the threshold to judge out-of-distribution. +> `ds_eval(numpy.ndarray)`: the data for finding an optimal threshold. +> `label(numpy.ndarray)`: the ood label of ds_eval. 0 means non-ood data, and 1 means ood data. + +#### Detection result + +```python +result = detector.ood_predict(optimal_threshold, ds_test) +``` +> `ds_test(numpy.ndarray)`: the testing data for ood detection. +> `optimal_threshold(float)`: the optimal threshold to judge out-of-distribution data. We can also set the threshold value by ourselves. ## Script Description diff --git a/mindarmour/reliability/concept_drift/concept_drift_check_images.py b/mindarmour/reliability/concept_drift/concept_drift_check_images.py index 18f6436..10c3daa 100644 --- a/mindarmour/reliability/concept_drift/concept_drift_check_images.py +++ b/mindarmour/reliability/concept_drift/concept_drift_check_images.py @@ -16,12 +16,14 @@ import heapq import numpy as np +from mindspore import Tensor from sklearn.cluster import KMeans from mindarmour.utils._check_param import check_param_type, check_param_in_range +from mindspore.train.summary.summary_record import _get_summary_tensor_data + """ Out-of-Distribution detection for images. -The sample can be run on Ascend 910 AI processor. """ @@ -30,58 +32,166 @@ class OodDetector: Train the OOD detector. Args: + model (Model):The training model. ds_train (numpy.ndarray): The training dataset. - ds_test (numpy.ndarray): The testing dataset. """ + def __init__(self, model, ds_train): + self.model = model + self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) + + def _feature_extract(self, model, data, layer='output[:Tensor]'): + """ + Extract features. + Args: + model (Model): The model for extracting features. + data (numpy.ndarray): Input data. + layer (str): The name of the feature layer. layer (str) is represented as + 'name[:Tensor]', where 'name' is given by users when training the model. + Please see more details about how to name the model layer in 'README.md'. + + Returns: + numpy.ndarray, the data feature extracted by a certain neural layer. + """ + model.predict(Tensor(data)) + layer_out = _get_summary_tensor_data() + return layer_out[layer].asnumpy() + + def get_optimal_threshold(self, label, ds_eval): + """ + Get the optimal threshold. - def __init__(self, ds_train, ds_test, n_cluster=10): + Args: + label (numpy.ndarray): The label whether an image is in-distribution and out-of-distribution. + ds_eval (numpy.ndarray): The testing dataset to help find the threshold. + + Returns: + - float, the optimal threshold. + """ + pass + + def ood_predict(self, threshold, ds_test): + """ + The out-of-distribution detection. + Args: + threshold (float): the threshold to judge ood data. One can set value by experience + or use function get_optimal_threshold. + ds_test (numpy.ndarray): The testing dataset. + + Returns: + - numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. + """ + pass + + +class OodDetectorFeatureCluster(OodDetector): + """ + Train the OOD detector. Extract the training data features, and obtain the clustering centers. The distance between + the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD) + image or not. + + Args: + model (Model):The training model. + ds_train (numpy.ndarray): The training dataset. + n_cluster (int): The cluster number. Belonging to [2,100]. + Usually, n_cluster equals to the class number of the training dataset. + If the OOD detector performs poor in the testing dataset, we can increase the value of n_cluster + appropriately. + layer (str): The name of the feature layer. layer (str) is represented by + 'name[:Tensor]', where 'name' is given by users when training the model. + Please see more details about how to name the model layer in 'README.md'. + """ + + def __init__(self, model, ds_train, n_cluster, layer): + self.model = model self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) - self.ds_test = check_param_type('ds_test', ds_test, np.ndarray) self.n_cluster = check_param_type('n_cluster', n_cluster, int) self.n_cluster = check_param_in_range('n_cluster', n_cluster, 2, 100) + self.layer = check_param_type('layer', layer, str) + self.feature = self._feature_extract(model, ds_train, layer=self.layer) - def ood_detector(self): + def _feature_cluster(self): """ - The out-of-distribution detection. + Get the feature cluster. Returns: - - numpy.ndarray, the detection score of images. + - numpy.ndarray, the feature cluster. """ - clf = KMeans(n_clusters=self.n_cluster) - clf.fit_predict(self.ds_train) + clf.fit_predict(self.feature) feature_cluster = clf.cluster_centers_ + return feature_cluster + + def _get_ood_score(self, ds_test): + """ + Get the ood score. + + Args: + ds_test (numpy.ndarray): The testing dataset. + + Returns: + - float, the optimal threshold. + """ + feature_cluster = self._feature_cluster() + ds_test = self._feature_extract(self.model, ds_test, layer=self.layer) score = [] - for i in range(len(self.ds_test)): + for i in range(len(ds_test)): dis = [] for j in range(len(feature_cluster)): - loc = list(map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j])))) - diff = sum(abs((feature_cluster[j][loc] - self.ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc]))) + loc = list( + map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j])))) + diff = sum(abs((feature_cluster[j][loc] - ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc]))) dis.append(diff) score.append(min(dis)) score = np.array(score) return score + def get_optimal_threshold(self, label, ds_eval): + """ + Get the optimal threshold. -def result_eval(score, label, threshold): - """ - Evaluate the detection results. + Args: + label (numpy.ndarray): The label whether an image is in-distribution and out-of-distribution. + ds_eval (numpy.ndarray): The testing dataset to help find the threshold. - Args: - score (numpy.ndarray): The detection score of images. - label (numpy.ndarray): The label whether an image is in-ditribution and out-of-distribution. - threshold (float): The threshold to judge out-of-distribution distance. + Returns: + - float, the optimal threshold. + """ + check_param_type('label', label, np.ndarray) + check_param_type('ds_eval', ds_eval, np.ndarray) + score = self._get_ood_score(ds_eval) + acc = [] + threshold = [] + for threshold_change in np.arange(0.0, 1.0, 0.01): + count = 0 + for i in range(len(score)): + if score[i] < threshold_change and label[i] == 0: + count = count + 1 + if score[i] >= threshold_change and label[i] == 1: + count = count + 1 + acc.append(count / len(score)) + threshold.append(threshold_change) + acc = np.array(acc) + threshold = np.array(threshold) + optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0] + return optimal_threshold - Returns: - - float, the detection accuracy. - """ - check_param_type('label', label, np.ndarray) - check_param_type('threshold', threshold, float) - check_param_in_range('threshold', threshold, 0, 1) - count = 0 - for i in range(len(score)): - if score[i] < threshold and label[i] == 0: - count = count + 1 - if score[i] >= threshold and label[i] == 1: - count = count + 1 - return count / len(score) + def ood_predict(self, threshold, ds_test): + """ + The out-of-distribution detection. + Args: + threshold (float): the threshold to judge ood data. One can set value by experience + or use function get_optimal_threshold. + ds_test (numpy.ndarray): The testing dataset. + + Returns: + - numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. + """ + score = self._get_ood_score(ds_test) + result = [] + for i in range(len(score)): + if score[i] < threshold: + result.append(0) + if score[i] >= threshold: + result.append(1) + result = np.array(result) + return result diff --git a/tests/ut/python/dataset/concept_test_lenet1.npy b/tests/ut/python/dataset/concept_test_lenet1.npy new file mode 100644 index 0000000..ae87bcc Binary files /dev/null and b/tests/ut/python/dataset/concept_test_lenet1.npy differ diff --git a/tests/ut/python/dataset/concept_test_lenet2.npy b/tests/ut/python/dataset/concept_test_lenet2.npy new file mode 100644 index 0000000..1fe447e Binary files /dev/null and b/tests/ut/python/dataset/concept_test_lenet2.npy differ diff --git a/tests/ut/python/reliability/concept_drift/test_concept_drift_images.py b/tests/ut/python/reliability/concept_drift/test_concept_drift_images.py index ef25c79..21010b9 100644 --- a/tests/ut/python/reliability/concept_drift/test_concept_drift_images.py +++ b/tests/ut/python/reliability/concept_drift/test_concept_drift_images.py @@ -24,30 +24,13 @@ from mindspore.train.model import Model from mindarmour.utils.logger import LogUtil from mindspore import Model, nn, context from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 -from mindspore.train.summary.summary_record import _get_summary_tensor_data from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval +from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster LOGGER = LogUtil.get_instance() TAG = 'Concept_Test' -def feature_extract(data, feature_model, layer='output[:Tensor]'): - """ - Extract features. - Args: - data (numpy.ndarray): Input data. - feature_model (Model): The model for extracting features. - layer (str): The feature layer. The layer name could be 'output[:Tensor]', - '1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'. - - Returns: - numpy.ndarray, the feature of input data. - """ - feature_model.predict(Tensor(data)) - layer_out = _get_summary_tensor_data() - return layer_out[layer].asnumpy() - @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -66,20 +49,20 @@ def test_cp(): model = Model(net) # load data ds_train = np.load('../../dataset/concept_train_lenet.npy') - ds_test = np.load('../../dataset/concept_test_lenet.npy') - ds_train = feature_extract(ds_train, model, layer='9[:Tensor]') - ds_test = feature_extract(ds_test, model, layer='9[:Tensor]') - # ood detect - detector = OodDetector(ds_train, ds_test, n_cluster=10) - score = detector.ood_detector() - # Evaluation - num = int(len(ds_test)/2) + ds_eval = np.load('../../dataset/concept_test_lenet1.npy') + ds_test = np.load('../../dataset/concept_test_lenet2.npy') + # ood detector initialization + detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') + # get optimal threshold with ds_eval + num = int(len(ds_eval) / 2) label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 - dec_acc = result_eval(score, label, threshold=0.5) + optimal_threshold = detector.get_optimal_threshold(label, ds_eval) + # get result of ds_test. We can also set threshold by ourselves. + result = detector.ood_predict(optimal_threshold, ds_test) # result log LOGGER.set_level(logging.DEBUG) - LOGGER.debug(TAG, '--start concept drift test--') - LOGGER.debug(score, '--concept drift check score--') - LOGGER.debug(dec_acc, '--concept drift check accuracy--') - LOGGER.debug(TAG, '--end concept drift test--') - assert np.any(score >= 0.0) + LOGGER.debug(TAG, '--start ood test--') + LOGGER.debug(result, '--ood result--') + LOGGER.debug(optimal_threshold, '--the optimal threshold--') + LOGGER.debug(TAG, '--end ood test--') + assert np.any(result >= 0.0)