|
|
@@ -11,20 +11,22 @@ |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
""" |
|
|
|
Out-of-Distribution detection for images. |
|
|
|
""" |
|
|
|
|
|
|
|
import heapq |
|
|
|
from abc import abstractmethod |
|
|
|
import numpy as np |
|
|
|
from mindspore import Tensor |
|
|
|
from sklearn.cluster import KMeans |
|
|
|
from mindarmour.utils._check_param import check_param_type, check_param_in_range |
|
|
|
from mindspore.train.summary.summary_record import _get_summary_tensor_data |
|
|
|
|
|
|
|
from mindspore.train.summary.summary_record import _get_summary_tensor_data |
|
|
|
from mindspore import Tensor |
|
|
|
from mindarmour.utils._check_param import check_param_type, check_param_in_range |
|
|
|
from mindarmour.utils.logger import LogUtil |
|
|
|
|
|
|
|
""" |
|
|
|
Out-of-Distribution detection for images. |
|
|
|
""" |
|
|
|
LOGGER = LogUtil.get_instance() |
|
|
|
TAG = 'concept drift detection' |
|
|
|
|
|
|
|
|
|
|
|
class OodDetector: |
|
|
@@ -56,6 +58,7 @@ class OodDetector: |
|
|
|
layer_out = _get_summary_tensor_data() |
|
|
|
return layer_out[layer].asnumpy() |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def get_optimal_threshold(self, label, ds_eval): |
|
|
|
""" |
|
|
|
Get the optimal threshold. |
|
|
@@ -67,8 +70,12 @@ class OodDetector: |
|
|
|
Returns: |
|
|
|
- float, the optimal threshold. |
|
|
|
""" |
|
|
|
pass |
|
|
|
msg = 'The function generate() is an abstract function in class ' \ |
|
|
|
'`OodDetector` and should be implemented in child class.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise NotImplementedError(msg) |
|
|
|
|
|
|
|
@abstractmethod |
|
|
|
def ood_predict(self, threshold, ds_test): |
|
|
|
""" |
|
|
|
The out-of-distribution detection. |
|
|
@@ -81,7 +88,10 @@ class OodDetector: |
|
|
|
Returns: |
|
|
|
- numpy.ndarray, the detection result. 0 means the data is not ood, 1 means the data is ood. |
|
|
|
""" |
|
|
|
pass |
|
|
|
msg = 'The function generate() is an abstract function in class ' \ |
|
|
|
'`OodDetector` and should be implemented in child class.' |
|
|
|
LOGGER.error(TAG, msg) |
|
|
|
raise NotImplementedError(msg) |
|
|
|
|
|
|
|
|
|
|
|
class OodDetectorFeatureCluster(OodDetector): |
|
|
@@ -103,6 +113,7 @@ class OodDetectorFeatureCluster(OodDetector): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model, ds_train, n_cluster, layer): |
|
|
|
super().__init__(model, ds_train) |
|
|
|
self.model = model |
|
|
|
self.ds_train = check_param_type('ds_train', ds_train, np.ndarray) |
|
|
|
self.n_cluster = check_param_type('n_cluster', n_cluster, int) |
|
|
@@ -173,7 +184,7 @@ class OodDetectorFeatureCluster(OodDetector): |
|
|
|
threshold.append(threshold_change) |
|
|
|
acc = np.array(acc) |
|
|
|
threshold = np.array(threshold) |
|
|
|
optimal_threshold = threshold[np.where(acc==np.max(acc))[0]][0] |
|
|
|
optimal_threshold = threshold[np.where(acc == np.max(acc))[0]][0] |
|
|
|
return optimal_threshold |
|
|
|
|
|
|
|
def ood_predict(self, threshold, ds_test): |
|
|
|