Browse Source

!284 fix ood detection

Merge pull request !284 from wuxiaoyu123456/master
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
44a1a4e75c
7 changed files with 309 additions and 152 deletions
  1. +10
    -28
      examples/reliability/concept_drift_check_images_lenet.py
  2. +11
    -29
      examples/reliability/concept_drift_check_images_resnet.py
  3. +131
    -31
      mindarmour/reliability/concept_drift/README.md
  4. +142
    -32
      mindarmour/reliability/concept_drift/concept_drift_check_images.py
  5. BIN
      tests/ut/python/dataset/concept_test_lenet1.npy
  6. BIN
      tests/ut/python/dataset/concept_test_lenet2.npy
  7. +15
    -32
      tests/ut/python/reliability/concept_drift/test_concept_drift_images.py

+ 10
- 28
examples/reliability/concept_drift_check_images_lenet.py View File

@@ -17,9 +17,8 @@ from mindspore import Tensor
from mindspore.train.model import Model
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster


"""
@@ -27,23 +26,6 @@ Examples for Lenet.
"""


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


if __name__ == '__main__':
# load model
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
@@ -53,13 +35,13 @@ if __name__ == '__main__':
model = Model(net)
# load data
ds_train = np.load('../../tests/ut/python/dataset/concept_train_lenet.npy')
ds_test = np.load('../../tests/ut/python/dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='output[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='output[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
ds_eval = np.load('../../tests/ut/python/dataset/concept_test_lenet1.npy')
ds_test = np.load('../../tests/ut/python/dataset/concept_test_lenet2.npy')
# ood detector initialization
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]')
# get optimal threshold with ds_eval
num = int(len(ds_eval) / 2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
optimal_threshold = detector.get_optimal_threshold(label, ds_eval)
# get result of ds_test2. We can also set threshold by ourselves.
result = detector.ood_predict(optimal_threshold, ds_test)

+ 11
- 29
examples/reliability/concept_drift_check_images_resnet.py View File

@@ -17,9 +17,8 @@ from mindspore import Tensor
from mindspore.train.model import Model
from mindspore import Model, nn, context
from examples.common.networks.resnet.resnet import resnet50
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster


"""
@@ -27,23 +26,6 @@ Examples for Resnet.
"""


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


if __name__ == '__main__':
# load model
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/resnet_1-20_1875.ckpt'
@@ -52,14 +34,14 @@ if __name__ == '__main__':
load_param_into_net(net, load_dict)
model = Model(net)
# load data
ds_train = np.load('./train.npy')
ds_test = np.load('./test.npy')
ds_train = feature_extract(ds_train, model, layer='output[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='output[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
ds_train = np.load('train.npy')
ds_eval = np.load('test1.npy')
ds_test = np.load('test2.npy')
# ood detector initialization
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]')
# get optimal threshold with ds_eval
num = int(len(ds_eval) / 2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
optimal_threshold = detector.get_optimal_threshold(label, ds_eval)
# get result of ds_test2. We can also set threshold by ourselves.
result = detector.ood_predict(optimal_threshold, ds_test)

+ 131
- 31
mindarmour/reliability/concept_drift/README.md View File

@@ -109,8 +109,8 @@ For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10
### Environment Requirements
- Hardware(Ascend)
- Prepare hardware environment with Ascend.
- Hardware
- Prepare hardware environment with Ascend, CPU and GPU.
- Framework
- MindSpore
- For more information, please check the resources below:
@@ -122,76 +122,176 @@ For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10
#### Import

```python
import logging
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.train.model import Model
from mindarmour.utils.logger import LogUtil
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serializaton import load_checkpoint, load_pram_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster
```

#### Load Classification Model

For convenience, we use a pre-trained model file `checkpoint_lenet-10_1875.ckpt`
in 'mindarmour/tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'.

```python
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
ckpt_path = 'checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_pram_into_net(net, load_dict)
load_param_into_net(net, load_dict)
model = Model(net)
```

>`ckpt_path(str)`:the model path.
>`ckpt_path(str)`: the model path.

#### Data processing

We extract the data features by the Lenet network.
We can also use self-constructed model.
It is important that we need to name the model layer, and get the layer outputs.
Take LeNet as an example.
Firstly, we import `TensorSummary` module.
Secondly, we initialize it as `self.summary = TensorSummary()`.
Finally, we add `self.summary('name', x)` after each layer we pay attention to. Here, `name` of each layer is given by users.
After the above process, we can train the model and load it.


```python
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import TensorSummary

def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""Wrap conv."""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")

def fc_with_initialize(input_channels, out_channels):
"""Wrap initialize method of full connection layer."""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)

def weight_variable():
"""Wrap initialize variable."""
return TruncatedNormal(0.05)

class LeNet5(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.summary = TensorSummary()

def construct(self, x):
"""
construct the network architecture
Returns:
x (tensor): network output
"""
x = self.conv1(x)
self.summary('1', x)

x = self.relu(x)
self.summary('2', x)

x = self.max_pool2d(x)
self.summary('3', x)

x = self.conv2(x)
self.summary('4', x)

x = self.relu(x)
self.summary('5', x)

x = self.max_pool2d(x)
self.summary('6', x)

x = self.flatten(x)
self.summary('7', x)

x = self.fc1(x)
self.summary('8', x)

x = self.relu(x)
self.summary('9', x)

x = self.fc2(x)
self.summary('10', x)

x = self.relu(x)
self.summary('11', x)

x = self.fc3(x)
self.summary('output', x)
return x

```
#### Load Data

We prepare three datasets. The training dataset, that is the same as the dataset to train the Lenet. Two testing datasets, the first testing dataset is with OOD label(0 for non-ood, and 1 for ood) for finding an optimal threshold for ood detection.
The second testing dataset is for ood validation. The first testing dataset is not necessary if we would like to set threshold by ourselves

```python
ds_train = np.load('../../dataset/concept_train_lenet.npy')
ds_test = np.load('../../dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='9[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='9[:Tensor]')
ds_eval = np.load('../../dataset/concept_test_lenet1.npy')
ds_test = np.load('../../dataset/concept_test_lenet2.npy')
```

> `ds_train(numpy.ndarray)`: the train data.
> `ds_test(numpy.ndarray)`: the test data.
> `model(Model)`: the Lenet model.
> `ds_eval(numpy.ndarray)`: the data for finding an optimal threshold. This dataset is not necessary.
> `ds_test(numpy.ndarray)`: the test data for ood detection.


#### Train the concept drift detector
#### OOD detector initialization

OOD detector for Lenet.


```python
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# ood detector initialization
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]')
```

> `ds_train(numpy.ndarray)`: the train data.
> `ds_test(numpy.ndarray)`: the test data.
> `model(Model)`: the model trained by the `ds_train`.
> `ds_train(numpy.ndarray)`: the training data.
> `n_cluster(int)`: the feature cluster number.
> `layer(str)`: the name of the feature extraction layer.

In our example, we input the layer name `output[:Tensor]`, which can also be`9[:Tensor]`, `10[:Tensor]`, `11[:Tensor]` for LeNet.


#### Evaluation
#### Optimal Threshold

This step is optional. If we have a labeled dataset, named `ds_eval`, we can use the following code to find the optimal detection threshold.

```python
num = int(len(ds_test)/2)
# get optimal threshold with ds_eval
num = int(len(ds_eval) / 2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
optimal_threshold = detector.get_optimal_threshold(label, ds_eval)
```

> `ds_test(numpy.ndarray)`: the test data.
> `score(numpy.ndarray)`: the concept drift score.
> `label(numpy.ndarray)`: the drift label.
> `threshold(float)`: the threshold to judge out-of-distribution.
> `ds_eval(numpy.ndarray)`: the data for finding an optimal threshold.
> `label(numpy.ndarray)`: the ood label of ds_eval. 0 means non-ood data, and 1 means ood data.

#### Detection result

```python
result = detector.ood_predict(optimal_threshold, ds_test)
```

> `ds_test(numpy.ndarray)`: the testing data for ood detection.
> `optimal_threshold(float)`: the optimal threshold to judge out-of-distribution data. We can also set the threshold value by ourselves.

## Script Description



+ 142
- 32
mindarmour/reliability/concept_drift/concept_drift_check_images.py View File

@@ -16,12 +16,14 @@

import heapq
import numpy as np
from mindspore import Tensor
from sklearn.cluster import KMeans
from mindarmour.utils._check_param import check_param_type, check_param_in_range
from mindspore.train.summary.summary_record import _get_summary_tensor_data


"""
Out-of-Distribution detection for images.
The sample can be run on Ascend 910 AI processor.
"""


@@ -30,58 +32,166 @@ class OodDetector:
Train the OOD detector.

Args:
model (Model):The training model.
ds_train (numpy.ndarray): The training dataset.
ds_test (numpy.ndarray): The testing dataset.
"""
def __init__(self, model, ds_train):
self.model = model
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray)

def _feature_extract(self, model, data, layer='output[:Tensor]'):
"""
Extract features.
Args:
model (Model): The model for extracting features.
data (numpy.ndarray): Input data.
layer (str): The name of the feature layer. layer (str) is represented as
'name[:Tensor]', where 'name' is given by users when training the model.
Please see more details about how to name the model layer in 'README.md'.

Returns:
numpy.ndarray, the data feature extracted by a certain neural layer.
"""
model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()

def get_optimal_threshold(self, label, ds_eval):
"""
Get the optimal threshold.

def __init__(self, ds_train, ds_test, n_cluster=10):
Args:
label (numpy.ndarray): The label whether an image is in-distribution and out-of-distribution.
ds_eval (numpy.ndarray): The testing dataset to help find the threshold.

Returns:
- float, the optimal threshold.
"""
pass

def ood_predict(self, threshold, ds_test):
"""
The out-of-distribution detection.
Args:
threshold (float): the threshold to judge ood data. One can set value by experience
or use function get_optimal_threshold.
ds_test (numpy.ndarray): The testing dataset.

Returns:
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood.
"""
pass


class OodDetectorFeatureCluster(OodDetector):
"""
Train the OOD detector. Extract the training data features, and obtain the clustering centers. The distance between
the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD)
image or not.

Args:
model (Model):The training model.
ds_train (numpy.ndarray): The training dataset.
n_cluster (int): The cluster number. Belonging to [2,100].
Usually, n_cluster equals to the class number of the training dataset.
If the OOD detector performs poor in the testing dataset, we can increase the value of n_cluster
appropriately.
layer (str): The name of the feature layer. layer (str) is represented by
'name[:Tensor]', where 'name' is given by users when training the model.
Please see more details about how to name the model layer in 'README.md'.
"""

def __init__(self, model, ds_train, n_cluster, layer):
self.model = model
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray)
self.ds_test = check_param_type('ds_test', ds_test, np.ndarray)
self.n_cluster = check_param_type('n_cluster', n_cluster, int)
self.n_cluster = check_param_in_range('n_cluster', n_cluster, 2, 100)
self.layer = check_param_type('layer', layer, str)
self.feature = self._feature_extract(model, ds_train, layer=self.layer)

def ood_detector(self):
def _feature_cluster(self):
"""
The out-of-distribution detection.
Get the feature cluster.

Returns:
- numpy.ndarray, the detection score of images.
- numpy.ndarray, the feature cluster.
"""

clf = KMeans(n_clusters=self.n_cluster)
clf.fit_predict(self.ds_train)
clf.fit_predict(self.feature)
feature_cluster = clf.cluster_centers_
return feature_cluster

def _get_ood_score(self, ds_test):
"""
Get the ood score.

Args:
ds_test (numpy.ndarray): The testing dataset.

Returns:
- float, the optimal threshold.
"""
feature_cluster = self._feature_cluster()
ds_test = self._feature_extract(self.model, ds_test, layer=self.layer)
score = []
for i in range(len(self.ds_test)):
for i in range(len(ds_test)):
dis = []
for j in range(len(feature_cluster)):
loc = list(map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j]))))
diff = sum(abs((feature_cluster[j][loc] - self.ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc])))
loc = list(
map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j]))))
diff = sum(abs((feature_cluster[j][loc] - ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc])))
dis.append(diff)
score.append(min(dis))
score = np.array(score)
return score

def get_optimal_threshold(self, label, ds_eval):
"""
Get the optimal threshold.

def result_eval(score, label, threshold):
"""
Evaluate the detection results.
Args:
label (numpy.ndarray): The label whether an image is in-distribution and out-of-distribution.
ds_eval (numpy.ndarray): The testing dataset to help find the threshold.

Args:
score (numpy.ndarray): The detection score of images.
label (numpy.ndarray): The label whether an image is in-ditribution and out-of-distribution.
threshold (float): The threshold to judge out-of-distribution distance.
Returns:
- float, the optimal threshold.
"""
check_param_type('label', label, np.ndarray)
check_param_type('ds_eval', ds_eval, np.ndarray)
score = self._get_ood_score(ds_eval)
acc = []
threshold = []
for threshold_change in np.arange(0.0, 1.0, 0.01):
count = 0
for i in range(len(score)):
if score[i] < threshold_change and label[i] == 0:
count = count + 1
if score[i] >= threshold_change and label[i] == 1:
count = count + 1
acc.append(count / len(score))
threshold.append(threshold_change)
acc = np.array(acc)
threshold = np.array(threshold)
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0]
return optimal_threshold

Returns:
- float, the detection accuracy.
"""
check_param_type('label', label, np.ndarray)
check_param_type('threshold', threshold, float)
check_param_in_range('threshold', threshold, 0, 1)
count = 0
for i in range(len(score)):
if score[i] < threshold and label[i] == 0:
count = count + 1
if score[i] >= threshold and label[i] == 1:
count = count + 1
return count / len(score)
def ood_predict(self, threshold, ds_test):
"""
The out-of-distribution detection.
Args:
threshold (float): the threshold to judge ood data. One can set value by experience
or use function get_optimal_threshold.
ds_test (numpy.ndarray): The testing dataset.

Returns:
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood.
"""
score = self._get_ood_score(ds_test)
result = []
for i in range(len(score)):
if score[i] < threshold:
result.append(0)
if score[i] >= threshold:
result.append(1)
result = np.array(result)
return result

BIN
tests/ut/python/dataset/concept_test_lenet1.npy View File


BIN
tests/ut/python/dataset/concept_test_lenet2.npy View File


+ 15
- 32
tests/ut/python/reliability/concept_drift/test_concept_drift_images.py View File

@@ -24,30 +24,13 @@ from mindspore.train.model import Model
from mindarmour.utils.logger import LogUtil
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster

LOGGER = LogUtil.get_instance()
TAG = 'Concept_Test'


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@@ -66,20 +49,20 @@ def test_cp():
model = Model(net)
# load data
ds_train = np.load('../../dataset/concept_train_lenet.npy')
ds_test = np.load('../../dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='9[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='9[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
ds_eval = np.load('../../dataset/concept_test_lenet1.npy')
ds_test = np.load('../../dataset/concept_test_lenet2.npy')
# ood detector initialization
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]')
# get optimal threshold with ds_eval
num = int(len(ds_eval) / 2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
optimal_threshold = detector.get_optimal_threshold(label, ds_eval)
# get result of ds_test. We can also set threshold by ourselves.
result = detector.ood_predict(optimal_threshold, ds_test)
# result log
LOGGER.set_level(logging.DEBUG)
LOGGER.debug(TAG, '--start concept drift test--')
LOGGER.debug(score, '--concept drift check score--')
LOGGER.debug(dec_acc, '--concept drift check accuracy--')
LOGGER.debug(TAG, '--end concept drift test--')
assert np.any(score >= 0.0)
LOGGER.debug(TAG, '--start ood test--')
LOGGER.debug(result, '--ood result--')
LOGGER.debug(optimal_threshold, '--the optimal threshold--')
LOGGER.debug(TAG, '--end ood test--')
assert np.any(result >= 0.0)

Loading…
Cancel
Save