|
|
@@ -13,18 +13,16 @@ |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
""" |
|
|
|
Out-of-Distribution detection module for images. |
|
|
|
""" |
|
|
|
|
|
|
|
import heapq |
|
|
|
import numpy as np |
|
|
|
from mindspore import Tensor |
|
|
|
from sklearn.cluster import KMeans |
|
|
|
from mindarmour.utils._check_param import check_param_type, check_param_in_range |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.train.summary.summary_record import _get_summary_tensor_data |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
Out-of-Distribution detection for images. |
|
|
|
""" |
|
|
|
from mindarmour.utils._check_param import check_param_type, check_param_in_range |
|
|
|
|
|
|
|
|
|
|
|
class OodDetector: |
|
|
@@ -67,7 +65,7 @@ class OodDetector: |
|
|
|
Returns: |
|
|
|
- float, the optimal threshold. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def ood_predict(self, threshold, ds_test): |
|
|
|
""" |
|
|
@@ -81,7 +79,6 @@ class OodDetector: |
|
|
|
Returns: |
|
|
|
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class OodDetectorFeatureCluster(OodDetector): |
|
|
@@ -90,6 +87,8 @@ class OodDetectorFeatureCluster(OodDetector): |
|
|
|
the testing data features and the clustering centers determines whether an image is an out-of-distribution(OOD) |
|
|
|
image or not. |
|
|
|
|
|
|
|
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/concept_drift_images.html>`_ |
|
|
|
|
|
|
|
Args: |
|
|
|
model (Model):The training model. |
|
|
|
ds_train (numpy.ndarray): The training dataset. |
|
|
@@ -100,9 +99,38 @@ class OodDetectorFeatureCluster(OodDetector): |
|
|
|
layer (str): The name of the feature layer. layer (str) is represented by |
|
|
|
'name[:Tensor]', where 'name' is given by users when training the model. |
|
|
|
Please see more details about how to name the model layer in 'README.md'. |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> from mindspore import Model |
|
|
|
>>> from mindspore.ops import TensorSummary |
|
|
|
>>> import mindspore.ops.operations as P |
|
|
|
>>> from mindarmour.reliability.concept_drift.concept_drift_check_images import OodDetectorFeatureCluster |
|
|
|
>>> class Net(nn.Cell): |
|
|
|
... def __init__(self): |
|
|
|
... super(Net, self).__init__() |
|
|
|
... self._softmax = P.Softmax() |
|
|
|
... self._Dense = nn.Dense(10,10) |
|
|
|
... self._squeeze = P.Squeeze(1) |
|
|
|
... self._summary = TensorSummary() |
|
|
|
... def construct(self, inputs): |
|
|
|
... out = self._softmax(inputs) |
|
|
|
... out = self._Dense(out) |
|
|
|
... self._summary('output', out) |
|
|
|
... return self._squeeze(out) |
|
|
|
>>> net = Net() |
|
|
|
>>> model = Model(net) |
|
|
|
>>> batch_size = 16 |
|
|
|
>>> batches = 1 |
|
|
|
>>> ds_train = np.random.randn(batches * batch_size, 1, 10).astype(np.float32) |
|
|
|
>>> ds_eval = np.random.randn(batches * batch_size, 1, 10).astype(np.float32) |
|
|
|
>>> detector = OodDetectorFeatureCluster(model, ds_train, n_cluster=10, layer='output[:Tensor]') |
|
|
|
>>> num = int(len(ds_eval) / 2) |
|
|
|
>>> ood_label = np.concatenate((np.zeros(num), np.ones(num)), axis=0) |
|
|
|
>>> optimal_threshold = detector.get_optimal_threshold(ood_label, ds_eval) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model, ds_train, n_cluster, layer): |
|
|
|
super(OodDetectorFeatureCluster, self).__init__(model, ds_train) |
|
|
|
self.model = model |
|
|
|
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) |
|
|
|
self.n_cluster = check_param_type('n_cluster', n_cluster, int) |
|
|
@@ -173,7 +201,7 @@ class OodDetectorFeatureCluster(OodDetector): |
|
|
|
threshold.append(threshold_change) |
|
|
|
acc = np.array(acc) |
|
|
|
threshold = np.array(threshold) |
|
|
|
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0] |
|
|
|
optimal_threshold = threshold[np.where(acc == np.max(acc))[0]][0] |
|
|
|
return optimal_threshold |
|
|
|
|
|
|
|
def ood_predict(self, threshold, ds_test): |
|
|
|