diff --git a/gklearn/preimage/median_preimage_generator.py b/gklearn/preimage/median_preimage_generator.py index f342465..ead6a9a 100644 --- a/gklearn/preimage/median_preimage_generator.py +++ b/gklearn/preimage/median_preimage_generator.py @@ -18,6 +18,7 @@ from gklearn.ged.median import MedianGraphEstimator from gklearn.ged.median import constant_node_costs,mge_options_to_string from gklearn.gedlib import librariesImport, gedlibpy from gklearn.utils import Timer +from gklearn.utils.utils import get_graph_kernel_by_name # from gklearn.utils.dataset import Dataset class MedianPreimageGenerator(PreimageGenerator): @@ -81,7 +82,13 @@ class MedianPreimageGenerator(PreimageGenerator): def run(self): - self.__set_graph_kernel_by_name() + self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'], + node_labels=self._dataset.node_labels, + edge_labels=self._dataset.edge_labels, + node_attrs=self._dataset.node_attrs, + edge_attrs=self._dataset.edge_attrs, + ds_infos=self._dataset.get_dataset_infos(keys=['directed']), + **self._kernel_options) # record start time. start = time.time() @@ -722,43 +729,6 @@ class MedianPreimageGenerator(PreimageGenerator): print('distance in kernel space for generalized median:', self.__k_dis_gen_median) print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) print('distance in kernel space for each graph in median set:', k_dis_median_set) - - - def __set_graph_kernel_by_name(self): - if self._kernel_options['name'] == 'ShortestPath': - from gklearn.kernels import ShortestPath - self._graph_kernel = ShortestPath(node_labels=self._dataset.node_labels, - node_attrs=self._dataset.node_attrs, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'StructuralSP': - from gklearn.kernels import StructuralSP - self._graph_kernel = StructuralSP(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - node_attrs=self._dataset.node_attrs, - edge_attrs=self._dataset.edge_attrs, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'PathUpToH': - from gklearn.kernels import PathUpToH - self._graph_kernel = PathUpToH(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'Treelet': - from gklearn.kernels import Treelet - self._graph_kernel = Treelet(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - elif self._kernel_options['name'] == 'WeisfeilerLehman': - from gklearn.kernels import WeisfeilerLehman - self._graph_kernel = WeisfeilerLehman(node_labels=self._dataset.node_labels, - edge_labels=self._dataset.edge_labels, - ds_infos=self._dataset.get_dataset_infos(keys=['directed']), - **self._kernel_options) - else: - raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') # def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):