From b463852b2d1f9fb3bcfce9d2002d53a457f1c869 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Sat, 18 Apr 2020 16:52:42 +0200 Subject: [PATCH] Move method get_graph_kernel_by_name(). --- gklearn/preimage/median_preimage_generator.py | 46 +++++---------------------- 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/gklearn/preimage/median_preimage_generator.py b/gklearn/preimage/median_preimage_generator.py index f342465..ead6a9a 100644 --- a/gklearn/preimage/median_preimage_generator.py +++ b/gklearn/preimage/median_preimage_generator.py @@ -18,6 +18,7 @@ from gklearn.ged.median import MedianGraphEstimator from gklearn.ged.median import constant_node_costs,mge_options_to_string from gklearn.gedlib import librariesImport, gedlibpy from gklearn.utils import Timer +from gklearn.utils.utils import get_graph_kernel_by_name # from gklearn.utils.dataset import Dataset class MedianPreimageGenerator(PreimageGenerator): @@ -81,7 +82,13 @@ class MedianPreimageGenerator(PreimageGenerator): def run(self): - self.__set_graph_kernel_by_name() + self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'], + node_labels=self._dataset.node_labels, + edge_labels=self._dataset.edge_labels, + node_attrs=self._dataset.node_attrs, + edge_attrs=self._dataset.edge_attrs, + ds_infos=self._dataset.get_dataset_infos(keys=['directed']), + **self._kernel_options) # record start time. start = time.time() @@ -722,43 +729,6 @@ class MedianPreimageGenerator(PreimageGenerator): print('distance in kernel space for generalized median:', self.__k_dis_gen_median) print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) print('distance in kernel space for each graph in median set:', k_dis_median_set) - - - def __set_graph_kernel_by_name(self): - if self._kernel_options['name'] == 'ShortestPath': - from gklearn.kernels import ShortestPath - self._graph_kernel = ShortestPath(node_labels=self._dataset.node_labels, - node_attrs=self._dataset.node_attrs, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'StructuralSP': - from gklearn.kernels import StructuralSP - self._graph_kernel = StructuralSP(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - node_attrs=self._dataset.node_attrs, - edge_attrs=self._dataset.edge_attrs, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'PathUpToH': - from gklearn.kernels import PathUpToH - self._graph_kernel = PathUpToH(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'Treelet': - from gklearn.kernels import Treelet - self._graph_kernel = Treelet(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'WeisfeilerLehman': - from gklearn.kernels import WeisfeilerLehman - self._graph_kernel = WeisfeilerLehman(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - else: - raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') # def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):