Browse Source

modify default window size to 10 and fix autoregression and LSTM

master
lhenry15 4 years ago
parent
commit
0cb2f05171
7 changed files with 163 additions and 9 deletions
  1. +52
    -0
      examples/sk_examples/LSTMOD_test.py
  2. +1
    -1
      examples/sk_examples/Telemanom_test.py
  3. +4
    -4
      tods/detection_algorithm/LSTMODetect.py
  4. +1
    -1
      tods/detection_algorithm/UODBasePrimitive.py
  5. +27
    -0
      tods/detection_algorithm/core/AutoRegOD.py
  6. +51
    -3
      tods/detection_algorithm/core/LSTMOD.py
  7. +27
    -0
      tods/detection_algorithm/core/MultiAutoRegOD.py

+ 52
- 0
examples/sk_examples/LSTMOD_test.py View File

@@ -0,0 +1,52 @@
import numpy as np
from tods.sk_interface.detection_algorithm.LSTMODetector_skinterface import LSTMODetectorSKI
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from sklearn import metrics

#prepare the data
data = np.loadtxt("./500_UCR_Anomaly_robotDOG1_10000_19280_19360.txt")

X_train = np.expand_dims(data[:10000], axis=1)
X_test = np.expand_dims(data[10000:], axis=1)

transformer = LSTMODetectorSKI()
transformer.fit(X_train)

prediction_labels_train = transformer.predict(X_train)

prediction_labels = transformer.predict(X_test)
prediction_score = transformer.predict_score(X_test)

print("Prediction Labels\n", prediction_labels)
print("Prediction Score\n", prediction_score)

# y_true = prediction_labels_train[:1000]
# y_pred = prediction_labels[:1000]
y_true = prediction_labels_train
y_pred = prediction_labels

print('Accuracy Score: ', accuracy_score(y_true, y_pred))

confusion_matrix(y_true, y_pred)

print(classification_report(y_true, y_pred))

precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
f1_scores = 2*recall*precision/(recall+precision)

print('Best threshold: ', thresholds[np.argmax(f1_scores)])
print('Best F1-Score: ', np.max(f1_scores))

fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred)
roc_auc = metrics.auc(fpr, tpr)

plt.title('ROC')
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
plt.legend(loc = 'lower right')
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()

+ 1
- 1
examples/sk_examples/Telemanom_test.py View File

@@ -1,5 +1,5 @@
import numpy as np import numpy as np
from tods.tods_skinterface.primitiveSKI.detection_algorithm.Telemanom_skinterface import TelemanomSKI
from tods.sk_interface.detection_algorithm.Telemanom_skinterface import TelemanomSKI
from sklearn.metrics import precision_recall_curve from sklearn.metrics import precision_recall_curve
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix


+ 4
- 4
tods/detection_algorithm/LSTMODetect.py View File

@@ -68,7 +68,7 @@ class Hyperparams(Hyperparams_ODBase):
) )


min_attack_time = hyperparams.Hyperparameter[int]( min_attack_time = hyperparams.Hyperparameter[int](
default=5,
default=10,
description='The minimum amount of recent time steps that is used to define a collective attack.', description='The minimum amount of recent time steps that is used to define a collective attack.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
) )
@@ -97,7 +97,7 @@ class Hyperparams(Hyperparams_ODBase):
) )


epochs = hyperparams.Hyperparameter[int]( epochs = hyperparams.Hyperparameter[int](
default=10,
default=50,
description='Number of epochs to train the model.', description='Number of epochs to train the model.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
) )
@@ -123,13 +123,13 @@ class Hyperparams(Hyperparams_ODBase):
) )


hidden_dim = hyperparams.Hyperparameter[int]( hidden_dim = hyperparams.Hyperparameter[int](
default=16,
default=8,
description='Hidden dim of LSTM.', description='Hidden dim of LSTM.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
) )


n_hidden_layer = hyperparams.Hyperparameter[int]( n_hidden_layer = hyperparams.Hyperparameter[int](
default=0,
default=2,
description='Hidden layer number of LSTM.', description='Hidden layer number of LSTM.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
) )


+ 1
- 1
tods/detection_algorithm/UODBasePrimitive.py View File

@@ -83,7 +83,7 @@ class Hyperparams_ODBase(hyperparams.Hyperparams):
) )


window_size = hyperparams.Hyperparameter[int]( window_size = hyperparams.Hyperparameter[int](
default=1,
default=10,
description='The moving window size.', description='The moving window size.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
) )


+ 27
- 0
tods/detection_algorithm/core/AutoRegOD.py View File

@@ -103,6 +103,33 @@ class AutoRegOD(CollectiveBaseDetector):
self._process_decision_scores() self._process_decision_scores()
return self return self


def predict(self, X): # pragma: no cover
"""Predict if a particular sample is an outlier or not.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.

Returns
-------
outlier_labels : numpy array of shape (n_samples,)
For each observation, tells whether or not
it should be considered as an outlier according to the
fitted model. 0 stands for inliers and 1 for outliers.
"""

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

pred_score, X_left_inds, X_right_inds = self.decision_function(X)

pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score))
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds))
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds))

return (pred_score > self.threshold_).astype(
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel()

def decision_function(self, X: np.array): def decision_function(self, X: np.array):
"""Predict raw anomaly scores of X using the fitted detector. """Predict raw anomaly scores of X using the fitted detector.




+ 51
- 3
tods/detection_algorithm/core/LSTMOD.py View File

@@ -33,7 +33,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector):
): ):


super(LSTMOutlierDetector, self).__init__(contamination=contamination, super(LSTMOutlierDetector, self).__init__(contamination=contamination,
window_size=min_attack_time,
# window_size=min_attack_time,
step_size=1, step_size=1,
) )


@@ -54,14 +54,34 @@ class LSTMOutlierDetector(CollectiveBaseDetector):
self.activation = activation self.activation = activation




# def _build_model(self):
# print('dim:', self.hidden_dim, self.feature_dim)
# model_ = Sequential()
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1),
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True))

# for layer_idx in range(self.n_hidden_layer-1):
# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1),
# dropout=self.dropout_rate, activation=self.activation, return_sequences=True))

# model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1),
# dropout=self.dropout_rate, activation=self.activation))

# model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None))

# model_.compile(loss=self.loss, optimizer=self.optimizer)
# return model_

def _build_model(self): def _build_model(self):
model_ = Sequential() model_ = Sequential()
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1), model_.add(LSTM(units=self.hidden_dim, input_shape=(self.feature_dim, 1),
dropout=self.dropout_rate, activation=self.activation))
dropout=self.dropout_rate, activation=self.activation,
return_sequences=bool(self.n_hidden_layer>0)))


for layer_idx in range(self.n_hidden_layer): for layer_idx in range(self.n_hidden_layer):
model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1), model_.add(LSTM(units=self.hidden_dim, input_shape=(self.hidden_dim, 1),
dropout=self.dropout_rate, activation=self.activation))
dropout=self.dropout_rate, activation=self.activation,
return_sequences=bool(layer_idx < self.n_hidden_layer - 1)))


model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None)) model_.add(Dense(units=self.feature_dim, input_shape=(self.hidden_dim, 1), activation=None))


@@ -84,6 +104,7 @@ class LSTMOutlierDetector(CollectiveBaseDetector):
self : object self : object
Fitted estimator. Fitted estimator.
""" """
print("XXXX:", X.shape)
X = check_array(X).astype(np.float) X = check_array(X).astype(np.float)
self._set_n_classes(None) self._set_n_classes(None)
X_buf, y_buf = self._get_sub_matrices(X) X_buf, y_buf = self._get_sub_matrices(X)
@@ -121,6 +142,33 @@ class LSTMOutlierDetector(CollectiveBaseDetector):


return relative_error return relative_error


def predict(self, X): # pragma: no cover
"""Predict if a particular sample is an outlier or not.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.

Returns
-------
outlier_labels : numpy array of shape (n_samples,)
For each observation, tells whether or not
it should be considered as an outlier according to the
fitted model. 0 stands for inliers and 1 for outliers.
"""

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

pred_score, X_left_inds, X_right_inds = self.decision_function(X)

pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score))
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds))
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds))

return (pred_score > self.threshold_).astype(
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel()

def decision_function(self, X: np.array): def decision_function(self, X: np.array):
"""Predict raw anomaly scores of X using the fitted detector. """Predict raw anomaly scores of X using the fitted detector.




+ 27
- 0
tods/detection_algorithm/core/MultiAutoRegOD.py View File

@@ -157,6 +157,33 @@ class MultiAutoRegOD(CollectiveBaseDetector):
self._process_decision_scores() self._process_decision_scores()
return self return self


def predict(self, X): # pragma: no cover
"""Predict if a particular sample is an outlier or not.

Parameters
----------
X : numpy array of shape (n_samples, n_features)
The input samples.

Returns
-------
outlier_labels : numpy array of shape (n_samples,)
For each observation, tells whether or not
it should be considered as an outlier according to the
fitted model. 0 stands for inliers and 1 for outliers.
"""

check_is_fitted(self, ['decision_scores_', 'threshold_', 'labels_'])

pred_score, X_left_inds, X_right_inds = self.decision_function(X)

pred_score = np.concatenate((np.zeros((self.window_size,)), pred_score))
X_left_inds = np.concatenate((np.zeros((self.window_size,)), X_left_inds))
X_right_inds = np.concatenate((np.zeros((self.window_size,)), X_right_inds))

return (pred_score > self.threshold_).astype(
'int').ravel(), X_left_inds.ravel(), X_right_inds.ravel()

def decision_function(self, X: np.array): def decision_function(self, X: np.array):
"""Predict raw anomaly scores of X using the fitted detector. """Predict raw anomaly scores of X using the fitted detector.




Loading…
Cancel
Save