Browse Source

!271 Concept drift for images

Merge pull request !271 from wuxiaoyu123456/master
tags/v1.6.0
i-robot Gitee 3 years ago
parent
commit
ccb1a17fce
10 changed files with 908 additions and 21 deletions
  1. +74
    -2
      examples/common/dataset/data_processing.py
  2. +0
    -0
      examples/common/networks/resnet/__init__.py
  3. +401
    -0
      examples/common/networks/resnet/resnet.py
  4. +65
    -0
      examples/reliability/concept_drift_check_images_lenet.py
  5. +65
    -0
      examples/reliability/concept_drift_check_images_resnet.py
  6. +131
    -19
      mindarmour/reliability/concept_drift/README.md
  7. +87
    -0
      mindarmour/reliability/concept_drift/concept_drift_check_images.py
  8. BIN
      tests/ut/python/dataset/concept_test_lenet.npy
  9. BIN
      tests/ut/python/dataset/concept_train_lenet.npy
  10. +85
    -0
      tests/ut/python/reliability/concept_drift/test_concept_drift_images.py

+ 74
- 2
examples/common/dataset/data_processing.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
"""data processing"""

import os
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
@@ -114,3 +115,74 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz
# apply repeat operations
data_set = data_set.repeat(repeat_num)
return data_set


def create_dataset_imagenet(path, batch_size=32, repeat_size=20, status="train", target="GPU"):
image_ds = ds.ImageFolderDataset(path, decode=True)
rescale = 1.0 / 255.0
shift = 0.0
cfg = {'num_classes': 10,
'learning_rate': 0.002,
'momentum': 0.9,
'epoch_size': 30,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 224,
'image_width': 224,
'save_checkpoint_steps': 1562,
'keep_checkpoint_max': 10}
resize_op = CV.Resize((cfg['image_height'], cfg['image_width']))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32)
image_ds = image_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=6)

image_ds = image_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=6)
image_ds = image_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=6)
image_ds = image_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=6)
image_ds = image_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=6)
image_ds = image_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=6)
image_ds = image_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=6)

image_ds = image_ds.shuffle(buffer_size=cfg['buffer_size'])
image_ds = image_ds.repeat(repeat_size)
return image_ds


def create_dataset_cifar(data_path, image_height, image_width, repeat_num=1, training=True):
"""
create data for next use such as training or infering
"""
cifar_ds = ds.Cifar10Dataset(data_path)
resize_height = image_height # 224
resize_width = image_width # 224
rescale = 1.0 / 255.0
shift = 0.0
batch_size = 32
# define map operations
random_crop_op = CV.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
random_horizontal_op = CV.RandomHorizontalFlip()
resize_op = CV.Resize((resize_height, resize_width)) # interpolation default BILINEAR
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
c_trans = []
if training:
c_trans = [random_crop_op, random_horizontal_op]
c_trans += [resize_op, rescale_op, normalize_op,
changeswap_op]
# apply map operations on images
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
# apply shuffle operations
cifar_ds = cifar_ds.shuffle(buffer_size=10)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True)
# apply repeat operations
cifar_ds = cifar_ds.repeat(repeat_num)
return cifar_ds

+ 0
- 0
examples/common/networks/resnet/__init__.py View File


+ 401
- 0
examples/common/networks/resnet/resnet.py View File

@@ -0,0 +1,401 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm
from mindspore.ops import TensorSummary


def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size*kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)

def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)


def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)


class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.

Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block (bool): use se block in SE-ResNet50 net. Default: False.

Returns:
Tensor, output tensor.

Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4

def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.summary = TensorSummary()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)

self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()

self.down_sample = False

if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None

if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = P.TensorAdd()

def construct(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)

if self.down_sample:
identity = self.down_sample_layer(identity)

out = self.add(out, identity)
out = self.relu(out)

return out


class ResNet(nn.Cell):
"""
ResNet architecture.

Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block (bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.

Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""

def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
use_se=False):
super(ResNet, self).__init__()

if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.se_block = False
if self.use_se:
self.se_block = True

if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
self.summary = TensorSummary()

def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.

Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block (bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.

Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []

resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)

def construct(self, x):
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)

c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)

out = self.mean(c5, (2, 3))
out = self.flatten(out)
self.summary('1', out)
out = self.end_point(out)
if self.training:
return out
self.summary('output', out)
return out


def resnet50(class_num=10):
"""
Get ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet50 neural network.

Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of SE-ResNet50 neural network.

Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)

def resnet101(class_num=1001):
"""
Get ResNet101 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet101 neural network.

Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)


+ 65
- 0
examples/reliability/concept_drift_check_images_lenet.py View File

@@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from mindspore import Tensor
from mindspore.train.model import Model
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval


"""
Examples for Lenet.
"""


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


if __name__ == '__main__':
# load model
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
model = Model(net)
# load data
ds_train = np.load('../../tests/ut/python/dataset/concept_train_lenet.npy')
ds_test = np.load('../../tests/ut/python/dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='output[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='output[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)

+ 65
- 0
examples/reliability/concept_drift_check_images_resnet.py View File

@@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from mindspore import Tensor
from mindspore.train.model import Model
from mindspore import Model, nn, context
from examples.common.networks.resnet.resnet import resnet50
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval


"""
Examples for Resnet.
"""


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


if __name__ == '__main__':
# load model
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/resnet_1-20_1875.ckpt'
net = resnet50()
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
model = Model(net)
# load data
ds_train = np.load('./train.npy')
ds_test = np.load('./test.npy')
ds_train = feature_extract(ds_train, model, layer='output[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='output[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)

+ 131
- 19
mindarmour/reliability/concept_drift/README.md View File

@@ -1,10 +1,10 @@
# Content
# Concept Drift

## Concept drift Description

In predictive analytics and machine learning, the concept drift means that the statistical properties of the target variable, which the model is trying to predict, change over time in unforeseen ways. This causes problems because the predictions become less accurate as time passes. Usually, concept drift is described as the change of data distribution over time.

## Method
## Method for time series

### Model Architecture

@@ -13,9 +13,9 @@ For time series concept drift detection

### Detector

For a time series, we select two adjacent time window and compare the features of the two window data to determine whether concept drift has occurred. For feature extraction, we choose to use the ESN network. The input of the ESN network is a certain window data, and the output is also the window data (like an auto-encoder). In this way, the ESN network is equivalent to a feature extractor. Features are represented by model parameters (weights and bias) of the ESN network. Finally, by comparing the difference of model parameters, we can determine whether the data has concept drift. It should be noted that the two windows are constantly sliding forward.
For time series, we select two adjacent time window and compare the features of the two window data to determine whether concept drift has occurred. For feature extraction, we choose to use the ESN network. The input of the ESN network is a certain window data, and the output is also the window data (like an auto-encoder). In this way, the ESN network is equivalent to a feature extractor. Features are represented by model parameters (weights and bias) of the ESN network. Finally, by comparing the difference of model parameters, we can determine whether the data has concept drift. It should be noted that the two windows are constantly sliding forward.

## Dataset
### Dataset

Download dataset https://www.kaggle.com/camnugent/sandp500.

@@ -32,7 +32,7 @@ Download dataset https://www.kaggle.com/camnugent/sandp500.
Please use the data in archive/individual_stocks_5yr/individual_stocks_5yr/XX.csv.
In each csv file, there are 'date','open','high','low','close','volume','Name' columns, please choose one column to begin your code. 'date' and 'Name' are non-data column.

## Environment Requirements
### Environment Requirements

- Hardware(CPU/Ascend/GPU)
- Prepare hardware environment with CPU, Ascend or GPU processor.
@@ -42,9 +42,9 @@ In each csv file, there are 'date','open','high','low','close','volume','Name' c
- MindSpore Tutorials
- MindSpore Python API

## Quick Start
### Quick Start

### Initialization
#### Initialization

```python
from mindarmour.reliability.concept_drift.concept_drift_check_time_series import ConceptDriftCheckTimeSeries
@@ -53,13 +53,13 @@ concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, step=1
need_label=False)
```

>window_size(int): Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default: 100.
rolling_window(int): Smoothing window size, belongs to [1, window_size]. Default:10.
step(int): The jump length of the sliding window, belongs to [1,window_size]. Default:10.
threshold_index(float): The threshold index. Default: 1.5.
need_label(bool): False or True. If need_label=True, concept drift labels are needed. Default: False.
>`window_size(int)`: Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default: 100.
`rolling_window(int)`: Smoothing window size, belongs to [1, window_size]. Default:10.
`step(int)`: The jump length of the sliding window, belongs to [1,window_size]. Default:10.
`threshold_index(float)`: The threshold index. Default: 1.5.
`need_label(bool)`: False or True. If need_label=True, concept drift labels are needed. Default: False.

### Data
#### Data

```python
import numpy as np
@@ -68,27 +68,139 @@ data = np.loadtxt(file, str, delimiter=",")
data = data[1:, 2].astype('float64') # here we choose one column or multiple columns data[1:, 2:5].
```

>data(numpy.ndarray): Input data. The shape of data could be (n,1) or (n,m).
>`data(numpy.ndarray)`: Input data. The shape of data could be (n,1) or (n,m).

### Drift check
#### Drift check

```python
drift_score, threshold, concept_drift_location = concept.concept_check(data)
# the result is saved as pdf named 'concept_drift_check.pdf'
```

>drift_score(numpy.ndarray): The concept drift score of the example series.
threshold(float): The threshold to judge concept drift.
concept_drift_location(list): The location of the concept drift.
>`drift_score(numpy.ndarray)`: The concept drift score of the example series.
`threshold(float)`: The threshold to judge concept drift.
`concept_drift_location(list)`: The location of the concept drift.


## Method for images

Generally, neural networks are used to process images. Therefore, we use algorithms based on neural networks to detect concept drifts of images.
For image data, there is a special term that describes the concept drift in detail, Out-of-Distribution(`OOD`).
Hereinafter, we will use the term `OOD` to describe concept drifts in images. As for non-drift images, we use the term In-Distribution(`ID`).
### Model Architecture
The model structure can be any neural network structure, such as DNN, CNN, and RNN.
Here, we select LeNet and ResNet as examples.
### Detector
Firstly, obtain the features of the training data, the features are the outputs of a selected neural layer.
Secondly, the features are clustered to obtain the clustering centers.
Finally, the features of the testing data in the same neural network layer are obtained, and the distance between the testing data features and the clustering center is calculated.
When the distance exceeds the threshold, the image is determined as an out-of-distribution(OOD) image.
### DataSet
We prepared two pairs of dataset for LeNet and ResNet separately.
For LeNet, the training data is Mnist as ID data. The testing data is Mnist + Cifar10. Cifar10 is OOD data.
For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10 + ImageNet. ImageNet is OOD data.
### Environment Requirements
- Hardware(Ascend)
- Prepare hardware environment with Ascend.
- Framework
- MindSpore
- For more information, please check the resources below:
- MindSpore Tutorials
- MindSpore Python API

### Quick Start

#### Import

```python
import logging
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.train.model import Model
from mindarmour.utils.logger import LogUtil
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serializaton import load_checkpoint, load_pram_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval
```

#### Load Classification Model

```python
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_pram_into_net(net, load_dict)
model = Model(net)
```

>`ckpt_path(str)`:the model path.

#### Data processing

We extract the data features by the Lenet network.


```python
ds_train = np.load('../../dataset/concept_train_lenet.npy')
ds_test = np.load('../../dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='9[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='9[:Tensor]')
```

> `ds_train(numpy.ndarray)`: the train data.
> `ds_test(numpy.ndarray)`: the test data.
> `model(Model)`: the Lenet model.


#### Train the concept drift detector

OOD detector for Lenet.


```python
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
```

> `ds_train(numpy.ndarray)`: the train data.
> `ds_test(numpy.ndarray)`: the test data.
> `n_cluster(int)`: the feature cluster number.


#### Evaluation

```python
num = int(len(ds_test)/2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
```

> `ds_test(numpy.ndarray)`: the test data.
> `score(numpy.ndarray)`: the concept drift score.
> `label(numpy.ndarray)`: the drift label.
> `threshold(float)`: the threshold to judge out-of-distribution.


## Script Description

```bash
├── mindarmour
├── reliability # descriptions about GhostNet # shell script for evaluation with CPU, GPU or Ascend
├── reliability
├──concept_drift
├──__init__.py
├──concept_drift_check_images.py
├──concept_drift_check_time_series.py
├──README.md
```


+ 87
- 0
mindarmour/reliability/concept_drift/concept_drift_check_images.py View File

@@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================


import heapq
import numpy as np
from sklearn.cluster import KMeans
from mindarmour.utils._check_param import check_param_type, check_param_in_range

"""
Out-of-Distribution detection for images.
The sample can be run on Ascend 910 AI processor.
"""


class OodDetector:
"""
Train the OOD detector.

Args:
ds_train (numpy.ndarray): The training dataset.
ds_test (numpy.ndarray): The testing dataset.
"""

def __init__(self, ds_train, ds_test, n_cluster=10):
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray)
self.ds_test = check_param_type('ds_test', ds_test, np.ndarray)
self.n_cluster = check_param_type('n_cluster', n_cluster, int)
self.n_cluster = check_param_in_range('n_cluster', n_cluster, 2, 100)

def ood_detector(self):
"""
The out-of-distribution detection.

Returns:
- numpy.ndarray, the detection score of images.
"""

clf = KMeans(n_clusters=self.n_cluster)
clf.fit_predict(self.ds_train)
feature_cluster = clf.cluster_centers_
score = []
for i in range(len(self.ds_test)):
dis = []
for j in range(len(feature_cluster)):
loc = list(map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j]))))
diff = sum(abs((feature_cluster[j][loc] - self.ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc])))
dis.append(diff)
score.append(min(dis))
score = np.array(score)
return score


def result_eval(score, label, threshold):
"""
Evaluate the detection results.

Args:
score (numpy.ndarray): The detection score of images.
label (numpy.ndarray): The label whether an image is in-ditribution and out-of-distribution.
threshold (float): The threshold to judge out-of-distribution distance.

Returns:
- float, the detection accuracy.
"""
check_param_type('label', label, np.ndarray)
check_param_type('threshold', threshold, float)
check_param_in_range('threshold', threshold, 0, 1)
count = 0
for i in range(len(score)):
if score[i] < threshold and label[i] == 0:
count = count + 1
if score[i] >= threshold and label[i] == 1:
count = count + 1
return count / len(score)

BIN
tests/ut/python/dataset/concept_test_lenet.npy View File


BIN
tests/ut/python/dataset/concept_train_lenet.npy View File


+ 85
- 0
tests/ut/python/reliability/concept_drift/test_concept_drift_images.py View File

@@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Concept drift test for images.
"""

import logging
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.train.model import Model
from mindarmour.utils.logger import LogUtil
from mindspore import Model, nn, context
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetector, result_eval

LOGGER = LogUtil.get_instance()
TAG = 'Concept_Test'


def feature_extract(data, feature_model, layer='output[:Tensor]'):
"""
Extract features.
Args:
data (numpy.ndarray): Input data.
feature_model (Model): The model for extracting features.
layer (str): The feature layer. The layer name could be 'output[:Tensor]',
'1[:Tensor]', '2[:Tensor]',...'10[:Tensor]'.

Returns:
numpy.ndarray, the feature of input data.
"""
feature_model.predict(Tensor(data))
layer_out = _get_summary_tensor_data()
return layer_out[layer].asnumpy()


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_cp():
"""
Concept drift test
"""
# load model
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
model = Model(net)
# load data
ds_train = np.load('../../dataset/concept_train_lenet.npy')
ds_test = np.load('../../dataset/concept_test_lenet.npy')
ds_train = feature_extract(ds_train, model, layer='9[:Tensor]')
ds_test = feature_extract(ds_test, model, layer='9[:Tensor]')
# ood detect
detector = OodDetector(ds_train, ds_test, n_cluster=10)
score = detector.ood_detector()
# Evaluation
num = int(len(ds_test)/2)
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1
dec_acc = result_eval(score, label, threshold=0.5)
# result log
LOGGER.set_level(logging.DEBUG)
LOGGER.debug(TAG, '--start concept drift test--')
LOGGER.debug(score, '--concept drift check score--')
LOGGER.debug(dec_acc, '--concept drift check accuracy--')
LOGGER.debug(TAG, '--end concept drift test--')
assert np.any(score >= 0.0)

Loading…
Cancel
Save