Browse Source

merge DAGMM

master
lhenry15 4 years ago
parent
commit
69528414a8
7 changed files with 885 additions and 0 deletions
  1. +200
    -0
      tods/detection_algorithm/DAGMM.py
  2. +6
    -0
      tods/detection_algorithm/core/dagmm/__init__.py
  3. +121
    -0
      tods/detection_algorithm/core/dagmm/compression_net.py
  4. +251
    -0
      tods/detection_algorithm/core/dagmm/dagmm.py
  5. +63
    -0
      tods/detection_algorithm/core/dagmm/estimation_net.py
  6. +130
    -0
      tods/detection_algorithm/core/dagmm/gmm.py
  7. +114
    -0
      tods/tests/detection_algorithm/test_DAGMM.py

+ 200
- 0
tods/detection_algorithm/DAGMM.py View File

@@ -0,0 +1,200 @@
from typing import Any, Callable, List, Dict, Union, Optional, Sequence, Tuple



from d3m.metadata import hyperparams, params, base as metadata_base

from d3m.primitive_interfaces.base import CallResult, DockerContainer

from tods.detection_algorithm.core.dagmm.dagmm import DAGMM
import uuid

from d3m import container, utils as d3m_utils

from tods.detection_algorithm.UODBasePrimitive import Params_ODBase, Hyperparams_ODBase, UnsupervisedOutlierDetectorBase



__all__ = ('DAGMMPrimitive',)

Inputs = container.DataFrame
Outputs = container.DataFrame


class Params(Params_ODBase):
######## Add more Attributes #######

pass


class Hyperparams(Hyperparams_ODBase):
comp_hiddens = hyperparams.List(
default=[16,8,1],
elements=hyperparams.Hyperparameter[int](1),
description='Sizes of hidden layers of compression network.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
)

est_hiddens = hyperparams.List(
default=[8,4],
elements=hyperparams.Hyperparameter[int](1),
description='Sizes of hidden layers of estimation network.',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
)
est_dropout_ratio = hyperparams.Hyperparameter[float](
default=0.25,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Dropout rate of estimation network"
)

minibatch_size = hyperparams.Hyperparameter[int](
default=3,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Mini Batch size"
)

epoch_size = hyperparams.Hyperparameter[int](
default=100,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Epoch"
)

rand_seed = hyperparams.Hyperparameter[int](
default=0,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="(optional )random seed used when fit() is called"
)

learning_rate = hyperparams.Hyperparameter[float](
default=0.0001,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="learning rate"
)
lambda1 = hyperparams.Hyperparameter[float](
default=0.1,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="a parameter of loss function (for energy term)"
)
lambda2 = hyperparams.Hyperparameter[float](
default=0.1,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="a parameter of loss function"
)

normalize = hyperparams.Hyperparameter[bool](
default=True,
semantic_types=['https://metadata.datadrivendiscovery.org/types/ControlParameter'],
description="Specify whether input data need to be normalized."
)

contamination = hyperparams.Uniform(
lower=0.,
upper=0.5,
default=0.1,
description='the amount of contamination of the data set, i.e.the proportion of outliers in the data set. Used when fitting to define the threshold on the decision function',
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter']
)


class DAGMMPrimitive(UnsupervisedOutlierDetectorBase[Inputs, Outputs, Params, Hyperparams]):
"""
Deep Autoencoding Gaussian Mixture Model
Parameters
----------


"""

__author__ = "DATA Lab at Texas A&M University",
metadata = metadata_base.PrimitiveMetadata(
{
'__author__': "DATA Lab @Texas A&M University",
'name': "DAGMM",
'python_path': 'd3m.primitives.tods.detection_algorithm.dagmm',
'source': {'name': "DATALAB @Taxes A&M University", 'contact': 'mailto:khlai037@tamu.edu',
'uris': ['https://gitlab.com/lhenry15/tods/-/blob/Yile/anomaly-primitives/anomaly_primitives/DAGMM.py']},
'algorithm_types': [metadata_base.PrimitiveAlgorithmType.DEEPLOG],
'primitive_family': metadata_base.PrimitiveFamily.ANOMALY_DETECTION,
'id': str(uuid.uuid3(uuid.NAMESPACE_DNS, 'DAGMMPrimitive')),
'hyperparams_to_tune': ['comp_hiddens','est_hiddens','est_dropout_ratio','minibatch_size','epoch_size','rand_seed',
'learning_rate','lambda1','lambda2','contamination'],
'version': '0.0.1',
}
)

def __init__(self, *,
hyperparams: Hyperparams, #
random_seed: int = 0,
docker_containers: Dict[str, DockerContainer] = None) -> None:
super().__init__(hyperparams=hyperparams, random_seed=random_seed, docker_containers=docker_containers)
self._clf = DAGMM(comp_hiddens= hyperparams['comp_hiddens'],
est_hiddens=hyperparams['est_hiddens'],
est_dropout_ratio=hyperparams['est_dropout_ratio'],
minibatch_size=hyperparams['minibatch_size'],
epoch_size=hyperparams['epoch_size'],
random_seed=hyperparams['rand_seed'],
learning_rate=hyperparams['learning_rate'],
lambda2=hyperparams['lambda2'],
normalize=hyperparams['normalize'],
contamination=hyperparams['contamination']

)

def set_training_data(self, *, inputs: Inputs) -> None:
"""
Set training data for outlier detection.
Args:
inputs: Container DataFrame

Returns:
None
"""
super().set_training_data(inputs=inputs)

def fit(self, *, timeout: float = None, iterations: int = None) -> CallResult[None]:
"""
Fit model with training data.
Args:
*: Container DataFrame. Time series data up to fit.

Returns:
None
"""

return super().fit()

def produce(self, *, inputs: Inputs, timeout: float = None, iterations: int = None) -> CallResult[Outputs]:
"""
Process the testing data.
Args:
inputs: Container DataFrame. Time series data up to outlier detection.

Returns:
Container DataFrame
1 marks Outliers, 0 marks normal.
"""
return super().produce(inputs=inputs, timeout=timeout, iterations=iterations)

def get_params(self) -> Params:
"""
Return parameters.
Args:
None

Returns:
class Params
"""
return super().get_params()

def set_params(self, *, params: Params) -> None:
"""
Set parameters for outlier detection.
Args:
params: class Params

Returns:
None
"""
super().set_params(params=params)



+ 6
- 0
tods/detection_algorithm/core/dagmm/__init__.py View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-

from .compression_net import CompressionNet
from .estimation_net import EstimationNet
from .gmm import GMM
from .dagmm import DAGMM

+ 121
- 0
tods/detection_algorithm/core/dagmm/compression_net.py View File

@@ -0,0 +1,121 @@
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

class CompressionNet:
""" Compression Network.
This network converts the input data to the representations
suitable for calculation of anormaly scores by "Estimation Network".

Outputs of network consist of next 2 components:
1) reduced low-dimensional representations learned by AutoEncoder.
2) the features derived from reconstruction error.
"""
def __init__(self, hidden_layer_sizes, activation=tf.nn.tanh):
"""
Parameters
----------
hidden_layer_sizes : list of int
list of the size of hidden layers.
For example, if the sizes are [n1, n2],
the sizes of created networks are:
input_size -> n1 -> n2 -> n1 -> input_sizes
(network outputs the representation of "n2" layer)
activation : function
activation function of hidden layer.
the last layer uses linear function.
"""
self.hidden_layer_sizes = hidden_layer_sizes
self.activation = activation

def compress(self, x):
self.input_size = x.shape[1]

with tf.variable_scope("Encoder"):
z = x
n_layer = 0
for size in self.hidden_layer_sizes[:-1]:
n_layer += 1
z = tf.layers.dense(z, size, activation=self.activation,
name="layer_{}".format(n_layer))

# activation function of last layer is linear
n_layer += 1
z = tf.layers.dense(z, self.hidden_layer_sizes[-1],
name="layer_{}".format(n_layer))

return z

def reverse(self, z):
with tf.variable_scope("Decoder"):
n_layer = 0
for size in self.hidden_layer_sizes[:-1][::-1]:
n_layer += 1
z = tf.layers.dense(z, size, activation=self.activation,
name="layer_{}".format(n_layer))

# activation function of last layes is linear
n_layer += 1
x_dash = tf.layers.dense(z, self.input_size,
name="layer_{}".format(n_layer))

return x_dash

def loss(self, x, x_dash):
def euclid_norm(x):
return tf.sqrt(tf.reduce_sum(tf.square(x), axis=1))

# Calculate Euclid norm, distance
norm_x = euclid_norm(x)
norm_x_dash = euclid_norm(x_dash)
dist_x = euclid_norm(x - x_dash)
dot_x = tf.reduce_sum(x * x_dash, axis=1)

# Based on the original paper, features of reconstraction error
# are composed of these loss functions:
# 1. loss_E : relative Euclidean distance
# 2. loss_C : cosine similarity
min_val = 1e-3
loss_E = dist_x / (norm_x + min_val)
loss_C = 0.5 * (1.0 - dot_x / (norm_x * norm_x_dash + min_val))
return tf.concat([loss_E[:,None], loss_C[:,None]], axis=1)

def extract_feature(self, x, x_dash, z_c):
z_r = self.loss(x, x_dash)
return tf.concat([z_c, z_r], axis=1)

def inference(self, x):
""" convert input to output tensor, which is composed of
low-dimensional representation and reconstruction error.

Parameters
----------
x : tf.Tensor shape : (n_samples, n_features)
Input data

Results
-------
z : tf.Tensor shape : (n_samples, n2 + 2)
Result data
Second dimension of this data is equal to
sum of compressed representation size and
number of loss function (=2)

x_dash : tf.Tensor shape : (n_samples, n_features)
Reconstructed data for calculation of
reconstruction error.
"""

with tf.variable_scope("CompNet"):
# AutoEncoder
z_c = self.compress(x)
x_dash = self.reverse(z_c)

# compose feature vector
z = self.extract_feature(x, x_dash, z_c)

return z, x_dash

def reconstruction_error(self, x, x_dash):
return tf.reduce_mean(tf.reduce_sum(
tf.square(x - x_dash), axis=1), axis=0)

+ 251
- 0
tods/detection_algorithm/core/dagmm/dagmm.py View File

@@ -0,0 +1,251 @@
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib

from .compression_net import CompressionNet
from .estimation_net import EstimationNet
from .gmm import GMM
from pyod.utils.stat_models import pairwise_distances_no_broadcast

from os import makedirs
from os.path import exists, join
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

from pyod.models.base import BaseDetector

class DAGMM(BaseDetector):
""" Deep Autoencoding Gaussian Mixture Model.

This implementation is based on the paper:
Bo Zong+ (2018) Deep Autoencoding Gaussian Mixture Model
for Unsupervised Anomaly Detection, ICLR 2018
(this is UNOFFICIAL implementation)
"""

MODEL_FILENAME = "DAGMM_model"
SCALER_FILENAME = "DAGMM_scaler"

def __init__(self, comp_hiddens:list = [16,8,1],
est_hiddens:list = [8,4], est_dropout_ratio:float =0.5,
minibatch_size:int = 1024, epoch_size:int =100,
learning_rate:float =0.0001, lambda1:float =0.1, lambda2:float =0.0001,
normalize:bool=True, random_seed:int=123 , contamination:float = 0.001 ):
"""
Parameters
----------
comp_hiddens : list of int
sizes of hidden layers of compression network
For example, if the sizes are [n1, n2],
structure of compression network is:
input_size -> n1 -> n2 -> n1 -> input_sizes

est_hiddens : list of int
sizes of hidden layers of estimation network.
The last element of this list is assigned as n_comp.
For example, if the sizes are [n1, n2],
structure of estimation network is:
input_size -> n1 -> n2 (= n_comp)

est_dropout_ratio : float (optional)
dropout ratio of estimation network applied during training
if 0 or None, dropout is not applied.
minibatch_size: int (optional)
mini batch size during training
epoch_size : int (optional)
epoch size during training
learning_rate : float (optional)
learning rate during training
lambda1 : float (optional)
a parameter of loss function (for energy term)
lambda2 : float (optional)
a parameter of loss function
(for sum of diagonal elements of covariance)
normalize : bool (optional)
specify whether input data need to be normalized.
by default, input data is normalized.
random_seed : int (optional)
random seed used when fit() is called.
"""
est_activation = tf.nn.tanh
comp_activation = tf.nn.tanh
super(DAGMM, self).__init__(contamination=contamination)
self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation)
self.est_dropout_ratio = est_dropout_ratio

n_comp = est_hiddens[-1]
self.gmm = GMM(n_comp)

self.minibatch_size = minibatch_size
self.epoch_size = epoch_size
self.learning_rate = learning_rate
self.lambda1 = lambda1
self.lambda2 = lambda2

self.normalize = normalize
self.scaler = None
self.seed = random_seed

self.graph = None
self.sess = None

#def __del__(self):
# if self.sess is not None:
# self.sess.close()

def fit(self,X,y=None):
""" Fit the DAGMM model according to the given data.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data.
"""

n_samples, n_features = X.shape

if self.normalize:
self.scaler = scaler = StandardScaler()
X = scaler.fit_transform(X)

with tf.Graph().as_default() as graph:
self.graph = graph
tf.set_random_seed(self.seed)
np.random.seed(seed=self.seed)

# Create Placeholder
self.input = input = tf.placeholder(
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[])

# Build graph
z, x_dash = self.comp_net.inference(input)
gamma = self.est_net.inference(z, drop)
self.gmm.fit(z, gamma)
energy = self.gmm.energy(z)

self.x_dash = x_dash

# Loss function
loss = (self.comp_net.reconstruction_error(input, x_dash) +
self.lambda1 * tf.reduce_mean(energy) +
self.lambda2 * self.gmm.cov_diag_loss())

# Minimizer
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)

# Number of batch
n_batch = (n_samples - 1) // self.minibatch_size + 1

# Create tensorflow session and initilize
init = tf.global_variables_initializer()

self.sess = tf.Session(graph=graph)
self.sess.run(init)

# Training
idx = np.arange(X.shape[0])
np.random.shuffle(idx)

for epoch in range(self.epoch_size):
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = X[idx[i_start:i_end]]

self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio})
if (epoch + 1) % 10 == 0:
loss_val = self.sess.run(loss, feed_dict={input:X, drop:0})
print(" epoch {}/{} : loss = {:.3f}".format(epoch + 1, self.epoch_size, loss_val))

# Fix GMM parameter
fix = self.gmm.fix_op()
self.sess.run(fix, feed_dict={input:X, drop:0})
self.energy = self.gmm.energy(z)

tf.add_to_collection("save", self.input)
tf.add_to_collection("save", self.energy)

self.saver = tf.train.Saver()

pred_scores = self.decision_function(X)
self.decision_scores_ = pred_scores
self._process_decision_scores()
#return self

def decision_function(self, X):
""" Calculate anormaly scores (sample energy) on samples in X.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Data for which anomaly scores are calculated.
n_features must be equal to n_features of the fitted data.

Returns
-------
energies : array-like, shape (n_samples)
Calculated sample energies.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")

if self.normalize:
X = self.scaler.transform(X)

energies = self.sess.run(self.energy, feed_dict={self.input:X})

return energies.reshape(1,-1)

def save(self, fdir):
""" Save trained model to designated directory.
This method have to be called after training.
(If not, throw an exception)

Parameters
----------
fdir : str
Path of directory trained model is saved.
If not exists, it is created automatically.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")

if not exists(fdir):
makedirs(fdir)

model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path)

if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
joblib.dump(self.scaler, scaler_path)

def restore(self, fdir):
""" Restore trained model from designated directory.

Parameters
----------
fdir : str
Path of directory trained model is saved.
"""
if not exists(fdir):
raise Exception("Model directory does not exist.")

model_path = join(fdir, self.MODEL_FILENAME)
meta_path = model_path + ".meta"

with tf.Graph().as_default() as graph:
self.graph = graph
self.sess = tf.Session(graph=graph)
self.saver = tf.train.import_meta_graph(meta_path)
self.saver.restore(self.sess, model_path)

self.input, self.energy = tf.get_collection("save")

if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
self.scaler = joblib.load(scaler_path)

+ 63
- 0
tods/detection_algorithm/core/dagmm/estimation_net.py View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

class EstimationNet:
""" Estimation Network

This network converts input feature vector to softmax probability.
Bacause loss function for this network is not defined,
it should be implemented outside of this class.
"""
def __init__(self, hidden_layer_sizes, activation=tf.nn.relu):
"""
Parameters
----------
hidden_layer_sizes : list of int
list of sizes of hidden layers.
For example, if the sizes are [n1, n2],
layer sizes of the network are:
input_size -> n1 -> n2
(network outputs the softmax probabilities of "n2" layer)
activation : function
activation function of hidden layer.
the funtcion of last layer is softmax function.
"""
self.hidden_layer_sizes = hidden_layer_sizes
self.activation = activation

def inference(self, z, dropout_ratio=None):
""" Output softmax probabilities

Parameters
----------
z : tf.Tensor shape : (n_samples, n_features)
Data inferenced by this network
dropout_ratio : tf.Tensor shape : 0-dimension float (optional)
Specify dropout ratio
(if None, dropout is not applied)

Results
-------
probs : tf.Tensor shape : (n_samples, n_classes)
Calculated probabilities
"""
with tf.variable_scope("EstNet"):
n_layer = 0
for size in self.hidden_layer_sizes[:-1]:
n_layer += 1
z = tf.layers.dense(z, size, activation=self.activation,
name="layer_{}".format(n_layer))
if dropout_ratio is not None:
z = tf.layers.dropout(z, dropout_ratio,
name="drop_{}".format(n_layer))

# Last layer uses linear function (=logits)
size = self.hidden_layer_sizes[-1]
logits = tf.layers.dense(z, size, activation=None, name="logits")

# Softmax output
output = tf.nn.softmax(logits)

return output

+ 130
- 0
tods/detection_algorithm/core/dagmm/gmm.py View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
class GMM:
""" Gaussian Mixture Model (GMM) """
def __init__(self, n_comp):
self.n_comp = n_comp
self.phi = self.mu = self.sigma = None
self.training = False

def create_variables(self, n_features):
with tf.variable_scope("GMM"):
phi = tf.Variable(tf.zeros(shape=[self.n_comp]),
dtype=tf.float32, name="phi")
mu = tf.Variable(tf.zeros(shape=[self.n_comp, n_features]),
dtype=tf.float32, name="mu")
sigma = tf.Variable(tf.zeros(
shape=[self.n_comp, n_features, n_features]),
dtype=tf.float32, name="sigma")
L = tf.Variable(tf.zeros(
shape=[self.n_comp, n_features, n_features]),
dtype=tf.float32, name="L")

return phi, mu, sigma, L

def fit(self, z, gamma):
""" fit data to GMM model

Parameters
----------
z : tf.Tensor, shape (n_samples, n_features)
data fitted to GMM.
gamma : tf.Tensor, shape (n_samples, n_comp)
probability. each row is correspond to row of z.
"""

with tf.variable_scope("GMM"):
# Calculate mu, sigma
# i : index of samples
# k : index of components
# l,m : index of features
gamma_sum = tf.reduce_sum(gamma, axis=0)
self.phi = phi = tf.reduce_mean(gamma, axis=0)
self.mu = mu = tf.einsum('ik,il->kl', gamma, z) / gamma_sum[:,None]
z_centered = tf.sqrt(gamma[:,:,None]) * (z[:,None,:] - mu[None,:,:])
self.sigma = sigma = tf.einsum(
'ikl,ikm->klm', z_centered, z_centered) / gamma_sum[:,None,None]

# Calculate a cholesky decomposition of covariance in advance
n_features = z.shape[1]
min_vals = tf.diag(tf.ones(n_features, dtype=tf.float32)) * 1e-6
self.L = tf.cholesky(sigma + min_vals[None,:,:])

self.training = False
return self

def fix_op(self):
""" return operator to fix paramters of GMM
Using this operator outside of this class,
you can fix current parameter to static tensor variable.

After you call this method, you have to run result
operator immediatelly, and call energy() to use static
variables of model parameter.

Returns
-------
op : operator of tensorflow
operator to assign current parameter to variables
"""

phi, mu, sigma, L = self.create_variables(self.mu.shape[1])

op = tf.group(
tf.assign(phi, self.phi),
tf.assign(mu, self.mu),
tf.assign(sigma, self.sigma),
tf.assign(L, self.L)
)

self.phi, self.phi_org = phi, self.phi
self.mu, self.mu_org = mu, self.mu
self.sigma, self.sigma_org = sigma, self.sigma
self.L, self.L_org = L, self.L

self.training = False

return op

def energy(self, z):
""" calculate an energy of each row of z

Parameters
----------
z : tf.Tensor, shape (n_samples, n_features)
data each row of which is calculated its energy.

Returns
-------
energy : tf.Tensor, shape (n_samples)
calculated energies
"""

if self.training and self.phi is None:
self.phi, self.mu, self.sigma, self.L = self.create_variable(z.shape[1])

with tf.variable_scope("GMM_energy"):
# Instead of inverse covariance matrix, exploit cholesky decomposition
# for stability of calculation.
z_centered = z[:,None,:] - self.mu[None,:,:] #ikl
v = tf.matrix_triangular_solve(self.L, tf.transpose(z_centered, [1, 2, 0])) # kli

# log(det(Sigma)) = 2 * sum[log(diag(L))]
log_det_sigma = 2.0 * tf.reduce_sum(tf.log(tf.matrix_diag_part(self.L)), axis=1)

# To calculate energies, use "log-sum-exp" (different from orginal paper)
d = z.get_shape().as_list()[1]
logits = tf.log(self.phi[:,None]) - 0.5 * (tf.reduce_sum(tf.square(v), axis=1)
+ d * tf.log(2.0 * np.pi) + log_det_sigma[:,None])
energies = - tf.reduce_logsumexp(logits, axis=0)

return energies

def cov_diag_loss(self):
with tf.variable_scope("GMM_diag_loss"):
diag_loss = tf.reduce_sum(tf.divide(1, tf.matrix_diag_part(self.sigma)))

return diag_loss

+ 114
- 0
tods/tests/detection_algorithm/test_DAGMM.py View File

@@ -0,0 +1,114 @@
import unittest

from d3m import container, utils
from d3m.metadata import base as metadata_base
from tods.detection_algorithm.DAGMM import DAGMMPrimitive



class DAGMMTest(unittest.TestCase):
def test_basic(self):
self.maxDiff = None
self.main = container.DataFrame({'a': [3.,5.,7.,2.], 'b': [1.,4.,7.,2.], 'c': [6.,3.,9.,17.]},
columns=['a', 'b', 'c'],
generate_metadata=True)




self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{
'selector': [],
'metadata': {
# 'top_level': 'main',
'schema': metadata_base.CONTAINER_SCHEMA_VERSION,
'structural_type': 'd3m.container.pandas.DataFrame',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'],
'dimension': {
'name': 'rows',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'],
'length': 4,
},
},
}, {
'selector': ['__ALL_ELEMENTS__'],
'metadata': {
'dimension': {
'name': 'columns',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'],
'length': 3,
},
},
}, {
'selector': ['__ALL_ELEMENTS__', 0],
'metadata': {'structural_type': 'numpy.float64', 'name': 'a'},
}, {
'selector': ['__ALL_ELEMENTS__', 1],
'metadata': {'structural_type': 'numpy.float64', 'name': 'b'},
}, {
'selector': ['__ALL_ELEMENTS__', 2],
'metadata': {'structural_type': 'numpy.float64', 'name': 'c'}
}])


self.assertIsInstance(self.main, container.DataFrame)


hyperparams_class = DAGMMPrimitive.metadata.get_hyperparams()
hyperparams = hyperparams_class.defaults()
hyperparams = hyperparams.replace({'minibatch_size': 4})


self.primitive = DAGMMPrimitive(hyperparams=hyperparams)
self.primitive.set_training_data(inputs=self.main)
#print("*****************",self.primitive.get_params())

self.primitive.fit()
self.new_main = self.primitive.produce(inputs=self.main).value
self.new_main_score = self.primitive.produce_score(inputs=self.main).value
print(self.new_main)
print(self.new_main_score)

params = self.primitive.get_params()
self.primitive.set_params(params=params)

self.assertEqual(utils.to_json_structure(self.main.metadata.to_internal_simple_structure()), [{
'selector': [],
'metadata': {
# 'top_level': 'main',
'schema': metadata_base.CONTAINER_SCHEMA_VERSION,
'structural_type': 'd3m.container.pandas.DataFrame',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/Table'],
'dimension': {
'name': 'rows',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularRow'],
'length': 4,
},
},
}, {
'selector': ['__ALL_ELEMENTS__'],
'metadata': {
'dimension': {
'name': 'columns',
'semantic_types': ['https://metadata.datadrivendiscovery.org/types/TabularColumn'],
'length': 3,
},
},
}, {
'selector': ['__ALL_ELEMENTS__', 0],
'metadata': {'structural_type': 'numpy.float64', 'name': 'a'},
}, {
'selector': ['__ALL_ELEMENTS__', 1],
'metadata': {'structural_type': 'numpy.float64', 'name': 'b'},
}, {
'selector': ['__ALL_ELEMENTS__', 2],
'metadata': {'structural_type': 'numpy.float64', 'name': 'c'}
}])

# def test_params(self):
# params = self.primitive.get_params()
# self.primitive.set_params(params=params)



if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save