@@ -1,3 +1,21 @@ | |||
# MindArmour 1.5.0 | |||
## MindArmour 1.5.0 Release Notes | |||
### Major Features and Improvements | |||
#### Reliability | |||
* [BETA] Reconstruct AI Fuzz and Neuron Coverage Metrics | |||
### Bug fixes | |||
### Contributors | |||
Thanks goes to these wonderful people: | |||
Wu Xiaoyu,Liu Zhidan, Jin Xiulang, Liu Luobin, Liu Liu | |||
# MindArmour 1.3.0-rc1 | |||
## MindArmour 1.3.0 Release Notes | |||
@@ -1,4 +1,4 @@ | |||
# Copyright 2019 Huawei Technologies Co., Ltd | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
@@ -11,8 +11,9 @@ | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import os | |||
"""data processing""" | |||
import os | |||
import mindspore.dataset as ds | |||
import mindspore.dataset.vision.c_transforms as CV | |||
import mindspore.dataset.transforms.c_transforms as C | |||
@@ -114,3 +115,74 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz | |||
# apply repeat operations | |||
data_set = data_set.repeat(repeat_num) | |||
return data_set | |||
def create_dataset_imagenet(path, batch_size=32, repeat_size=20, status="train", target="GPU"): | |||
image_ds = ds.ImageFolderDataset(path, decode=True) | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
cfg = {'num_classes': 10, | |||
'learning_rate': 0.002, | |||
'momentum': 0.9, | |||
'epoch_size': 30, | |||
'batch_size': 32, | |||
'buffer_size': 1000, | |||
'image_height': 224, | |||
'image_width': 224, | |||
'save_checkpoint_steps': 1562, | |||
'keep_checkpoint_max': 10} | |||
resize_op = CV.Resize((cfg['image_height'], cfg['image_width'])) | |||
rescale_op = CV.Rescale(rescale, shift) | |||
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |||
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) | |||
random_horizontal_op = CV.RandomHorizontalFlip() | |||
channel_swap_op = CV.HWC2CHW() | |||
typecast_op = C.TypeCast(mstype.int32) | |||
image_ds = image_ds.map(input_columns="label", operations=typecast_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=random_crop_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=random_horizontal_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=normalize_op, num_parallel_workers=6) | |||
image_ds = image_ds.map(input_columns="image", operations=channel_swap_op, num_parallel_workers=6) | |||
image_ds = image_ds.shuffle(buffer_size=cfg['buffer_size']) | |||
image_ds = image_ds.repeat(repeat_size) | |||
return image_ds | |||
def create_dataset_cifar(data_path, image_height, image_width, repeat_num=1, training=True): | |||
""" | |||
create data for next use such as training or infering | |||
""" | |||
cifar_ds = ds.Cifar10Dataset(data_path) | |||
resize_height = image_height # 224 | |||
resize_width = image_width # 224 | |||
rescale = 1.0 / 255.0 | |||
shift = 0.0 | |||
batch_size = 32 | |||
# define map operations | |||
random_crop_op = CV.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT | |||
random_horizontal_op = CV.RandomHorizontalFlip() | |||
resize_op = CV.Resize((resize_height, resize_width)) # interpolation default BILINEAR | |||
rescale_op = CV.Rescale(rescale, shift) | |||
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |||
changeswap_op = CV.HWC2CHW() | |||
type_cast_op = C.TypeCast(mstype.int32) | |||
c_trans = [] | |||
if training: | |||
c_trans = [random_crop_op, random_horizontal_op] | |||
c_trans += [resize_op, rescale_op, normalize_op, | |||
changeswap_op] | |||
# apply map operations on images | |||
cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label") | |||
cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image") | |||
# apply shuffle operations | |||
cifar_ds = cifar_ds.shuffle(buffer_size=10) | |||
# apply batch operations | |||
cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True) | |||
# apply repeat operations | |||
cifar_ds = cifar_ds.repeat(repeat_num) | |||
return cifar_ds |
@@ -0,0 +1,401 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
import numpy as np | |||
import mindspore.nn as nn | |||
import mindspore.common.dtype as mstype | |||
from mindspore.ops import operations as P | |||
from mindspore.ops import functional as F | |||
from mindspore.common.tensor import Tensor | |||
from scipy.stats import truncnorm | |||
from mindspore.ops import TensorSummary | |||
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): | |||
fan_in = in_channel * kernel_size * kernel_size | |||
scale = 1.0 | |||
scale /= max(1., fan_in) | |||
stddev = (scale ** 0.5) / .87962566103423978 | |||
mu, sigma = 0, stddev | |||
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size*kernel_size) | |||
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) | |||
return Tensor(weight, dtype=mstype.float32) | |||
def _weight_variable(shape, factor=0.01): | |||
init_value = np.random.randn(*shape).astype(np.float32) * factor | |||
return Tensor(init_value) | |||
def _conv3x3(in_channel, out_channel, stride=1, use_se=False): | |||
if use_se: | |||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) | |||
else: | |||
weight_shape = (out_channel, in_channel, 3, 3) | |||
weight = _weight_variable(weight_shape) | |||
return nn.Conv2d(in_channel, out_channel, | |||
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||
def _conv1x1(in_channel, out_channel, stride=1, use_se=False): | |||
if use_se: | |||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) | |||
else: | |||
weight_shape = (out_channel, in_channel, 1, 1) | |||
weight = _weight_variable(weight_shape) | |||
return nn.Conv2d(in_channel, out_channel, | |||
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||
def _conv7x7(in_channel, out_channel, stride=1, use_se=False): | |||
if use_se: | |||
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) | |||
else: | |||
weight_shape = (out_channel, in_channel, 7, 7) | |||
weight = _weight_variable(weight_shape) | |||
return nn.Conv2d(in_channel, out_channel, | |||
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||
def _bn(channel): | |||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | |||
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) | |||
def _bn_last(channel): | |||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, | |||
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) | |||
def _fc(in_channel, out_channel, use_se=False): | |||
if use_se: | |||
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel) | |||
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) | |||
else: | |||
weight_shape = (out_channel, in_channel) | |||
weight = _weight_variable(weight_shape) | |||
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) | |||
class ResidualBlock(nn.Cell): | |||
""" | |||
ResNet V1 residual block definition. | |||
Args: | |||
in_channel (int): Input channel. | |||
out_channel (int): Output channel. | |||
stride (int): Stride size for the first convolutional layer. Default: 1. | |||
use_se (bool): enable SE-ResNet50 net. Default: False. | |||
se_block (bool): use se block in SE-ResNet50 net. Default: False. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> ResidualBlock(3, 256, stride=2) | |||
""" | |||
expansion = 4 | |||
def __init__(self, | |||
in_channel, | |||
out_channel, | |||
stride=1, | |||
use_se=False, se_block=False): | |||
super(ResidualBlock, self).__init__() | |||
self.summary = TensorSummary() | |||
self.stride = stride | |||
self.use_se = use_se | |||
self.se_block = se_block | |||
channel = out_channel // self.expansion | |||
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) | |||
self.bn1 = _bn(channel) | |||
if self.use_se and self.stride != 1: | |||
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), | |||
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) | |||
else: | |||
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) | |||
self.bn2 = _bn(channel) | |||
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) | |||
self.bn3 = _bn_last(out_channel) | |||
if self.se_block: | |||
self.se_global_pool = P.ReduceMean(keep_dims=False) | |||
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se) | |||
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se) | |||
self.se_sigmoid = nn.Sigmoid() | |||
self.se_mul = P.Mul() | |||
self.relu = nn.ReLU() | |||
self.down_sample = False | |||
if stride != 1 or in_channel != out_channel: | |||
self.down_sample = True | |||
self.down_sample_layer = None | |||
if self.down_sample: | |||
if self.use_se: | |||
if stride == 1: | |||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, | |||
stride, use_se=self.use_se), _bn(out_channel)]) | |||
else: | |||
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), | |||
_conv1x1(in_channel, out_channel, 1, | |||
use_se=self.use_se), _bn(out_channel)]) | |||
else: | |||
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, | |||
use_se=self.use_se), _bn(out_channel)]) | |||
self.add = P.TensorAdd() | |||
def construct(self, x): | |||
identity = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
if self.use_se and self.stride != 1: | |||
out = self.e2(out) | |||
else: | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
out = self.relu(out) | |||
out = self.conv3(out) | |||
out = self.bn3(out) | |||
if self.se_block: | |||
out_se = out | |||
out = self.se_global_pool(out, (2, 3)) | |||
out = self.se_dense_0(out) | |||
out = self.relu(out) | |||
out = self.se_dense_1(out) | |||
out = self.se_sigmoid(out) | |||
out = F.reshape(out, F.shape(out) + (1, 1)) | |||
out = self.se_mul(out, out_se) | |||
if self.down_sample: | |||
identity = self.down_sample_layer(identity) | |||
out = self.add(out, identity) | |||
out = self.relu(out) | |||
return out | |||
class ResNet(nn.Cell): | |||
""" | |||
ResNet architecture. | |||
Args: | |||
block (Cell): Block for network. | |||
layer_nums (list): Numbers of block in different layers. | |||
in_channels (list): Input channel in each layer. | |||
out_channels (list): Output channel in each layer. | |||
strides (list): Stride size in each layer. | |||
num_classes (int): The number of classes that the training images are belonging to. | |||
use_se (bool): enable SE-ResNet50 net. Default: False. | |||
se_block (bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. | |||
Returns: | |||
Tensor, output tensor. | |||
Examples: | |||
>>> ResNet(ResidualBlock, | |||
>>> [3, 4, 6, 3], | |||
>>> [64, 256, 512, 1024], | |||
>>> [256, 512, 1024, 2048], | |||
>>> [1, 2, 2, 2], | |||
>>> 10) | |||
""" | |||
def __init__(self, | |||
block, | |||
layer_nums, | |||
in_channels, | |||
out_channels, | |||
strides, | |||
num_classes, | |||
use_se=False): | |||
super(ResNet, self).__init__() | |||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: | |||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") | |||
self.use_se = use_se | |||
self.se_block = False | |||
if self.use_se: | |||
self.se_block = True | |||
if self.use_se: | |||
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) | |||
self.bn1_0 = _bn(32) | |||
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) | |||
self.bn1_1 = _bn(32) | |||
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) | |||
else: | |||
self.conv1 = _conv7x7(3, 64, stride=2) | |||
self.bn1 = _bn(64) | |||
self.relu = P.ReLU() | |||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") | |||
self.layer1 = self._make_layer(block, | |||
layer_nums[0], | |||
in_channel=in_channels[0], | |||
out_channel=out_channels[0], | |||
stride=strides[0], | |||
use_se=self.use_se) | |||
self.layer2 = self._make_layer(block, | |||
layer_nums[1], | |||
in_channel=in_channels[1], | |||
out_channel=out_channels[1], | |||
stride=strides[1], | |||
use_se=self.use_se) | |||
self.layer3 = self._make_layer(block, | |||
layer_nums[2], | |||
in_channel=in_channels[2], | |||
out_channel=out_channels[2], | |||
stride=strides[2], | |||
use_se=self.use_se, | |||
se_block=self.se_block) | |||
self.layer4 = self._make_layer(block, | |||
layer_nums[3], | |||
in_channel=in_channels[3], | |||
out_channel=out_channels[3], | |||
stride=strides[3], | |||
use_se=self.use_se, | |||
se_block=self.se_block) | |||
self.mean = P.ReduceMean(keep_dims=True) | |||
self.flatten = nn.Flatten() | |||
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) | |||
self.summary = TensorSummary() | |||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): | |||
""" | |||
Make stage network of ResNet. | |||
Args: | |||
block (Cell): Resnet block. | |||
layer_num (int): Layer number. | |||
in_channel (int): Input channel. | |||
out_channel (int): Output channel. | |||
stride (int): Stride size for the first convolutional layer. | |||
se_block (bool): use se block in SE-ResNet50 net. Default: False. | |||
Returns: | |||
SequentialCell, the output layer. | |||
Examples: | |||
>>> _make_layer(ResidualBlock, 3, 128, 256, 2) | |||
""" | |||
layers = [] | |||
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) | |||
layers.append(resnet_block) | |||
if se_block: | |||
for _ in range(1, layer_num - 1): | |||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) | |||
layers.append(resnet_block) | |||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) | |||
layers.append(resnet_block) | |||
else: | |||
for _ in range(1, layer_num): | |||
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) | |||
layers.append(resnet_block) | |||
return nn.SequentialCell(layers) | |||
def construct(self, x): | |||
if self.use_se: | |||
x = self.conv1_0(x) | |||
x = self.bn1_0(x) | |||
x = self.relu(x) | |||
x = self.conv1_1(x) | |||
x = self.bn1_1(x) | |||
x = self.relu(x) | |||
x = self.conv1_2(x) | |||
else: | |||
x = self.conv1(x) | |||
x = self.bn1(x) | |||
x = self.relu(x) | |||
c1 = self.maxpool(x) | |||
c2 = self.layer1(c1) | |||
c3 = self.layer2(c2) | |||
c4 = self.layer3(c3) | |||
c5 = self.layer4(c4) | |||
out = self.mean(c5, (2, 3)) | |||
out = self.flatten(out) | |||
self.summary('1', out) | |||
out = self.end_point(out) | |||
if self.training: | |||
return out | |||
self.summary('output', out) | |||
return out | |||
def resnet50(class_num=10): | |||
""" | |||
Get ResNet50 neural network. | |||
Args: | |||
class_num (int): Class number. | |||
Returns: | |||
Cell, cell instance of ResNet50 neural network. | |||
Examples: | |||
>>> net = resnet50(10) | |||
""" | |||
return ResNet(ResidualBlock, | |||
[3, 4, 6, 3], | |||
[64, 256, 512, 1024], | |||
[256, 512, 1024, 2048], | |||
[1, 2, 2, 2], | |||
class_num) | |||
def se_resnet50(class_num=1001): | |||
""" | |||
Get SE-ResNet50 neural network. | |||
Args: | |||
class_num (int): Class number. | |||
Returns: | |||
Cell, cell instance of SE-ResNet50 neural network. | |||
Examples: | |||
>>> net = se-resnet50(1001) | |||
""" | |||
return ResNet(ResidualBlock, | |||
[3, 4, 6, 3], | |||
[64, 256, 512, 1024], | |||
[256, 512, 1024, 2048], | |||
[1, 2, 2, 2], | |||
class_num, | |||
use_se=True) | |||
def resnet101(class_num=1001): | |||
""" | |||
Get ResNet101 neural network. | |||
Args: | |||
class_num (int): Class number. | |||
Returns: | |||
Cell, cell instance of ResNet101 neural network. | |||
Examples: | |||
>>> net = resnet101(1001) | |||
""" | |||
return ResNet(ResidualBlock, | |||
[3, 4, 23, 3], | |||
[64, 256, 512, 1024], | |||
[256, 512, 1024, 2048], | |||
[1, 2, 2, 2], | |||
class_num) | |||
@@ -0,0 +1,47 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.train.model import Model | |||
from mindspore import Model, nn, context | |||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster | |||
""" | |||
Examples for Lenet. | |||
""" | |||
if __name__ == '__main__': | |||
# load model | |||
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = LeNet5() | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
# load data | |||
ds_train = np.load('../../tests/ut/python/dataset/concept_train_lenet.npy') | |||
ds_test1 = np.load('../../tests/ut/python/dataset/concept_test_lenet1.npy') | |||
ds_test2 = np.load('../../tests/ut/python/dataset/concept_test_lenet2.npy') | |||
# ood detector initialization | |||
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') | |||
# get optimal threshold with ds_test1 | |||
num = int(len(ds_test1) / 2) | |||
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 | |||
optimal_threshold = detector.get_optimal_threshold(label, ds_test1) | |||
# get result of ds_test2. We can also set threshold by ourself. | |||
result = detector.ood_predict(optimal_threshold, ds_test2) |
@@ -0,0 +1,47 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.train.model import Model | |||
from mindspore import Model, nn, context | |||
from examples.common.networks.resnet.resnet import resnet50 | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster | |||
""" | |||
Examples for Resnet. | |||
""" | |||
if __name__ == '__main__': | |||
# load model | |||
ckpt_path = '../../tests/ut/python/dataset/trained_ckpt_file/resnet_1-20_1875.ckpt' | |||
net = resnet50() | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
# load data | |||
ds_train = np.load('train.npy') | |||
ds_test1 = np.load('test1.npy') | |||
ds_test2 = np.load('test2.npy') | |||
# ood detector initialization | |||
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') | |||
# get optimal threshold with ds_test1 | |||
num = int(len(ds_test1) / 2) | |||
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 | |||
optimal_threshold = detector.get_optimal_threshold(label, ds_test1) | |||
# get result of ds_test2. We can also set threshold by ourself. | |||
result = detector.ood_predict(optimal_threshold, ds_test2) |
@@ -1,10 +1,10 @@ | |||
# Content | |||
# Concept Drift | |||
## Concept drift Description | |||
In predictive analytics and machine learning, the concept drift means that the statistical properties of the target variable, which the model is trying to predict, change over time in unforeseen ways. This causes problems because the predictions become less accurate as time passes. Usually, concept drift is described as the change of data distribution over time. | |||
## Method | |||
## Method for time series | |||
### Model Architecture | |||
@@ -13,9 +13,9 @@ For time series concept drift detection | |||
### Detector | |||
For a time series, we select two adjacent time window and compare the features of the two window data to determine whether concept drift has occurred. For feature extraction, we choose to use the ESN network. The input of the ESN network is a certain window data, and the output is also the window data (like an auto-encoder). In this way, the ESN network is equivalent to a feature extractor. Features are represented by model parameters (weights and bias) of the ESN network. Finally, by comparing the difference of model parameters, we can determine whether the data has concept drift. It should be noted that the two windows are constantly sliding forward. | |||
For time series, we select two adjacent time window and compare the features of the two window data to determine whether concept drift has occurred. For feature extraction, we choose to use the ESN network. The input of the ESN network is a certain window data, and the output is also the window data (like an auto-encoder). In this way, the ESN network is equivalent to a feature extractor. Features are represented by model parameters (weights and bias) of the ESN network. Finally, by comparing the difference of model parameters, we can determine whether the data has concept drift. It should be noted that the two windows are constantly sliding forward. | |||
## Dataset | |||
### Dataset | |||
Download dataset https://www.kaggle.com/camnugent/sandp500. | |||
@@ -32,7 +32,7 @@ Download dataset https://www.kaggle.com/camnugent/sandp500. | |||
Please use the data in archive/individual_stocks_5yr/individual_stocks_5yr/XX.csv. | |||
In each csv file, there are 'date','open','high','low','close','volume','Name' columns, please choose one column to begin your code. 'date' and 'Name' are non-data column. | |||
## Environment Requirements | |||
### Environment Requirements | |||
- Hardware(CPU/Ascend/GPU) | |||
- Prepare hardware environment with CPU, Ascend or GPU processor. | |||
@@ -42,9 +42,9 @@ In each csv file, there are 'date','open','high','low','close','volume','Name' c | |||
- MindSpore Tutorials | |||
- MindSpore Python API | |||
## Quick Start | |||
### Quick Start | |||
### Initialization | |||
#### Initialization | |||
```python | |||
from mindarmour.reliability.concept_drift.concept_drift_check_time_series import ConceptDriftCheckTimeSeries | |||
@@ -53,13 +53,13 @@ concept = ConceptDriftCheckTimeSeries(window_size=100, rolling_window=10, step=1 | |||
need_label=False) | |||
``` | |||
>window_size(int): Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default: 100. | |||
rolling_window(int): Smoothing window size, belongs to [1, window_size]. Default:10. | |||
step(int): The jump length of the sliding window, belongs to [1,window_size]. Default:10. | |||
threshold_index(float): The threshold index. Default: 1.5. | |||
need_label(bool): False or True. If need_label=True, concept drift labels are needed. Default: False. | |||
>`window_size(int)`: Size of a concept window, no less than 10. If given the input data, window_size belongs to [10, 1/3*len(input data)]. If the data is periodic, usually window_size equals 2-5 periods, such as, for monthly/weekly data, the data volume of 30/7 days is a period. Default: 100. | |||
`rolling_window(int)`: Smoothing window size, belongs to [1, window_size]. Default:10. | |||
`step(int)`: The jump length of the sliding window, belongs to [1,window_size]. Default:10. | |||
`threshold_index(float)`: The threshold index. Default: 1.5. | |||
`need_label(bool)`: False or True. If need_label=True, concept drift labels are needed. Default: False. | |||
### Data | |||
#### Data | |||
```python | |||
import numpy as np | |||
@@ -68,27 +68,144 @@ data = np.loadtxt(file, str, delimiter=",") | |||
data = data[1:, 2].astype('float64') # here we choose one column or multiple columns data[1:, 2:5]. | |||
``` | |||
>data(numpy.ndarray): Input data. The shape of data could be (n,1) or (n,m). | |||
>`data(numpy.ndarray)`: Input data. The shape of data could be (n,1) or (n,m). | |||
### Drift check | |||
#### Drift check | |||
```python | |||
drift_score, threshold, concept_drift_location = concept.concept_check(data) | |||
# the result is saved as pdf named 'concept_drift_check.pdf' | |||
``` | |||
>drift_score(numpy.ndarray): The concept drift score of the example series. | |||
threshold(float): The threshold to judge concept drift. | |||
concept_drift_location(list): The location of the concept drift. | |||
>`drift_score(numpy.ndarray)`: The concept drift score of the example series. | |||
`threshold(float)`: The threshold to judge concept drift. | |||
`concept_drift_location(list)`: The location of the concept drift. | |||
## Method for images | |||
Generally, neural networks are used to process images. Therefore, we use algorithms based on neural networks to detect concept drifts of images. | |||
For image data, there is a special term that describes the concept drift in detail, Out-of-Distribution(`OOD`). | |||
Hereinafter, we will use the term `OOD` to describe concept drifts in images. As for non-drift images, we use the term In-Distribution(`ID`). | |||
### Model Architecture | |||
The model structure can be any neural network structure, such as DNN, CNN, and RNN. | |||
Here, we select LeNet and ResNet as examples. | |||
### Detector | |||
Firstly, obtain the features of the training data, the features are the outputs of a selected neural layer. | |||
Secondly, the features are clustered to obtain the clustering centers. | |||
Finally, the features of the testing data in the same neural network layer are obtained, and the distance between the testing data features and the clustering center is calculated. | |||
When the distance exceeds the threshold, the image is determined as an out-of-distribution(OOD) image. | |||
### DataSet | |||
We prepared two pairs of dataset for LeNet and ResNet separately. | |||
For LeNet, the training data is Mnist as ID data. The testing data is Mnist + Cifar10. Cifar10 is OOD data. | |||
For ResNet, the training data is Cifar10 as ID data. The testing data is Cifar10 + ImageNet. ImageNet is OOD data. | |||
### Environment Requirements | |||
- Hardware | |||
- Prepare hardware environment with Ascend, CPU and GPU. | |||
- Framework | |||
- MindSpore | |||
- For more information, please check the resources below: | |||
- MindSpore Tutorials | |||
- MindSpore Python API | |||
### Quick Start | |||
#### Import | |||
```python | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.train.model import Model | |||
from mindspore import Model, nn, context | |||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster | |||
``` | |||
#### Load Classification Model | |||
```python | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = LeNet5() | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
``` | |||
>`ckpt_path(str)`: the model path. | |||
#### Load Data | |||
We prepare three datasets. The training dataset, that is the same as the dataset to train the Lenet. Two testing datasets, the first testing dataset is with OOD label(0 for non-ood, and 1 for ood) for finding an optimal threshold for ood detection. | |||
The second testing dataset is for ood validation. The first testing dataset is not necessary if we would like to set threshold by ourselves | |||
```python | |||
ds_train = np.load('../../dataset/concept_train_lenet.npy') | |||
ds_test1 = np.load('../../dataset/concept_test_lenet1.npy') | |||
ds_test2 = np.load('../../dataset/concept_test_lenet2.npy') | |||
``` | |||
> `ds_train(numpy.ndarray)`: the train data. | |||
> `ds_test1(numpy.ndarray)`: the data for finding an optimal threshold. This dataset is not necessary. | |||
> `ds_test2(numpy.ndarray)`: the test data for ood detection. | |||
#### OOD detector initialization | |||
OOD detector for Lenet. | |||
```python | |||
# ood detector initialization | |||
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') | |||
``` | |||
> `model(Model)`: the model trained by the `ds_train`. | |||
> `ds_train(numpy.ndarray)`: the training data. | |||
> `n_cluster(int)`: the feature cluster number. | |||
> `layer(str)`: the feature extraction layer. In our example, The layer name could be 'output[:Tensor]', '9[:Tensor]', '10[:Tensor]', '11[:Tensor]' for LeNet. | |||
#### Optimal Threshold | |||
This step is optional. If we have a labeled dataset, named ds_test1, we can use the following code to find the optimal detection threshold. | |||
```python | |||
# get optimal threshold with ds_test1 | |||
num = int(len(ds_test1) / 2) | |||
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 | |||
optimal_threshold = detector.get_optimal_threshold(label, ds_test1) | |||
``` | |||
> `ds_test1(numpy.ndarray)`: the data for finding an optimal threshold. . | |||
> `label(numpy.ndarray)`: the ood label of ds_test1. 0 means non-ood data, and 1 means ood data. | |||
#### Detection result | |||
```python | |||
result = detector.ood_predict(optimal_threshold, ds_test2) | |||
``` | |||
> `ds_test2(numpy.ndarray)`: the testing data for ood detection. | |||
> `optimal_threshold(float)`: the optimal threshold to judge out-of-distribution data. We can also set the threshold value by ourselves. | |||
## Script Description | |||
```bash | |||
├── mindarmour | |||
├── reliability # descriptions about GhostNet # shell script for evaluation with CPU, GPU or Ascend | |||
├── reliability | |||
├──concept_drift | |||
├──__init__.py | |||
├──concept_drift_check_images.py | |||
├──concept_drift_check_time_series.py | |||
├──README.md | |||
``` | |||
@@ -0,0 +1,171 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================ | |||
import heapq | |||
import numpy as np | |||
from mindspore import Tensor | |||
from sklearn.cluster import KMeans | |||
from mindarmour.utils._check_param import check_param_type, check_param_in_range | |||
from mindspore.train.summary.summary_record import _get_summary_tensor_data | |||
""" | |||
Out-of-Distribution detection for images. | |||
The sample can be run on Ascend 910 AI processor. | |||
""" | |||
class OodDetector: | |||
""" | |||
Train the OOD detector. | |||
Args: | |||
model (Model):The training model. | |||
ds_train (numpy.ndarray): The training dataset. | |||
""" | |||
def __init__(self, model, ds_train): | |||
self.model = model | |||
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) | |||
def _feature_extract(self, model, data, layer='output[:Tensor]'): | |||
""" | |||
Extract features. | |||
Args: | |||
model (Model): The model for extracting features. | |||
data (numpy.ndarray): Input data. | |||
layer (str): The feature layer. The layer name could be 'output[:Tensor]', | |||
'9[:Tensor]', '10[:Tensor]', '11[:Tensor]' for LeNet, and 'output[:Tensor]', | |||
'1[:Tensor]' for Resnet. | |||
Returns: | |||
numpy.ndarray, the feature of input data. | |||
""" | |||
model.predict(Tensor(data)) | |||
layer_out = _get_summary_tensor_data() | |||
return layer_out[layer].asnumpy() | |||
def get_optimal_threshold(self, score, label, ds_test1): | |||
pass | |||
def ood_predict(self, threshold, ds_test2): | |||
pass | |||
class OodDetectorFeatureCluster(OodDetector): | |||
""" | |||
Train the OOD detector. | |||
Args: | |||
model (Model):The training model. | |||
ds_train (numpy.ndarray): The training dataset. | |||
n_cluster (int): The cluster number. | |||
""" | |||
def __init__(self, model, ds_train, n_cluster, layer): | |||
self.model = model | |||
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) | |||
self.n_cluster = check_param_type('n_cluster', n_cluster, int) | |||
self.n_cluster = check_param_in_range('n_cluster', n_cluster, 2, 100) | |||
self.layer = check_param_type('layer', layer, str) | |||
self.feature = self._feature_extract(model, ds_train, layer=self.layer) | |||
def _feature_cluster(self): | |||
""" | |||
Get the feature cluster. | |||
Returns: | |||
- numpy.ndarray, the feature cluster. | |||
""" | |||
clf = KMeans(n_clusters=self.n_cluster) | |||
clf.fit_predict(self.feature) | |||
feature_cluster = clf.cluster_centers_ | |||
return feature_cluster | |||
def _get_ood_score(self, ds_test): | |||
""" | |||
Get the ood score. | |||
Args: | |||
ds_test (numpy.ndarray): The testing dataset. | |||
Returns: | |||
- float, the optimal threshold. | |||
""" | |||
feature_cluster = self._feature_cluster() | |||
ds_test = self._feature_extract(self.model, ds_test, layer=self.layer) | |||
score = [] | |||
for i in range(len(ds_test)): | |||
dis = [] | |||
for j in range(len(feature_cluster)): | |||
loc = list( | |||
map(list(feature_cluster[j]).index, heapq.nlargest(self.n_cluster, list(feature_cluster[j])))) | |||
diff = sum(abs((feature_cluster[j][loc] - ds_test[i][loc]))) / sum(abs((feature_cluster[j][loc]))) | |||
dis.append(diff) | |||
score.append(min(dis)) | |||
score = np.array(score) | |||
return score | |||
def get_optimal_threshold(self, label, test_data_threshold): | |||
""" | |||
Get the optimal threshold. | |||
Args: | |||
score (numpy.ndarray): The detection score of images. | |||
label (numpy.ndarray): The label whether an image is in-distribution and out-of-distribution. | |||
test_data_threshold (numpy.ndarray): The testing dataset to help find the threshold. | |||
Returns: | |||
- float, the optimal threshold. | |||
""" | |||
check_param_type('label', label, np.ndarray) | |||
check_param_type('ds_test1', test_data_threshold, np.ndarray) | |||
score = self._get_ood_score(test_data_threshold) | |||
acc = [] | |||
threshold = [] | |||
for threshold_change in np.arange(0.0, 1.0, 0.01): | |||
count = 0 | |||
for i in range(len(score)): | |||
if score[i] < threshold_change and label[i] == 0: | |||
count = count + 1 | |||
if score[i] >= threshold_change and label[i] == 1: | |||
count = count + 1 | |||
acc.append(count / len(score)) | |||
threshold.append(threshold_change) | |||
acc = np.array(acc) | |||
threshold = np.array(threshold) | |||
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0] | |||
return optimal_threshold | |||
def ood_predict(self, threshold, ds_test): | |||
""" | |||
The out-of-distribution detection. | |||
Args: | |||
threshold (float): the threshold to judge ood data. One can set value by experience | |||
or use function get_optimal_threshold. | |||
ds_test (numpy.ndarray): The testing dataset. | |||
Returns: | |||
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. | |||
""" | |||
score = self._get_ood_score(ds_test) | |||
result = [] | |||
for i in range(len(score)): | |||
if score[i] < threshold: | |||
result.append(0) | |||
if score[i] >= threshold: | |||
result.append(1) | |||
result = np.array(result) | |||
return result |
@@ -0,0 +1,68 @@ | |||
# Copyright 2021 Huawei Technologies Co., Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
""" | |||
Concept drift test for images. | |||
""" | |||
import logging | |||
import pytest | |||
import numpy as np | |||
from mindspore import Tensor | |||
from mindspore.train.model import Model | |||
from mindarmour.utils.logger import LogUtil | |||
from mindspore import Model, nn, context | |||
from examples.common.networks.lenet5.lenet5_net_for_fuzzing import LeNet5 | |||
from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster | |||
LOGGER = LogUtil.get_instance() | |||
TAG = 'Concept_Test' | |||
@pytest.mark.level0 | |||
@pytest.mark.platform_arm_ascend_training | |||
@pytest.mark.platform_x86_ascend_training | |||
@pytest.mark.env_card | |||
@pytest.mark.component_mindarmour | |||
def test_cp(): | |||
""" | |||
Concept drift test | |||
""" | |||
# load model | |||
ckpt_path = '../../dataset/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
net = LeNet5() | |||
load_dict = load_checkpoint(ckpt_path) | |||
load_param_into_net(net, load_dict) | |||
model = Model(net) | |||
# load data | |||
ds_train = np.load('../../dataset/concept_train_lenet.npy') | |||
ds_test1 = np.load('../../dataset/concept_test_lenet1.npy') | |||
ds_test2 = np.load('../../dataset/concept_test_lenet2.npy') | |||
# ood detector initialization | |||
detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') | |||
# get optimal threshold with ds_test1 | |||
num = int(len(ds_test1) / 2) | |||
label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) # ID data = 0, OOD data = 1 | |||
optimal_threshold = detector.get_optimal_threshold(label, ds_test1) | |||
# get result of ds_test2. We can also set threshold by ourself. | |||
result = detector.ood_predict(optimal_threshold, ds_test2) | |||
# result log | |||
LOGGER.set_level(logging.DEBUG) | |||
LOGGER.debug(TAG, '--start ood test--') | |||
LOGGER.debug(result, '--ood result--') | |||
LOGGER.debug(optimal_threshold, '--the optimal threshold--') | |||
LOGGER.debug(TAG, '--end ood test--') | |||
assert np.any(result >= 0.0) |