diff --git a/gklearn/utils/utils.py b/gklearn/utils/utils.py index fca19dd..5758291 100644 --- a/gklearn/utils/utils.py +++ b/gklearn/utils/utils.py @@ -366,19 +366,62 @@ def get_edge_labels(Gn, edge_label): def get_graph_kernel_by_name(name, node_labels=None, edge_labels=None, node_attrs=None, edge_attrs=None, ds_infos=None, kernel_options={}, **kwargs): if len(kwargs) != 0: kernel_options = kwargs - if name == 'Marginalized': + + if name == 'CommonWalk' or name == 'common walk': + from gklearn.kernels import CommonWalk + graph_kernel = CommonWalk(node_labels=node_labels, + edge_labels=edge_labels, + ds_infos=ds_infos, + **kernel_options) + + elif name == 'Marginalized' or name == 'marginalized': from gklearn.kernels import Marginalized graph_kernel = Marginalized(node_labels=node_labels, edge_labels=edge_labels, ds_infos=ds_infos, **kernel_options) - elif name == 'ShortestPath': + + elif name == 'SylvesterEquation' or name == 'sylvester equation': + from gklearn.kernels import SylvesterEquation + graph_kernel = SylvesterEquation( + ds_infos=ds_infos, + **kernel_options) + + elif name == 'FixedPoint' or name == 'fixed point': + from gklearn.kernels import FixedPoint + graph_kernel = FixedPoint(node_labels=node_labels, + edge_labels=edge_labels, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ds_infos=ds_infos, + **kernel_options) + + elif name == 'ConjugateGradient' or name == 'conjugate gradient': + from gklearn.kernels import ConjugateGradient + graph_kernel = ConjugateGradient(node_labels=node_labels, + edge_labels=edge_labels, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ds_infos=ds_infos, + **kernel_options) + + elif name == 'SpectralDecomposition' or name == 'spectral decomposition': + from gklearn.kernels import SpectralDecomposition + graph_kernel = SpectralDecomposition(node_labels=node_labels, + edge_labels=edge_labels, + node_attrs=node_attrs, + edge_attrs=edge_attrs, + ds_infos=ds_infos, + **kernel_options) + + elif name == 'ShortestPath' or name == 'shortest path': from gklearn.kernels import ShortestPath graph_kernel = ShortestPath(node_labels=node_labels, node_attrs=node_attrs, ds_infos=ds_infos, **kernel_options) - elif name == 'StructuralSP': + + elif name == 'StructuralSP' or name == 'structural shortest path': from gklearn.kernels import StructuralSP graph_kernel = StructuralSP(node_labels=node_labels, edge_labels=edge_labels, @@ -386,25 +429,29 @@ def get_graph_kernel_by_name(name, node_labels=None, edge_labels=None, node_attr edge_attrs=edge_attrs, ds_infos=ds_infos, **kernel_options) - elif name == 'PathUpToH': + + elif name == 'PathUpToH' or name == 'path up to length h': from gklearn.kernels import PathUpToH graph_kernel = PathUpToH(node_labels=node_labels, edge_labels=edge_labels, ds_infos=ds_infos, **kernel_options) - elif name == 'Treelet': + + elif name == 'Treelet' or name == 'treelet': from gklearn.kernels import Treelet graph_kernel = Treelet(node_labels=node_labels, edge_labels=edge_labels, ds_infos=ds_infos, **kernel_options) - elif name == 'WLSubtree': + + elif name == 'WLSubtree' or name == 'weisfeiler-lehman subtree': from gklearn.kernels import WLSubtree graph_kernel = WLSubtree(node_labels=node_labels, edge_labels=edge_labels, ds_infos=ds_infos, **kernel_options) - elif name == 'WeisfeilerLehman': + + elif name == 'WeisfeilerLehman' or name == 'weisfeiler-lehman': from gklearn.kernels import WeisfeilerLehman graph_kernel = WeisfeilerLehman(node_labels=node_labels, edge_labels=edge_labels, @@ -541,10 +588,18 @@ def get_mlti_dim_edge_attrs(G, attr_names): def normalize_gram_matrix(gram_matrix): diag = gram_matrix.diagonal().copy() + old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. for i in range(len(gram_matrix)): for j in range(i, len(gram_matrix)): - gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j]) - gram_matrix[j][i] = gram_matrix[i][j] + try: + gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j]) + except: +# rollback() + np.seterr(**old_settings) + raise + else: + gram_matrix[j][i] = gram_matrix[i][j] + np.seterr(**old_settings) return gram_matrix