diff --git a/tods/tests/sk_interface/detection_algorithm/test_ski_LSTMODetector.py b/tods/tests/sk_interface/detection_algorithm/test_ski_LSTMODetector.py index dd6359b..bfe11df 100644 --- a/tods/tests/sk_interface/detection_algorithm/test_ski_LSTMODetector.py +++ b/tods/tests/sk_interface/detection_algorithm/test_ski_LSTMODetector.py @@ -29,7 +29,8 @@ class LSTMODetectorSKI_TestCase(unittest.TestCase): self.y_test = self.y_test[1:] self.y_train = self.y_train[1:] - self.transformer = LSTMODetectorSKI(contamination=self.contamination) + + self.transformer = LSTMODetectorSKI(contamination=self.contamination, feature_dim=self.X_train.shape[1]) self.transformer.fit(self.X_train) def test_prediction_labels(self):