From 69528414a825e7270f268d51be104024d7eb1b92 Mon Sep 17 00:00:00 2001 From: lhenry15 Date: Thu, 6 May 2021 12:03:55 -0500 Subject: [PATCH] merge DAGMM --- tods/detection_algorithm/DAGMM.py | 200 ++++++++++++++++ tods/detection_algorithm/core/dagmm/__init__.py | 6 + .../core/dagmm/compression_net.py | 121 ++++++++++ tods/detection_algorithm/core/dagmm/dagmm.py | 251 +++++++++++++++++++++ .../core/dagmm/estimation_net.py | 63 ++++++ tods/detection_algorithm/core/dagmm/gmm.py | 130 +++++++++++ tods/tests/detection_algorithm/test_DAGMM.py | 114 ++++++++++ 7 files changed, 885 insertions(+) create mode 100644 tods/detection_algorithm/DAGMM.py create mode 100644 tods/detection_algorithm/core/dagmm/__init__.py create mode 100644 tods/detection_algorithm/core/dagmm/compression_net.py create mode 100644 tods/detection_algorithm/core/dagmm/dagmm.py create mode 100644 tods/detection_algorithm/core/dagmm/estimation_net.py create mode 100644 tods/detection_algorithm/core/dagmm/gmm.py create mode 100644 tods/tests/detection_algorithm/test_DAGMM.py diff --git a/tods/detection_algorithm/DAGMM.py b/tods/detection_algorithm/DAGMM.py new file mode 100644 index 0000000..0a2bff4 --- /dev/null +++ b/tods/detection_algorithm/DAGMM.py @@ -0,0 +1,200 @@ +from typing import Any, Callable, List, Dict, Union, Optional, Sequence, Tuple + + + +from d3m.metadata import hyperparams, params, base as metadata_base + +from d3m.primitive_interfaces.base import CallResult, DockerContainer + +from tods.detection_algorithm.core.dagmm.dagmm import DAGMM +import uuid + +from d3m import container, utils as d3m_utils + +from tods.detection_algorithm.UODBasePrimitive import Params_ODBase, Hyperparams_ODBase, UnsupervisedOutlierDetectorBase + + + +__all__ = ('DAGMMPrimitive',) + +Inputs = container.DataFrame +Outputs = container.DataFrame + + +class Params(Params_ODBase): + ######## Add more Attributes ####### + + pass + + +class Hyperparams(Hyperparams_ODBase): + comp_hiddens = hyperparams.List( + default=[16,8,1], + elements=hyperparams.Hyperparameter[int](1), + description='Sizes of hidden layers of compression network.', + semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] + ) + + est_hiddens = hyperparams.List( + default=[8,4], + elements=hyperparams.Hyperparameter[int](1), + description='Sizes of hidden layers of estimation network.', + semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] + ) + est_dropout_ratio = hyperparams.Hyperparameter[float]( + default=0.25, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Dropout rate of estimation network" + ) + + minibatch_size = hyperparams.Hyperparameter[int]( + default=3, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Mini Batch size" + ) + + epoch_size = hyperparams.Hyperparameter[int]( + default=100, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Epoch" + ) + + rand_seed = hyperparams.Hyperparameter[int]( + default=0, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="(optional )random seed used when fit() is called" + ) + + learning_rate = hyperparams.Hyperparameter[float]( + default=0.0001, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="learning rate" + ) + lambda1 = hyperparams.Hyperparameter[float]( + default=0.1, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="a parameter of loss function (for energy term)" + ) + lambda2 = hyperparams.Hyperparameter[float]( + default=0.1, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="a parameter of loss function" + ) + + normalize = hyperparams.Hyperparameter[bool]( + default=True, + semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], + description="Specify whether input data need to be normalized." + ) + + contamination = hyperparams.Uniform( + lower=0., + upper=0.5, + default=0.1, + description='the amount of contamination of the data set, i.e.the proportion of outliers in the data set. Used when fitting to define the threshold on the decision function', + semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] + ) + + +class DAGMMPrimitive(UnsupervisedOutlierDetectorBase[Inputs, Outputs, Params, Hyperparams]): + """ + Deep Autoencoding Gaussian Mixture Model + Parameters + ---------- + + + """ + + __author__ = "DATA Lab at Texas A&M University", + metadata = metadata_base.PrimitiveMetadata( + { + '__author__': "DATA Lab @Texas A&M University", + 'name': "DAGMM", + 'python_path': 'd3m.primitives.tods.detection_algorithm.dagmm', + 'source': {'name': "DATALAB @Taxes A&M University", 'contact': 'mailto:khlai037@tamu.edu', + 'uris': ['https://gitlab.com/lhenry15/tods/-/blob/Yile/anomaly-primitives/anomaly_primitives/DAGMM.py']}, + 'algorithm_types': [metadata_base.PrimitiveAlgorithmType.DEEPLOG], + 'primitive_family': metadata_base.PrimitiveFamily.ANOMALY_DETECTION, + 'id': str(uuid.uuid3(uuid.NAMESPACE_DNS, 'DAGMMPrimitive')), + 'hyperparams_to_tune': ['comp_hiddens','est_hiddens','est_dropout_ratio','minibatch_size','epoch_size','rand_seed', + 'learning_rate','lambda1','lambda2','contamination'], + 'version': '0.0.1', + } + ) + + def __init__(self, *, + hyperparams: Hyperparams, # + random_seed: int = 0, + docker_containers: Dict[str, DockerContainer] = None) -> None: + super().__init__(hyperparams=hyperparams, random_seed=random_seed, docker_containers=docker_containers) + self._clf = DAGMM(comp_hiddens= hyperparams['comp_hiddens'], + est_hiddens=hyperparams['est_hiddens'], + est_dropout_ratio=hyperparams['est_dropout_ratio'], + minibatch_size=hyperparams['minibatch_size'], + epoch_size=hyperparams['epoch_size'], + random_seed=hyperparams['rand_seed'], + learning_rate=hyperparams['learning_rate'], + lambda2=hyperparams['lambda2'], + normalize=hyperparams['normalize'], + contamination=hyperparams['contamination'] + + ) + + def set_training_data(self, *, inputs: Inputs) -> None: + """ + Set training data for outlier detection. + Args: + inputs: Container DataFrame + + Returns: + None + """ + super().set_training_data(inputs=inputs) + + def fit(self, *, timeout: float = None, iterations: int = None) -> CallResult[None]: + """ + Fit model with training data. + Args: + *: Container DataFrame. Time series data up to fit. + + Returns: + None + """ + + return super().fit() + + def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> CallResult[Outputs]: + """ + Process the testing data. + Args: + inputs: Container DataFrame. Time series data up to outlier detection. + + Returns: + Container DataFrame + 1 marks Outliers, 0 marks normal. + """ + return super().produce(inputs=inputs, timeout=timeout, iterations=iterations) + + def get_params(self) -> Params: + """ + Return parameters. + Args: + None + + Returns: + class Params + """ + return super().get_params() + + def set_params(self, *, params: Params) -> None: + """ + Set parameters for outlier detection. + Args: + params: class Params + + Returns: + None + """ + super().set_params(params=params) + + diff --git a/tods/detection_algorithm/core/dagmm/__init__.py b/tods/detection_algorithm/core/dagmm/__init__.py new file mode 100644 index 0000000..cc05500 --- /dev/null +++ b/tods/detection_algorithm/core/dagmm/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from .compression_net import CompressionNet +from .estimation_net import EstimationNet +from .gmm import GMM +from .dagmm import DAGMM diff --git a/tods/detection_algorithm/core/dagmm/compression_net.py b/tods/detection_algorithm/core/dagmm/compression_net.py new file mode 100644 index 0000000..759dbdc --- /dev/null +++ b/tods/detection_algorithm/core/dagmm/compression_net.py @@ -0,0 +1,121 @@ +import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + +class CompressionNet: + """ Compression Network. + This network converts the input data to the representations + suitable for calculation of anormaly scores by "Estimation Network". + + Outputs of network consist of next 2 components: + 1) reduced low-dimensional representations learned by AutoEncoder. + 2) the features derived from reconstruction error. + """ + def __init__(self, hidden_layer_sizes, activation=tf.nn.tanh): + """ + Parameters + ---------- + hidden_layer_sizes : list of int + list of the size of hidden layers. + For example, if the sizes are [n1, n2], + the sizes of created networks are: + input_size -> n1 -> n2 -> n1 -> input_sizes + (network outputs the representation of "n2" layer) + activation : function + activation function of hidden layer. + the last layer uses linear function. + """ + self.hidden_layer_sizes = hidden_layer_sizes + self.activation = activation + + def compress(self, x): + self.input_size = x.shape[1] + + with tf.variable_scope("Encoder"): + z = x + n_layer = 0 + for size in self.hidden_layer_sizes[:-1]: + n_layer += 1 + z = tf.layers.dense(z, size, activation=self.activation, + name="layer_{}".format(n_layer)) + + # activation function of last layer is linear + n_layer += 1 + z = tf.layers.dense(z, self.hidden_layer_sizes[-1], + name="layer_{}".format(n_layer)) + + return z + + def reverse(self, z): + with tf.variable_scope("Decoder"): + n_layer = 0 + for size in self.hidden_layer_sizes[:-1][::-1]: + n_layer += 1 + z = tf.layers.dense(z, size, activation=self.activation, + name="layer_{}".format(n_layer)) + + # activation function of last layes is linear + n_layer += 1 + x_dash = tf.layers.dense(z, self.input_size, + name="layer_{}".format(n_layer)) + + return x_dash + + def loss(self, x, x_dash): + def euclid_norm(x): + return tf.sqrt(tf.reduce_sum(tf.square(x), axis=1)) + + # Calculate Euclid norm, distance + norm_x = euclid_norm(x) + norm_x_dash = euclid_norm(x_dash) + dist_x = euclid_norm(x - x_dash) + dot_x = tf.reduce_sum(x * x_dash, axis=1) + + # Based on the original paper, features of reconstraction error + # are composed of these loss functions: + # 1. loss_E : relative Euclidean distance + # 2. loss_C : cosine similarity + min_val = 1e-3 + loss_E = dist_x / (norm_x + min_val) + loss_C = 0.5 * (1.0 - dot_x / (norm_x * norm_x_dash + min_val)) + return tf.concat([loss_E[:,None], loss_C[:,None]], axis=1) + + def extract_feature(self, x, x_dash, z_c): + z_r = self.loss(x, x_dash) + return tf.concat([z_c, z_r], axis=1) + + def inference(self, x): + """ convert input to output tensor, which is composed of + low-dimensional representation and reconstruction error. + + Parameters + ---------- + x : tf.Tensor shape : (n_samples, n_features) + Input data + + Results + ------- + z : tf.Tensor shape : (n_samples, n2 + 2) + Result data + Second dimension of this data is equal to + sum of compressed representation size and + number of loss function (=2) + + x_dash : tf.Tensor shape : (n_samples, n_features) + Reconstructed data for calculation of + reconstruction error. + """ + + with tf.variable_scope("CompNet"): + # AutoEncoder + z_c = self.compress(x) + x_dash = self.reverse(z_c) + + # compose feature vector + z = self.extract_feature(x, x_dash, z_c) + + return z, x_dash + + def reconstruction_error(self, x, x_dash): + return tf.reduce_mean(tf.reduce_sum( + tf.square(x - x_dash), axis=1), axis=0) diff --git a/tods/detection_algorithm/core/dagmm/dagmm.py b/tods/detection_algorithm/core/dagmm/dagmm.py new file mode 100644 index 0000000..47c70c9 --- /dev/null +++ b/tods/detection_algorithm/core/dagmm/dagmm.py @@ -0,0 +1,251 @@ +import tensorflow as tf +import numpy as np +from sklearn.preprocessing import StandardScaler +from sklearn.externals import joblib + +from .compression_net import CompressionNet +from .estimation_net import EstimationNet +from .gmm import GMM +from pyod.utils.stat_models import pairwise_distances_no_broadcast + +from os import makedirs +from os.path import exists, join +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + +from pyod.models.base import BaseDetector + +class DAGMM(BaseDetector): + """ Deep Autoencoding Gaussian Mixture Model. + + This implementation is based on the paper: + Bo Zong+ (2018) Deep Autoencoding Gaussian Mixture Model + for Unsupervised Anomaly Detection, ICLR 2018 + (this is UNOFFICIAL implementation) + """ + + MODEL_FILENAME = "DAGMM_model" + SCALER_FILENAME = "DAGMM_scaler" + + def __init__(self, comp_hiddens:list = [16,8,1], + est_hiddens:list = [8,4], est_dropout_ratio:float =0.5, + minibatch_size:int = 1024, epoch_size:int =100, + learning_rate:float =0.0001, lambda1:float =0.1, lambda2:float =0.0001, + normalize:bool=True, random_seed:int=123 , contamination:float = 0.001 ): + """ + Parameters + ---------- + comp_hiddens : list of int + sizes of hidden layers of compression network + For example, if the sizes are [n1, n2], + structure of compression network is: + input_size -> n1 -> n2 -> n1 -> input_sizes + + est_hiddens : list of int + sizes of hidden layers of estimation network. + The last element of this list is assigned as n_comp. + For example, if the sizes are [n1, n2], + structure of estimation network is: + input_size -> n1 -> n2 (= n_comp) + + est_dropout_ratio : float (optional) + dropout ratio of estimation network applied during training + if 0 or None, dropout is not applied. + minibatch_size: int (optional) + mini batch size during training + epoch_size : int (optional) + epoch size during training + learning_rate : float (optional) + learning rate during training + lambda1 : float (optional) + a parameter of loss function (for energy term) + lambda2 : float (optional) + a parameter of loss function + (for sum of diagonal elements of covariance) + normalize : bool (optional) + specify whether input data need to be normalized. + by default, input data is normalized. + random_seed : int (optional) + random seed used when fit() is called. + """ + est_activation = tf.nn.tanh + comp_activation = tf.nn.tanh + super(DAGMM, self).__init__(contamination=contamination) + self.comp_net = CompressionNet(comp_hiddens, comp_activation) + self.est_net = EstimationNet(est_hiddens, est_activation) + self.est_dropout_ratio = est_dropout_ratio + + n_comp = est_hiddens[-1] + self.gmm = GMM(n_comp) + + self.minibatch_size = minibatch_size + self.epoch_size = epoch_size + self.learning_rate = learning_rate + self.lambda1 = lambda1 + self.lambda2 = lambda2 + + self.normalize = normalize + self.scaler = None + self.seed = random_seed + + self.graph = None + self.sess = None + + #def __del__(self): + # if self.sess is not None: + # self.sess.close() + + def fit(self,X,y=None): + """ Fit the DAGMM model according to the given data. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training data. + """ + + n_samples, n_features = X.shape + + if self.normalize: + self.scaler = scaler = StandardScaler() + X = scaler.fit_transform(X) + + with tf.Graph().as_default() as graph: + self.graph = graph + tf.set_random_seed(self.seed) + np.random.seed(seed=self.seed) + + # Create Placeholder + self.input = input = tf.placeholder( + dtype=tf.float32, shape=[None, n_features]) + self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[]) + + # Build graph + z, x_dash = self.comp_net.inference(input) + gamma = self.est_net.inference(z, drop) + self.gmm.fit(z, gamma) + energy = self.gmm.energy(z) + + self.x_dash = x_dash + + # Loss function + loss = (self.comp_net.reconstruction_error(input, x_dash) + + self.lambda1 * tf.reduce_mean(energy) + + self.lambda2 * self.gmm.cov_diag_loss()) + + # Minimizer + minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss) + + # Number of batch + n_batch = (n_samples - 1) // self.minibatch_size + 1 + + # Create tensorflow session and initilize + init = tf.global_variables_initializer() + + self.sess = tf.Session(graph=graph) + self.sess.run(init) + + # Training + idx = np.arange(X.shape[0]) + np.random.shuffle(idx) + + for epoch in range(self.epoch_size): + for batch in range(n_batch): + i_start = batch * self.minibatch_size + i_end = (batch + 1) * self.minibatch_size + x_batch = X[idx[i_start:i_end]] + + self.sess.run(minimizer, feed_dict={ + input:x_batch, drop:self.est_dropout_ratio}) + if (epoch + 1) % 10 == 0: + loss_val = self.sess.run(loss, feed_dict={input:X, drop:0}) + print(" epoch {}/{} : loss = {:.3f}".format(epoch + 1, self.epoch_size, loss_val)) + + # Fix GMM parameter + fix = self.gmm.fix_op() + self.sess.run(fix, feed_dict={input:X, drop:0}) + self.energy = self.gmm.energy(z) + + tf.add_to_collection("save", self.input) + tf.add_to_collection("save", self.energy) + + self.saver = tf.train.Saver() + + pred_scores = self.decision_function(X) + self.decision_scores_ = pred_scores + self._process_decision_scores() + #return self + + def decision_function(self, X): + """ Calculate anormaly scores (sample energy) on samples in X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Data for which anomaly scores are calculated. + n_features must be equal to n_features of the fitted data. + + Returns + ------- + energies : array-like, shape (n_samples) + Calculated sample energies. + """ + if self.sess is None: + raise Exception("Trained model does not exist.") + + if self.normalize: + X = self.scaler.transform(X) + + energies = self.sess.run(self.energy, feed_dict={self.input:X}) + + return energies.reshape(1,-1) + + def save(self, fdir): + """ Save trained model to designated directory. + This method have to be called after training. + (If not, throw an exception) + + Parameters + ---------- + fdir : str + Path of directory trained model is saved. + If not exists, it is created automatically. + """ + if self.sess is None: + raise Exception("Trained model does not exist.") + + if not exists(fdir): + makedirs(fdir) + + model_path = join(fdir, self.MODEL_FILENAME) + self.saver.save(self.sess, model_path) + + if self.normalize: + scaler_path = join(fdir, self.SCALER_FILENAME) + joblib.dump(self.scaler, scaler_path) + + def restore(self, fdir): + """ Restore trained model from designated directory. + + Parameters + ---------- + fdir : str + Path of directory trained model is saved. + """ + if not exists(fdir): + raise Exception("Model directory does not exist.") + + model_path = join(fdir, self.MODEL_FILENAME) + meta_path = model_path + ".meta" + + with tf.Graph().as_default() as graph: + self.graph = graph + self.sess = tf.Session(graph=graph) + self.saver = tf.train.import_meta_graph(meta_path) + self.saver.restore(self.sess, model_path) + + self.input, self.energy = tf.get_collection("save") + + if self.normalize: + scaler_path = join(fdir, self.SCALER_FILENAME) + self.scaler = joblib.load(scaler_path) diff --git a/tods/detection_algorithm/core/dagmm/estimation_net.py b/tods/detection_algorithm/core/dagmm/estimation_net.py new file mode 100644 index 0000000..6dbe7cc --- /dev/null +++ b/tods/detection_algorithm/core/dagmm/estimation_net.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + +class EstimationNet: + """ Estimation Network + + This network converts input feature vector to softmax probability. + Bacause loss function for this network is not defined, + it should be implemented outside of this class. + """ + def __init__(self, hidden_layer_sizes, activation=tf.nn.relu): + """ + Parameters + ---------- + hidden_layer_sizes : list of int + list of sizes of hidden layers. + For example, if the sizes are [n1, n2], + layer sizes of the network are: + input_size -> n1 -> n2 + (network outputs the softmax probabilities of "n2" layer) + activation : function + activation function of hidden layer. + the funtcion of last layer is softmax function. + """ + self.hidden_layer_sizes = hidden_layer_sizes + self.activation = activation + + def inference(self, z, dropout_ratio=None): + """ Output softmax probabilities + + Parameters + ---------- + z : tf.Tensor shape : (n_samples, n_features) + Data inferenced by this network + dropout_ratio : tf.Tensor shape : 0-dimension float (optional) + Specify dropout ratio + (if None, dropout is not applied) + + Results + ------- + probs : tf.Tensor shape : (n_samples, n_classes) + Calculated probabilities + """ + with tf.variable_scope("EstNet"): + n_layer = 0 + for size in self.hidden_layer_sizes[:-1]: + n_layer += 1 + z = tf.layers.dense(z, size, activation=self.activation, + name="layer_{}".format(n_layer)) + if dropout_ratio is not None: + z = tf.layers.dropout(z, dropout_ratio, + name="drop_{}".format(n_layer)) + + # Last layer uses linear function (=logits) + size = self.hidden_layer_sizes[-1] + logits = tf.layers.dense(z, size, activation=None, name="logits") + + # Softmax output + output = tf.nn.softmax(logits) + + return output diff --git a/tods/detection_algorithm/core/dagmm/gmm.py b/tods/detection_algorithm/core/dagmm/gmm.py new file mode 100644 index 0000000..6da4070 --- /dev/null +++ b/tods/detection_algorithm/core/dagmm/gmm.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() +class GMM: + """ Gaussian Mixture Model (GMM) """ + def __init__(self, n_comp): + self.n_comp = n_comp + self.phi = self.mu = self.sigma = None + self.training = False + + def create_variables(self, n_features): + with tf.variable_scope("GMM"): + phi = tf.Variable(tf.zeros(shape=[self.n_comp]), + dtype=tf.float32, name="phi") + mu = tf.Variable(tf.zeros(shape=[self.n_comp, n_features]), + dtype=tf.float32, name="mu") + sigma = tf.Variable(tf.zeros( + shape=[self.n_comp, n_features, n_features]), + dtype=tf.float32, name="sigma") + L = tf.Variable(tf.zeros( + shape=[self.n_comp, n_features, n_features]), + dtype=tf.float32, name="L") + + return phi, mu, sigma, L + + def fit(self, z, gamma): + """ fit data to GMM model + + Parameters + ---------- + z : tf.Tensor, shape (n_samples, n_features) + data fitted to GMM. + gamma : tf.Tensor, shape (n_samples, n_comp) + probability. each row is correspond to row of z. + """ + + with tf.variable_scope("GMM"): + # Calculate mu, sigma + # i : index of samples + # k : index of components + # l,m : index of features + gamma_sum = tf.reduce_sum(gamma, axis=0) + self.phi = phi = tf.reduce_mean(gamma, axis=0) + self.mu = mu = tf.einsum('ik,il->kl', gamma, z) / gamma_sum[:,None] + z_centered = tf.sqrt(gamma[:,:,None]) * (z[:,None,:] - mu[None,:,:]) + self.sigma = sigma = tf.einsum( + 'ikl,ikm->klm', z_centered, z_centered) / gamma_sum[:,None,None] + + # Calculate a cholesky decomposition of covariance in advance + n_features = z.shape[1] + min_vals = tf.diag(tf.ones(n_features, dtype=tf.float32)) * 1e-6 + self.L = tf.cholesky(sigma + min_vals[None,:,:]) + + self.training = False + return self + + def fix_op(self): + """ return operator to fix paramters of GMM + Using this operator outside of this class, + you can fix current parameter to static tensor variable. + + After you call this method, you have to run result + operator immediatelly, and call energy() to use static + variables of model parameter. + + Returns + ------- + op : operator of tensorflow + operator to assign current parameter to variables + """ + + phi, mu, sigma, L = self.create_variables(self.mu.shape[1]) + + op = tf.group( + tf.assign(phi, self.phi), + tf.assign(mu, self.mu), + tf.assign(sigma, self.sigma), + tf.assign(L, self.L) + ) + + self.phi, self.phi_org = phi, self.phi + self.mu, self.mu_org = mu, self.mu + self.sigma, self.sigma_org = sigma, self.sigma + self.L, self.L_org = L, self.L + + self.training = False + + return op + + def energy(self, z): + """ calculate an energy of each row of z + + Parameters + ---------- + z : tf.Tensor, shape (n_samples, n_features) + data each row of which is calculated its energy. + + Returns + ------- + energy : tf.Tensor, shape (n_samples) + calculated energies + """ + + if self.training and self.phi is None: + self.phi, self.mu, self.sigma, self.L = self.create_variable(z.shape[1]) + + with tf.variable_scope("GMM_energy"): + # Instead of inverse covariance matrix, exploit cholesky decomposition + # for stability of calculation. + z_centered = z[:,None,:] - self.mu[None,:,:] #ikl + v = tf.matrix_triangular_solve(self.L, tf.transpose(z_centered, [1, 2, 0])) # kli + + # log(det(Sigma)) = 2 * sum[log(diag(L))] + log_det_sigma = 2.0 * tf.reduce_sum(tf.log(tf.matrix_diag_part(self.L)), axis=1) + + # To calculate energies, use "log-sum-exp" (different from orginal paper) + d = z.get_shape().as_list()[1] + logits = tf.log(self.phi[:,None]) - 0.5 * (tf.reduce_sum(tf.square(v), axis=1) + + d * tf.log(2.0 * np.pi) + log_det_sigma[:,None]) + energies = - tf.reduce_logsumexp(logits, axis=0) + + return energies + + def cov_diag_loss(self): + with tf.variable_scope("GMM_diag_loss"): + diag_loss = tf.reduce_sum(tf.divide(1, tf.matrix_diag_part(self.sigma))) + + return diag_loss diff --git a/tods/tests/detection_algorithm/test_DAGMM.py b/tods/tests/detection_algorithm/test_DAGMM.py new file mode 100644 index 0000000..0c8ab59 --- /dev/null +++ b/tods/tests/detection_algorithm/test_DAGMM.py @@ -0,0 +1,114 @@ +import unittest + +from d3m import container, utils +from d3m.metadata import base as metadata_base +from tods.detection_algorithm.DAGMM import DAGMMPrimitive + + + +class DAGMMTest(unittest.TestCase): + def test_basic(self): + self.maxDiff = None + self.main = container.DataFrame({'a': [3.,5.,7.,2.], 'b': [1.,4.,7.,2.], 'c': [6.,3.,9.,17.]}, + columns=['a', 'b', 'c'], + generate_metadata=True) + + + + + self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{ + 'selector': [], + 'metadata': { + # 'top_level': 'main', + 'schema': metadata_base.CONTAINER_SCHEMA_VERSION, + 'structural_type': 'd3m.container.pandas.DataFrame', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'], + 'dimension': { + 'name': 'rows', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'], + 'length': 4, + }, + }, + }, { + 'selector': ['__ALL_ELEMENTS__'], + 'metadata': { + 'dimension': { + 'name': 'columns', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'], + 'length': 3, + }, + }, + }, { + 'selector': ['__ALL_ELEMENTS__', 0], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'a'}, + }, { + 'selector': ['__ALL_ELEMENTS__', 1], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'b'}, + }, { + 'selector': ['__ALL_ELEMENTS__', 2], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'c'} + }]) + + + self.assertIsInstance(self.main, container.DataFrame) + + + hyperparams_class = DAGMMPrimitive.metadata.get_hyperparams() + hyperparams = hyperparams_class.defaults() + hyperparams = hyperparams.replace({'minibatch_size': 4}) + + + self.primitive = DAGMMPrimitive(hyperparams=hyperparams) + self.primitive.set_training_data(inputs=self.main) + #print("*****************",self.primitive.get_params()) + + self.primitive.fit() + self.new_main = self.primitive.produce(inputs=self.main).value + self.new_main_score = self.primitive.produce_score(inputs=self.main).value + print(self.new_main) + print(self.new_main_score) + + params = self.primitive.get_params() + self.primitive.set_params(params=params) + + self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{ + 'selector': [], + 'metadata': { + # 'top_level': 'main', + 'schema': metadata_base.CONTAINER_SCHEMA_VERSION, + 'structural_type': 'd3m.container.pandas.DataFrame', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'], + 'dimension': { + 'name': 'rows', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'], + 'length': 4, + }, + }, + }, { + 'selector': ['__ALL_ELEMENTS__'], + 'metadata': { + 'dimension': { + 'name': 'columns', + 'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'], + 'length': 3, + }, + }, + }, { + 'selector': ['__ALL_ELEMENTS__', 0], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'a'}, + }, { + 'selector': ['__ALL_ELEMENTS__', 1], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'b'}, + }, { + 'selector': ['__ALL_ELEMENTS__', 2], + 'metadata': {'structural_type': 'numpy.float64', 'name': 'c'} + }]) + + # def test_params(self): + # params = self.primitive.get_params() + # self.primitive.set_params(params=params) + + + +if __name__ == '__main__': + unittest.main()