@@ -0,0 +1,200 @@ | |||
from typing import Any, Callable, List, Dict, Union, Optional, Sequence, Tuple | |||
from d3m.metadata import hyperparams, params, base as metadata_base | |||
from d3m.primitive_interfaces.base import CallResult, DockerContainer | |||
from tods.detection_algorithm.core.dagmm.dagmm import DAGMM | |||
import uuid | |||
from d3m import container, utils as d3m_utils | |||
from tods.detection_algorithm.UODBasePrimitive import Params_ODBase, Hyperparams_ODBase, UnsupervisedOutlierDetectorBase | |||
__all__ = ('DAGMMPrimitive',) | |||
Inputs = container.DataFrame | |||
Outputs = container.DataFrame | |||
class Params(Params_ODBase): | |||
######## Add more Attributes ####### | |||
pass | |||
class Hyperparams(Hyperparams_ODBase): | |||
comp_hiddens = hyperparams.List( | |||
default=[16,8,1], | |||
elements=hyperparams.Hyperparameter[int](1), | |||
description='Sizes of hidden layers of compression network.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
est_hiddens = hyperparams.List( | |||
default=[8,4], | |||
elements=hyperparams.Hyperparameter[int](1), | |||
description='Sizes of hidden layers of estimation network.', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
est_dropout_ratio = hyperparams.Hyperparameter[float]( | |||
default=0.25, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Dropout rate of estimation network" | |||
) | |||
minibatch_size = hyperparams.Hyperparameter[int]( | |||
default=3, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Mini Batch size" | |||
) | |||
epoch_size = hyperparams.Hyperparameter[int]( | |||
default=100, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Epoch" | |||
) | |||
rand_seed = hyperparams.Hyperparameter[int]( | |||
default=0, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="(optional )random seed used when fit() is called" | |||
) | |||
learning_rate = hyperparams.Hyperparameter[float]( | |||
default=0.0001, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="learning rate" | |||
) | |||
lambda1 = hyperparams.Hyperparameter[float]( | |||
default=0.1, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="a parameter of loss function (for energy term)" | |||
) | |||
lambda2 = hyperparams.Hyperparameter[float]( | |||
default=0.1, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="a parameter of loss function" | |||
) | |||
normalize = hyperparams.Hyperparameter[bool]( | |||
default=True, | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'], | |||
description="Specify whether input data need to be normalized." | |||
) | |||
contamination = hyperparams.Uniform( | |||
lower=0., | |||
upper=0.5, | |||
default=0.1, | |||
description='the amount of contamination of the data set, i.e.the proportion of outliers in the data set. Used when fitting to define the threshold on the decision function', | |||
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'] | |||
) | |||
class DAGMMPrimitive(UnsupervisedOutlierDetectorBase[Inputs, Outputs, Params, Hyperparams]): | |||
""" | |||
Deep Autoencoding Gaussian Mixture Model | |||
Parameters | |||
---------- | |||
""" | |||
__author__ = "DATA Lab at Texas A&M University", | |||
metadata = metadata_base.PrimitiveMetadata( | |||
{ | |||
'__author__': "DATA Lab @Texas A&M University", | |||
'name': "DAGMM", | |||
'python_path': 'd3m.primitives.tods.detection_algorithm.dagmm', | |||
'source': {'name': "DATALAB @Taxes A&M University", 'contact': 'mailto:khlai037@tamu.edu', | |||
'uris': ['https://gitlab.com/lhenry15/tods/-/blob/Yile/anomaly-primitives/anomaly_primitives/DAGMM.py']}, | |||
'algorithm_types': [metadata_base.PrimitiveAlgorithmType.DEEPLOG], | |||
'primitive_family': metadata_base.PrimitiveFamily.ANOMALY_DETECTION, | |||
'id': str(uuid.uuid3(uuid.NAMESPACE_DNS, 'DAGMMPrimitive')), | |||
'hyperparams_to_tune': ['comp_hiddens','est_hiddens','est_dropout_ratio','minibatch_size','epoch_size','rand_seed', | |||
'learning_rate','lambda1','lambda2','contamination'], | |||
'version': '0.0.1', | |||
} | |||
) | |||
def __init__(self, *, | |||
hyperparams: Hyperparams, # | |||
random_seed: int = 0, | |||
docker_containers: Dict[str, DockerContainer] = None) -> None: | |||
super().__init__(hyperparams=hyperparams, random_seed=random_seed, docker_containers=docker_containers) | |||
self._clf = DAGMM(comp_hiddens= hyperparams['comp_hiddens'], | |||
est_hiddens=hyperparams['est_hiddens'], | |||
est_dropout_ratio=hyperparams['est_dropout_ratio'], | |||
minibatch_size=hyperparams['minibatch_size'], | |||
epoch_size=hyperparams['epoch_size'], | |||
random_seed=hyperparams['rand_seed'], | |||
learning_rate=hyperparams['learning_rate'], | |||
lambda2=hyperparams['lambda2'], | |||
normalize=hyperparams['normalize'], | |||
contamination=hyperparams['contamination'] | |||
) | |||
def set_training_data(self, *, inputs: Inputs) -> None: | |||
""" | |||
Set training data for outlier detection. | |||
Args: | |||
inputs: Container DataFrame | |||
Returns: | |||
None | |||
""" | |||
super().set_training_data(inputs=inputs) | |||
def fit(self, *, timeout: float = None, iterations: int = None) -> CallResult[None]: | |||
""" | |||
Fit model with training data. | |||
Args: | |||
*: Container DataFrame. Time series data up to fit. | |||
Returns: | |||
None | |||
""" | |||
return super().fit() | |||
def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> CallResult[Outputs]: | |||
""" | |||
Process the testing data. | |||
Args: | |||
inputs: Container DataFrame. Time series data up to outlier detection. | |||
Returns: | |||
Container DataFrame | |||
1 marks Outliers, 0 marks normal. | |||
""" | |||
return super().produce(inputs=inputs, timeout=timeout, iterations=iterations) | |||
def get_params(self) -> Params: | |||
""" | |||
Return parameters. | |||
Args: | |||
None | |||
Returns: | |||
class Params | |||
""" | |||
return super().get_params() | |||
def set_params(self, *, params: Params) -> None: | |||
""" | |||
Set parameters for outlier detection. | |||
Args: | |||
params: class Params | |||
Returns: | |||
None | |||
""" | |||
super().set_params(params=params) | |||
@@ -0,0 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
from .compression_net import CompressionNet | |||
from .estimation_net import EstimationNet | |||
from .gmm import GMM | |||
from .dagmm import DAGMM |
@@ -0,0 +1,121 @@ | |||
import tensorflow as tf | |||
import tensorflow.compat.v1 as tf | |||
tf.disable_v2_behavior() | |||
class CompressionNet: | |||
""" Compression Network. | |||
This network converts the input data to the representations | |||
suitable for calculation of anormaly scores by "Estimation Network". | |||
Outputs of network consist of next 2 components: | |||
1) reduced low-dimensional representations learned by AutoEncoder. | |||
2) the features derived from reconstruction error. | |||
""" | |||
def __init__(self, hidden_layer_sizes, activation=tf.nn.tanh): | |||
""" | |||
Parameters | |||
---------- | |||
hidden_layer_sizes : list of int | |||
list of the size of hidden layers. | |||
For example, if the sizes are [n1, n2], | |||
the sizes of created networks are: | |||
input_size -> n1 -> n2 -> n1 -> input_sizes | |||
(network outputs the representation of "n2" layer) | |||
activation : function | |||
activation function of hidden layer. | |||
the last layer uses linear function. | |||
""" | |||
self.hidden_layer_sizes = hidden_layer_sizes | |||
self.activation = activation | |||
def compress(self, x): | |||
self.input_size = x.shape[1] | |||
with tf.variable_scope("Encoder"): | |||
z = x | |||
n_layer = 0 | |||
for size in self.hidden_layer_sizes[:-1]: | |||
n_layer += 1 | |||
z = tf.layers.dense(z, size, activation=self.activation, | |||
name="layer_{}".format(n_layer)) | |||
# activation function of last layer is linear | |||
n_layer += 1 | |||
z = tf.layers.dense(z, self.hidden_layer_sizes[-1], | |||
name="layer_{}".format(n_layer)) | |||
return z | |||
def reverse(self, z): | |||
with tf.variable_scope("Decoder"): | |||
n_layer = 0 | |||
for size in self.hidden_layer_sizes[:-1][::-1]: | |||
n_layer += 1 | |||
z = tf.layers.dense(z, size, activation=self.activation, | |||
name="layer_{}".format(n_layer)) | |||
# activation function of last layes is linear | |||
n_layer += 1 | |||
x_dash = tf.layers.dense(z, self.input_size, | |||
name="layer_{}".format(n_layer)) | |||
return x_dash | |||
def loss(self, x, x_dash): | |||
def euclid_norm(x): | |||
return tf.sqrt(tf.reduce_sum(tf.square(x), axis=1)) | |||
# Calculate Euclid norm, distance | |||
norm_x = euclid_norm(x) | |||
norm_x_dash = euclid_norm(x_dash) | |||
dist_x = euclid_norm(x - x_dash) | |||
dot_x = tf.reduce_sum(x * x_dash, axis=1) | |||
# Based on the original paper, features of reconstraction error | |||
# are composed of these loss functions: | |||
# 1. loss_E : relative Euclidean distance | |||
# 2. loss_C : cosine similarity | |||
min_val = 1e-3 | |||
loss_E = dist_x / (norm_x + min_val) | |||
loss_C = 0.5 * (1.0 - dot_x / (norm_x * norm_x_dash + min_val)) | |||
return tf.concat([loss_E[:,None], loss_C[:,None]], axis=1) | |||
def extract_feature(self, x, x_dash, z_c): | |||
z_r = self.loss(x, x_dash) | |||
return tf.concat([z_c, z_r], axis=1) | |||
def inference(self, x): | |||
""" convert input to output tensor, which is composed of | |||
low-dimensional representation and reconstruction error. | |||
Parameters | |||
---------- | |||
x : tf.Tensor shape : (n_samples, n_features) | |||
Input data | |||
Results | |||
------- | |||
z : tf.Tensor shape : (n_samples, n2 + 2) | |||
Result data | |||
Second dimension of this data is equal to | |||
sum of compressed representation size and | |||
number of loss function (=2) | |||
x_dash : tf.Tensor shape : (n_samples, n_features) | |||
Reconstructed data for calculation of | |||
reconstruction error. | |||
""" | |||
with tf.variable_scope("CompNet"): | |||
# AutoEncoder | |||
z_c = self.compress(x) | |||
x_dash = self.reverse(z_c) | |||
# compose feature vector | |||
z = self.extract_feature(x, x_dash, z_c) | |||
return z, x_dash | |||
def reconstruction_error(self, x, x_dash): | |||
return tf.reduce_mean(tf.reduce_sum( | |||
tf.square(x - x_dash), axis=1), axis=0) |
@@ -0,0 +1,251 @@ | |||
import tensorflow as tf | |||
import numpy as np | |||
from sklearn.preprocessing import StandardScaler | |||
from sklearn.externals import joblib | |||
from .compression_net import CompressionNet | |||
from .estimation_net import EstimationNet | |||
from .gmm import GMM | |||
from pyod.utils.stat_models import pairwise_distances_no_broadcast | |||
from os import makedirs | |||
from os.path import exists, join | |||
import tensorflow.compat.v1 as tf | |||
tf.disable_v2_behavior() | |||
from pyod.models.base import BaseDetector | |||
class DAGMM(BaseDetector): | |||
""" Deep Autoencoding Gaussian Mixture Model. | |||
This implementation is based on the paper: | |||
Bo Zong+ (2018) Deep Autoencoding Gaussian Mixture Model | |||
for Unsupervised Anomaly Detection, ICLR 2018 | |||
(this is UNOFFICIAL implementation) | |||
""" | |||
MODEL_FILENAME = "DAGMM_model" | |||
SCALER_FILENAME = "DAGMM_scaler" | |||
def __init__(self, comp_hiddens:list = [16,8,1], | |||
est_hiddens:list = [8,4], est_dropout_ratio:float =0.5, | |||
minibatch_size:int = 1024, epoch_size:int =100, | |||
learning_rate:float =0.0001, lambda1:float =0.1, lambda2:float =0.0001, | |||
normalize:bool=True, random_seed:int=123 , contamination:float = 0.001 ): | |||
""" | |||
Parameters | |||
---------- | |||
comp_hiddens : list of int | |||
sizes of hidden layers of compression network | |||
For example, if the sizes are [n1, n2], | |||
structure of compression network is: | |||
input_size -> n1 -> n2 -> n1 -> input_sizes | |||
est_hiddens : list of int | |||
sizes of hidden layers of estimation network. | |||
The last element of this list is assigned as n_comp. | |||
For example, if the sizes are [n1, n2], | |||
structure of estimation network is: | |||
input_size -> n1 -> n2 (= n_comp) | |||
est_dropout_ratio : float (optional) | |||
dropout ratio of estimation network applied during training | |||
if 0 or None, dropout is not applied. | |||
minibatch_size: int (optional) | |||
mini batch size during training | |||
epoch_size : int (optional) | |||
epoch size during training | |||
learning_rate : float (optional) | |||
learning rate during training | |||
lambda1 : float (optional) | |||
a parameter of loss function (for energy term) | |||
lambda2 : float (optional) | |||
a parameter of loss function | |||
(for sum of diagonal elements of covariance) | |||
normalize : bool (optional) | |||
specify whether input data need to be normalized. | |||
by default, input data is normalized. | |||
random_seed : int (optional) | |||
random seed used when fit() is called. | |||
""" | |||
est_activation = tf.nn.tanh | |||
comp_activation = tf.nn.tanh | |||
super(DAGMM, self).__init__(contamination=contamination) | |||
self.comp_net = CompressionNet(comp_hiddens, comp_activation) | |||
self.est_net = EstimationNet(est_hiddens, est_activation) | |||
self.est_dropout_ratio = est_dropout_ratio | |||
n_comp = est_hiddens[-1] | |||
self.gmm = GMM(n_comp) | |||
self.minibatch_size = minibatch_size | |||
self.epoch_size = epoch_size | |||
self.learning_rate = learning_rate | |||
self.lambda1 = lambda1 | |||
self.lambda2 = lambda2 | |||
self.normalize = normalize | |||
self.scaler = None | |||
self.seed = random_seed | |||
self.graph = None | |||
self.sess = None | |||
#def __del__(self): | |||
# if self.sess is not None: | |||
# self.sess.close() | |||
def fit(self,X,y=None): | |||
""" Fit the DAGMM model according to the given data. | |||
Parameters | |||
---------- | |||
X : array-like, shape (n_samples, n_features) | |||
Training data. | |||
""" | |||
n_samples, n_features = X.shape | |||
if self.normalize: | |||
self.scaler = scaler = StandardScaler() | |||
X = scaler.fit_transform(X) | |||
with tf.Graph().as_default() as graph: | |||
self.graph = graph | |||
tf.set_random_seed(self.seed) | |||
np.random.seed(seed=self.seed) | |||
# Create Placeholder | |||
self.input = input = tf.placeholder( | |||
dtype=tf.float32, shape=[None, n_features]) | |||
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[]) | |||
# Build graph | |||
z, x_dash = self.comp_net.inference(input) | |||
gamma = self.est_net.inference(z, drop) | |||
self.gmm.fit(z, gamma) | |||
energy = self.gmm.energy(z) | |||
self.x_dash = x_dash | |||
# Loss function | |||
loss = (self.comp_net.reconstruction_error(input, x_dash) + | |||
self.lambda1 * tf.reduce_mean(energy) + | |||
self.lambda2 * self.gmm.cov_diag_loss()) | |||
# Minimizer | |||
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss) | |||
# Number of batch | |||
n_batch = (n_samples - 1) // self.minibatch_size + 1 | |||
# Create tensorflow session and initilize | |||
init = tf.global_variables_initializer() | |||
self.sess = tf.Session(graph=graph) | |||
self.sess.run(init) | |||
# Training | |||
idx = np.arange(X.shape[0]) | |||
np.random.shuffle(idx) | |||
for epoch in range(self.epoch_size): | |||
for batch in range(n_batch): | |||
i_start = batch * self.minibatch_size | |||
i_end = (batch + 1) * self.minibatch_size | |||
x_batch = X[idx[i_start:i_end]] | |||
self.sess.run(minimizer, feed_dict={ | |||
input:x_batch, drop:self.est_dropout_ratio}) | |||
if (epoch + 1) % 10 == 0: | |||
loss_val = self.sess.run(loss, feed_dict={input:X, drop:0}) | |||
print(" epoch {}/{} : loss = {:.3f}".format(epoch + 1, self.epoch_size, loss_val)) | |||
# Fix GMM parameter | |||
fix = self.gmm.fix_op() | |||
self.sess.run(fix, feed_dict={input:X, drop:0}) | |||
self.energy = self.gmm.energy(z) | |||
tf.add_to_collection("save", self.input) | |||
tf.add_to_collection("save", self.energy) | |||
self.saver = tf.train.Saver() | |||
pred_scores = self.decision_function(X) | |||
self.decision_scores_ = pred_scores | |||
self._process_decision_scores() | |||
#return self | |||
def decision_function(self, X): | |||
""" Calculate anormaly scores (sample energy) on samples in X. | |||
Parameters | |||
---------- | |||
X : array-like, shape (n_samples, n_features) | |||
Data for which anomaly scores are calculated. | |||
n_features must be equal to n_features of the fitted data. | |||
Returns | |||
------- | |||
energies : array-like, shape (n_samples) | |||
Calculated sample energies. | |||
""" | |||
if self.sess is None: | |||
raise Exception("Trained model does not exist.") | |||
if self.normalize: | |||
X = self.scaler.transform(X) | |||
energies = self.sess.run(self.energy, feed_dict={self.input:X}) | |||
return energies.reshape(1,-1) | |||
def save(self, fdir): | |||
""" Save trained model to designated directory. | |||
This method have to be called after training. | |||
(If not, throw an exception) | |||
Parameters | |||
---------- | |||
fdir : str | |||
Path of directory trained model is saved. | |||
If not exists, it is created automatically. | |||
""" | |||
if self.sess is None: | |||
raise Exception("Trained model does not exist.") | |||
if not exists(fdir): | |||
makedirs(fdir) | |||
model_path = join(fdir, self.MODEL_FILENAME) | |||
self.saver.save(self.sess, model_path) | |||
if self.normalize: | |||
scaler_path = join(fdir, self.SCALER_FILENAME) | |||
joblib.dump(self.scaler, scaler_path) | |||
def restore(self, fdir): | |||
""" Restore trained model from designated directory. | |||
Parameters | |||
---------- | |||
fdir : str | |||
Path of directory trained model is saved. | |||
""" | |||
if not exists(fdir): | |||
raise Exception("Model directory does not exist.") | |||
model_path = join(fdir, self.MODEL_FILENAME) | |||
meta_path = model_path + ".meta" | |||
with tf.Graph().as_default() as graph: | |||
self.graph = graph | |||
self.sess = tf.Session(graph=graph) | |||
self.saver = tf.train.import_meta_graph(meta_path) | |||
self.saver.restore(self.sess, model_path) | |||
self.input, self.energy = tf.get_collection("save") | |||
if self.normalize: | |||
scaler_path = join(fdir, self.SCALER_FILENAME) | |||
self.scaler = joblib.load(scaler_path) |
@@ -0,0 +1,63 @@ | |||
# -*- coding: utf-8 -*- | |||
import tensorflow as tf | |||
import tensorflow.compat.v1 as tf | |||
tf.disable_v2_behavior() | |||
class EstimationNet: | |||
""" Estimation Network | |||
This network converts input feature vector to softmax probability. | |||
Bacause loss function for this network is not defined, | |||
it should be implemented outside of this class. | |||
""" | |||
def __init__(self, hidden_layer_sizes, activation=tf.nn.relu): | |||
""" | |||
Parameters | |||
---------- | |||
hidden_layer_sizes : list of int | |||
list of sizes of hidden layers. | |||
For example, if the sizes are [n1, n2], | |||
layer sizes of the network are: | |||
input_size -> n1 -> n2 | |||
(network outputs the softmax probabilities of "n2" layer) | |||
activation : function | |||
activation function of hidden layer. | |||
the funtcion of last layer is softmax function. | |||
""" | |||
self.hidden_layer_sizes = hidden_layer_sizes | |||
self.activation = activation | |||
def inference(self, z, dropout_ratio=None): | |||
""" Output softmax probabilities | |||
Parameters | |||
---------- | |||
z : tf.Tensor shape : (n_samples, n_features) | |||
Data inferenced by this network | |||
dropout_ratio : tf.Tensor shape : 0-dimension float (optional) | |||
Specify dropout ratio | |||
(if None, dropout is not applied) | |||
Results | |||
------- | |||
probs : tf.Tensor shape : (n_samples, n_classes) | |||
Calculated probabilities | |||
""" | |||
with tf.variable_scope("EstNet"): | |||
n_layer = 0 | |||
for size in self.hidden_layer_sizes[:-1]: | |||
n_layer += 1 | |||
z = tf.layers.dense(z, size, activation=self.activation, | |||
name="layer_{}".format(n_layer)) | |||
if dropout_ratio is not None: | |||
z = tf.layers.dropout(z, dropout_ratio, | |||
name="drop_{}".format(n_layer)) | |||
# Last layer uses linear function (=logits) | |||
size = self.hidden_layer_sizes[-1] | |||
logits = tf.layers.dense(z, size, activation=None, name="logits") | |||
# Softmax output | |||
output = tf.nn.softmax(logits) | |||
return output |
@@ -0,0 +1,130 @@ | |||
# -*- coding: utf-8 -*- | |||
import numpy as np | |||
import tensorflow as tf | |||
import tensorflow.compat.v1 as tf | |||
tf.disable_v2_behavior() | |||
class GMM: | |||
""" Gaussian Mixture Model (GMM) """ | |||
def __init__(self, n_comp): | |||
self.n_comp = n_comp | |||
self.phi = self.mu = self.sigma = None | |||
self.training = False | |||
def create_variables(self, n_features): | |||
with tf.variable_scope("GMM"): | |||
phi = tf.Variable(tf.zeros(shape=[self.n_comp]), | |||
dtype=tf.float32, name="phi") | |||
mu = tf.Variable(tf.zeros(shape=[self.n_comp, n_features]), | |||
dtype=tf.float32, name="mu") | |||
sigma = tf.Variable(tf.zeros( | |||
shape=[self.n_comp, n_features, n_features]), | |||
dtype=tf.float32, name="sigma") | |||
L = tf.Variable(tf.zeros( | |||
shape=[self.n_comp, n_features, n_features]), | |||
dtype=tf.float32, name="L") | |||
return phi, mu, sigma, L | |||
def fit(self, z, gamma): | |||
""" fit data to GMM model | |||
Parameters | |||
---------- | |||
z : tf.Tensor, shape (n_samples, n_features) | |||
data fitted to GMM. | |||
gamma : tf.Tensor, shape (n_samples, n_comp) | |||
probability. each row is correspond to row of z. | |||
""" | |||
with tf.variable_scope("GMM"): | |||
# Calculate mu, sigma | |||
# i : index of samples | |||
# k : index of components | |||
# l,m : index of features | |||
gamma_sum = tf.reduce_sum(gamma, axis=0) | |||
self.phi = phi = tf.reduce_mean(gamma, axis=0) | |||
self.mu = mu = tf.einsum('ik,il->kl', gamma, z) / gamma_sum[:,None] | |||
z_centered = tf.sqrt(gamma[:,:,None]) * (z[:,None,:] - mu[None,:,:]) | |||
self.sigma = sigma = tf.einsum( | |||
'ikl,ikm->klm', z_centered, z_centered) / gamma_sum[:,None,None] | |||
# Calculate a cholesky decomposition of covariance in advance | |||
n_features = z.shape[1] | |||
min_vals = tf.diag(tf.ones(n_features, dtype=tf.float32)) * 1e-6 | |||
self.L = tf.cholesky(sigma + min_vals[None,:,:]) | |||
self.training = False | |||
return self | |||
def fix_op(self): | |||
""" return operator to fix paramters of GMM | |||
Using this operator outside of this class, | |||
you can fix current parameter to static tensor variable. | |||
After you call this method, you have to run result | |||
operator immediatelly, and call energy() to use static | |||
variables of model parameter. | |||
Returns | |||
------- | |||
op : operator of tensorflow | |||
operator to assign current parameter to variables | |||
""" | |||
phi, mu, sigma, L = self.create_variables(self.mu.shape[1]) | |||
op = tf.group( | |||
tf.assign(phi, self.phi), | |||
tf.assign(mu, self.mu), | |||
tf.assign(sigma, self.sigma), | |||
tf.assign(L, self.L) | |||
) | |||
self.phi, self.phi_org = phi, self.phi | |||
self.mu, self.mu_org = mu, self.mu | |||
self.sigma, self.sigma_org = sigma, self.sigma | |||
self.L, self.L_org = L, self.L | |||
self.training = False | |||
return op | |||
def energy(self, z): | |||
""" calculate an energy of each row of z | |||
Parameters | |||
---------- | |||
z : tf.Tensor, shape (n_samples, n_features) | |||
data each row of which is calculated its energy. | |||
Returns | |||
------- | |||
energy : tf.Tensor, shape (n_samples) | |||
calculated energies | |||
""" | |||
if self.training and self.phi is None: | |||
self.phi, self.mu, self.sigma, self.L = self.create_variable(z.shape[1]) | |||
with tf.variable_scope("GMM_energy"): | |||
# Instead of inverse covariance matrix, exploit cholesky decomposition | |||
# for stability of calculation. | |||
z_centered = z[:,None,:] - self.mu[None,:,:] #ikl | |||
v = tf.matrix_triangular_solve(self.L, tf.transpose(z_centered, [1, 2, 0])) # kli | |||
# log(det(Sigma)) = 2 * sum[log(diag(L))] | |||
log_det_sigma = 2.0 * tf.reduce_sum(tf.log(tf.matrix_diag_part(self.L)), axis=1) | |||
# To calculate energies, use "log-sum-exp" (different from orginal paper) | |||
d = z.get_shape().as_list()[1] | |||
logits = tf.log(self.phi[:,None]) - 0.5 * (tf.reduce_sum(tf.square(v), axis=1) | |||
+ d * tf.log(2.0 * np.pi) + log_det_sigma[:,None]) | |||
energies = - tf.reduce_logsumexp(logits, axis=0) | |||
return energies | |||
def cov_diag_loss(self): | |||
with tf.variable_scope("GMM_diag_loss"): | |||
diag_loss = tf.reduce_sum(tf.divide(1, tf.matrix_diag_part(self.sigma))) | |||
return diag_loss |
@@ -0,0 +1,114 @@ | |||
import unittest | |||
from d3m import container, utils | |||
from d3m.metadata import base as metadata_base | |||
from tods.detection_algorithm.DAGMM import DAGMMPrimitive | |||
class DAGMMTest(unittest.TestCase): | |||
def test_basic(self): | |||
self.maxDiff = None | |||
self.main = container.DataFrame({'a': [3.,5.,7.,2.], 'b': [1.,4.,7.,2.], 'c': [6.,3.,9.,17.]}, | |||
columns=['a', 'b', 'c'], | |||
generate_metadata=True) | |||
self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{ | |||
'selector': [], | |||
'metadata': { | |||
# 'top_level': 'main', | |||
'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||
'structural_type': 'd3m.container.pandas.DataFrame', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'], | |||
'dimension': { | |||
'name': 'rows', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'], | |||
'length': 4, | |||
}, | |||
}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__'], | |||
'metadata': { | |||
'dimension': { | |||
'name': 'columns', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'], | |||
'length': 3, | |||
}, | |||
}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 0], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'a'}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 1], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'b'}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 2], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'c'} | |||
}]) | |||
self.assertIsInstance(self.main, container.DataFrame) | |||
hyperparams_class = DAGMMPrimitive.metadata.get_hyperparams() | |||
hyperparams = hyperparams_class.defaults() | |||
hyperparams = hyperparams.replace({'minibatch_size': 4}) | |||
self.primitive = DAGMMPrimitive(hyperparams=hyperparams) | |||
self.primitive.set_training_data(inputs=self.main) | |||
#print("*****************",self.primitive.get_params()) | |||
self.primitive.fit() | |||
self.new_main = self.primitive.produce(inputs=self.main).value | |||
self.new_main_score = self.primitive.produce_score(inputs=self.main).value | |||
print(self.new_main) | |||
print(self.new_main_score) | |||
params = self.primitive.get_params() | |||
self.primitive.set_params(params=params) | |||
self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{ | |||
'selector': [], | |||
'metadata': { | |||
# 'top_level': 'main', | |||
'schema': metadata_base.CONTAINER_SCHEMA_VERSION, | |||
'structural_type': 'd3m.container.pandas.DataFrame', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'], | |||
'dimension': { | |||
'name': 'rows', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'], | |||
'length': 4, | |||
}, | |||
}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__'], | |||
'metadata': { | |||
'dimension': { | |||
'name': 'columns', | |||
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'], | |||
'length': 3, | |||
}, | |||
}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 0], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'a'}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 1], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'b'}, | |||
}, { | |||
'selector': ['__ALL_ELEMENTS__', 2], | |||
'metadata': {'structural_type': 'numpy.float64', 'name': 'c'} | |||
}]) | |||
# def test_params(self): | |||
# params = self.primitive.get_params() | |||
# self.primitive.set_params(params=params) | |||
if __name__ == '__main__': | |||
unittest.main() |