@@ -0,0 +1,52 @@ | |||||
import numpy as np | |||||
from tods.sk_interface.detection_algorithm.LSTMODetector_skinterface import LSTMODetectorSKI | |||||
from sklearn.metrics import precision_recall_curve | |||||
from sklearn.metrics import accuracy_score | |||||
from sklearn.metrics import confusion_matrix | |||||
from sklearn.metrics import classification_report | |||||
import matplotlib.pyplot as plt | |||||
from sklearn import metrics | |||||
#prepare the data | |||||
data = np.loadtxt("./500_UCR_Anomaly_robotDOG1_10000_19280_19360.txt") | |||||
X_train = np.expand_dims(data[:10000], axis=1) | |||||
X_test = np.expand_dims(data[10000:], axis=1) | |||||
transformer = LSTMODetectorSKI() | |||||
transformer.fit(X_train) | |||||
prediction_labels_train = transformer.predict(X_train) | |||||
prediction_labels = transformer.predict(X_test) | |||||
prediction_score = transformer.predict_score(X_test) | |||||
print("Prediction Labels\n", prediction_labels) | |||||
print("Prediction Score\n", prediction_score) | |||||
# y_true = prediction_labels_train[:1000] | |||||
# y_pred = prediction_labels[:1000] | |||||
y_true = prediction_labels_train | |||||
y_pred = prediction_labels | |||||
print('Accuracy Score: ', accuracy_score(y_true, y_pred)) | |||||
confusion_matrix(y_true, y_pred) | |||||
print(classification_report(y_true, y_pred)) | |||||
precision, recall, thresholds = precision_recall_curve(y_true, y_pred) | |||||
f1_scores = 2*recall*precision/(recall+precision) | |||||
print('Best threshold: ', thresholds[np.argmax(f1_scores)]) | |||||
print('Best F1-Score: ', np.max(f1_scores)) | |||||
fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred) | |||||
roc_auc = metrics.auc(fpr, tpr) | |||||
plt.title('ROC') | |||||
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc) | |||||
plt.legend(loc = 'lower right') | |||||
plt.ylabel('True Positive Rate') | |||||
plt.xlabel('False Positive Rate') | |||||
plt.show() |
@@ -1,5 +1,5 @@ | |||||
import numpy as np | import numpy as np | ||||
from tods.tods_skinterface.primitiveSKI.detection_algorithm.Telemanom_skinterface import TelemanomSKI | |||||
from tods.sk_interface.detection_algorithm.Telemanom_skinterface import TelemanomSKI | |||||
from sklearn.metrics import precision_recall_curve | from sklearn.metrics import precision_recall_curve | ||||
from sklearn.metrics import accuracy_score | from sklearn.metrics import accuracy_score | ||||
from sklearn.metrics import confusion_matrix | from sklearn.metrics import confusion_matrix | ||||
@@ -68,7 +68,7 @@ class Hyperparams(Hyperparams_ODBase): | |||||
) | ) | ||||
min_attack_time = hyperparams.Hyperparameter[int]( | min_attack_time = hyperparams.Hyperparameter[int]( | ||||
default=5, | |||||
default=10, | |||||
description='The minimum amount of recent time steps that is used to define a collective attack.', | description='The minimum amount of recent time steps that is used to define a collective attack.', | ||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | ||||
) | ) | ||||
@@ -97,7 +97,7 @@ class Hyperparams(Hyperparams_ODBase): | |||||
) | ) | ||||
epochs = hyperparams.Hyperparameter[int]( | epochs = hyperparams.Hyperparameter[int]( | ||||
default=10, | |||||
default=50, | |||||
description='Number of epochs to train the model.', | description='Number of epochs to train the model.', | ||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | ||||
) | ) | ||||
@@ -123,13 +123,13 @@ class Hyperparams(Hyperparams_ODBase): | |||||
) | ) | ||||
hidden_dim = hyperparams.Hyperparameter[int]( | hidden_dim = hyperparams.Hyperparameter[int]( | ||||
default=16, | |||||
default=8, | |||||
description='Hidden dim of LSTM.', | description='Hidden dim of LSTM.', | ||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | ||||
) | ) | ||||
n_hidden_layer = hyperparams.Hyperparameter[int]( | n_hidden_layer = hyperparams.Hyperparameter[int]( | ||||
default=0, | |||||
default=2, | |||||
description='Hidden layer number of LSTM.', | description='Hidden layer number of LSTM.', | ||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | ||||
) | ) | ||||
@@ -83,7 +83,7 @@ class Hyperparams_ODBase(hyperparams.Hyperparams): | |||||
) | ) | ||||
window_size = hyperparams.Hyperparameter[int]( | window_size = hyperparams.Hyperparameter[int]( | ||||
default=1, | |||||
default=10, | |||||
description='The moving window size.', | description='The moving window size.', | ||||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | ||||
) | ) | ||||
@@ -103,6 +103,33 @@ class AutoRegOD(CollectiveBaseDetector): | |||||
self._process_decision_scores() | self._process_decision_scores() | ||||
return self | return self | ||||
def predict(self, X): # pragma: no cover | |||||
"""Predict if a particular sample is an outlier or not. | |||||
Parameters | |||||
---------- | |||||
X : numpy array of shape (n_samples, n_features) | |||||
The input samples. | |||||
Returns | |||||
------- | |||||
outlier_labels : numpy array of shape (n_samples,) | |||||
For each observation, tells whether or not | |||||
it should be considered as an outlier according to the | |||||
fitted model. 0 stands for inliers and 1 for outliers. | |||||
""" | |||||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||||
return (pred_score > self.threshold_).astype( | |||||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||||
def decision_function(self, X: np.array): | def decision_function(self, X: np.array): | ||||
"""Predict raw anomaly scores of X using the fitted detector. | """Predict raw anomaly scores of X using the fitted detector. | ||||
@@ -33,7 +33,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||||
): | ): | ||||
super(LSTMOutlierDetector, self).__init__(contamination=contamination, | super(LSTMOutlierDetector, self).__init__(contamination=contamination, | ||||
window_size=min_attack_time, | |||||
# window_size=min_attack_time, | |||||
step_size=1, | step_size=1, | ||||
) | ) | ||||
@@ -54,14 +54,34 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||||
self.activation = activation | self.activation = activation | ||||
# def _build_model(self): | |||||
# print('dim:', self.hidden_dim, self.feature_dim) | |||||
# model_ = Sequential() | |||||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), | |||||
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True)) | |||||
# for layer_idx in range(self.n_hidden_layer-1): | |||||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | |||||
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True)) | |||||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | |||||
# dropout=self.dropout_rate, activation=self.activation)) | |||||
# model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) | |||||
# model_.compile(loss=self.loss, optimizer=self.optimizer) | |||||
# return model_ | |||||
def _build_model(self): | def _build_model(self): | ||||
model_ = Sequential() | model_ = Sequential() | ||||
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), | model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), | ||||
dropout=self.dropout_rate, activation=self.activation)) | |||||
dropout=self.dropout_rate, activation=self.activation, | |||||
return_sequences=bool(self.n_hidden_layer>0))) | |||||
for layer_idx in range(self.n_hidden_layer): | for layer_idx in range(self.n_hidden_layer): | ||||
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | ||||
dropout=self.dropout_rate, activation=self.activation)) | |||||
dropout=self.dropout_rate, activation=self.activation, | |||||
return_sequences=bool(layer_idx < self.n_hidden_layer - 1))) | |||||
model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) | model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) | ||||
@@ -84,6 +104,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||||
self : object | self : object | ||||
Fitted estimator. | Fitted estimator. | ||||
""" | """ | ||||
print("XXXX:", X.shape) | |||||
X = check_array(X).astype(np.float) | X = check_array(X).astype(np.float) | ||||
self._set_n_classes(None) | self._set_n_classes(None) | ||||
X_buf, y_buf = self._get_sub_matrices(X) | X_buf, y_buf = self._get_sub_matrices(X) | ||||
@@ -121,6 +142,33 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||||
return relative_error | return relative_error | ||||
def predict(self, X): # pragma: no cover | |||||
"""Predict if a particular sample is an outlier or not. | |||||
Parameters | |||||
---------- | |||||
X : numpy array of shape (n_samples, n_features) | |||||
The input samples. | |||||
Returns | |||||
------- | |||||
outlier_labels : numpy array of shape (n_samples,) | |||||
For each observation, tells whether or not | |||||
it should be considered as an outlier according to the | |||||
fitted model. 0 stands for inliers and 1 for outliers. | |||||
""" | |||||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||||
return (pred_score > self.threshold_).astype( | |||||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||||
def decision_function(self, X: np.array): | def decision_function(self, X: np.array): | ||||
"""Predict raw anomaly scores of X using the fitted detector. | """Predict raw anomaly scores of X using the fitted detector. | ||||
@@ -157,6 +157,33 @@ class MultiAutoRegOD(CollectiveBaseDetector): | |||||
self._process_decision_scores() | self._process_decision_scores() | ||||
return self | return self | ||||
def predict(self, X): # pragma: no cover | |||||
"""Predict if a particular sample is an outlier or not. | |||||
Parameters | |||||
---------- | |||||
X : numpy array of shape (n_samples, n_features) | |||||
The input samples. | |||||
Returns | |||||
------- | |||||
outlier_labels : numpy array of shape (n_samples,) | |||||
For each observation, tells whether or not | |||||
it should be considered as an outlier according to the | |||||
fitted model. 0 stands for inliers and 1 for outliers. | |||||
""" | |||||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||||
return (pred_score > self.threshold_).astype( | |||||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||||
def decision_function(self, X: np.array): | def decision_function(self, X: np.array): | ||||
"""Predict raw anomaly scores of X using the fitted detector. | """Predict raw anomaly scores of X using the fitted detector. | ||||