@@ -0,0 +1,52 @@ | |||
import numpy as np | |||
from tods.sk_interface.detection_algorithm.LSTMODetector_skinterface import LSTMODetectorSKI | |||
from sklearn.metrics import precision_recall_curve | |||
from sklearn.metrics import accuracy_score | |||
from sklearn.metrics import confusion_matrix | |||
from sklearn.metrics import classification_report | |||
import matplotlib.pyplot as plt | |||
from sklearn import metrics | |||
#prepare the data | |||
data = np.loadtxt("./500_UCR_Anomaly_robotDOG1_10000_19280_19360.txt") | |||
X_train = np.expand_dims(data[:10000], axis=1) | |||
X_test = np.expand_dims(data[10000:], axis=1) | |||
transformer = LSTMODetectorSKI() | |||
transformer.fit(X_train) | |||
prediction_labels_train = transformer.predict(X_train) | |||
prediction_labels = transformer.predict(X_test) | |||
prediction_score = transformer.predict_score(X_test) | |||
print("Prediction Labels\n", prediction_labels) | |||
print("Prediction Score\n", prediction_score) | |||
# y_true = prediction_labels_train[:1000] | |||
# y_pred = prediction_labels[:1000] | |||
y_true = prediction_labels_train | |||
y_pred = prediction_labels | |||
print('Accuracy Score: ', accuracy_score(y_true, y_pred)) | |||
confusion_matrix(y_true, y_pred) | |||
print(classification_report(y_true, y_pred)) | |||
precision, recall, thresholds = precision_recall_curve(y_true, y_pred) | |||
f1_scores = 2*recall*precision/(recall+precision) | |||
print('Best threshold: ', thresholds[np.argmax(f1_scores)]) | |||
print('Best F1-Score: ', np.max(f1_scores)) | |||
fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred) | |||
roc_auc = metrics.auc(fpr, tpr) | |||
plt.title('ROC') | |||
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc) | |||
plt.legend(loc = 'lower right') | |||
plt.ylabel('True Positive Rate') | |||
plt.xlabel('False Positive Rate') | |||
plt.show() |
@@ -1,5 +1,5 @@ | |||
import numpy as np | |||
from tods.tods_skinterface.primitiveSKI.detection_algorithm.Telemanom_skinterface import TelemanomSKI | |||
from tods.sk_interface.detection_algorithm.Telemanom_skinterface import TelemanomSKI | |||
from sklearn.metrics import precision_recall_curve | |||
from sklearn.metrics import accuracy_score | |||
from sklearn.metrics import confusion_matrix | |||
@@ -68,7 +68,7 @@ class Hyperparams(Hyperparams_ODBase): | |||
) | |||
min_attack_time = hyperparams.Hyperparameter[int]( | |||
default=5, | |||
default=10, | |||
description='The minimum amount of recent time steps that is used to define a collective attack.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
@@ -97,7 +97,7 @@ class Hyperparams(Hyperparams_ODBase): | |||
) | |||
epochs = hyperparams.Hyperparameter[int]( | |||
default=10, | |||
default=50, | |||
description='Number of epochs to train the model.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
@@ -123,13 +123,13 @@ class Hyperparams(Hyperparams_ODBase): | |||
) | |||
hidden_dim = hyperparams.Hyperparameter[int]( | |||
default=16, | |||
default=8, | |||
description='Hidden dim of LSTM.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
n_hidden_layer = hyperparams.Hyperparameter[int]( | |||
default=0, | |||
default=2, | |||
description='Hidden layer number of LSTM.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
@@ -83,7 +83,7 @@ class Hyperparams_ODBase(hyperparams.Hyperparams): | |||
) | |||
window_size = hyperparams.Hyperparameter[int]( | |||
default=1, | |||
default=10, | |||
description='The moving window size.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
@@ -103,6 +103,33 @@ class AutoRegOD(CollectiveBaseDetector): | |||
self._process_decision_scores() | |||
return self | |||
def predict(self, X): # pragma: no cover | |||
"""Predict if a particular sample is an outlier or not. | |||
Parameters | |||
---------- | |||
X : numpy array of shape (n_samples, n_features) | |||
The input samples. | |||
Returns | |||
------- | |||
outlier_labels : numpy array of shape (n_samples,) | |||
For each observation, tells whether or not | |||
it should be considered as an outlier according to the | |||
fitted model. 0 stands for inliers and 1 for outliers. | |||
""" | |||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||
return (pred_score > self.threshold_).astype( | |||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||
def decision_function(self, X: np.array): | |||
"""Predict raw anomaly scores of X using the fitted detector. | |||
@@ -33,7 +33,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||
): | |||
super(LSTMOutlierDetector, self).__init__(contamination=contamination, | |||
window_size=min_attack_time, | |||
# window_size=min_attack_time, | |||
step_size=1, | |||
) | |||
@@ -54,14 +54,34 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||
self.activation = activation | |||
# def _build_model(self): | |||
# print('dim:', self.hidden_dim, self.feature_dim) | |||
# model_ = Sequential() | |||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), | |||
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True)) | |||
# for layer_idx in range(self.n_hidden_layer-1): | |||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | |||
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True)) | |||
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | |||
# dropout=self.dropout_rate, activation=self.activation)) | |||
# model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) | |||
# model_.compile(loss=self.loss, optimizer=self.optimizer) | |||
# return model_ | |||
def _build_model(self): | |||
model_ = Sequential() | |||
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), | |||
dropout=self.dropout_rate, activation=self.activation)) | |||
dropout=self.dropout_rate, activation=self.activation, | |||
return_sequences=bool(self.n_hidden_layer>0))) | |||
for layer_idx in range(self.n_hidden_layer): | |||
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), | |||
dropout=self.dropout_rate, activation=self.activation)) | |||
dropout=self.dropout_rate, activation=self.activation, | |||
return_sequences=bool(layer_idx < self.n_hidden_layer - 1))) | |||
model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) | |||
@@ -84,6 +104,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||
self : object | |||
Fitted estimator. | |||
""" | |||
print("XXXX:", X.shape) | |||
X = check_array(X).astype(np.float) | |||
self._set_n_classes(None) | |||
X_buf, y_buf = self._get_sub_matrices(X) | |||
@@ -121,6 +142,33 @@ class LSTMOutlierDetector(CollectiveBaseDetector): | |||
return relative_error | |||
def predict(self, X): # pragma: no cover | |||
"""Predict if a particular sample is an outlier or not. | |||
Parameters | |||
---------- | |||
X : numpy array of shape (n_samples, n_features) | |||
The input samples. | |||
Returns | |||
------- | |||
outlier_labels : numpy array of shape (n_samples,) | |||
For each observation, tells whether or not | |||
it should be considered as an outlier according to the | |||
fitted model. 0 stands for inliers and 1 for outliers. | |||
""" | |||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||
return (pred_score > self.threshold_).astype( | |||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||
def decision_function(self, X: np.array): | |||
"""Predict raw anomaly scores of X using the fitted detector. | |||
@@ -157,6 +157,33 @@ class MultiAutoRegOD(CollectiveBaseDetector): | |||
self._process_decision_scores() | |||
return self | |||
def predict(self, X): # pragma: no cover | |||
"""Predict if a particular sample is an outlier or not. | |||
Parameters | |||
---------- | |||
X : numpy array of shape (n_samples, n_features) | |||
The input samples. | |||
Returns | |||
------- | |||
outlier_labels : numpy array of shape (n_samples,) | |||
For each observation, tells whether or not | |||
it should be considered as an outlier according to the | |||
fitted model. 0 stands for inliers and 1 for outliers. | |||
""" | |||
check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_']) | |||
pred_score, X_left_inds, X_right_inds = self.decision_function(X) | |||
pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score)) | |||
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds)) | |||
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds)) | |||
return (pred_score > self.threshold_).astype( | |||
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel() | |||
def decision_function(self, X: np.array): | |||
"""Predict raw anomaly scores of X using the fitted detector. | |||