|
|
@@ -18,6 +18,7 @@ from gklearn.ged.median import MedianGraphEstimator |
|
|
|
from gklearn.ged.median import constant_node_costs,mge_options_to_string |
|
|
|
from gklearn.gedlib import librariesImport, gedlibpy |
|
|
|
from gklearn.utils import Timer |
|
|
|
from gklearn.utils.utils import get_graph_kernel_by_name |
|
|
|
# from gklearn.utils.dataset import Dataset |
|
|
|
|
|
|
|
class MedianPreimageGenerator(PreimageGenerator): |
|
|
@@ -81,7 +82,13 @@ class MedianPreimageGenerator(PreimageGenerator): |
|
|
|
|
|
|
|
|
|
|
|
def run(self): |
|
|
|
self.__set_graph_kernel_by_name() |
|
|
|
self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'], |
|
|
|
node_labels=self._dataset.node_labels, |
|
|
|
edge_labels=self._dataset.edge_labels, |
|
|
|
node_attrs=self._dataset.node_attrs, |
|
|
|
edge_attrs=self._dataset.edge_attrs, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
|
|
|
|
# record start time. |
|
|
|
start = time.time() |
|
|
@@ -722,43 +729,6 @@ class MedianPreimageGenerator(PreimageGenerator): |
|
|
|
print('distance in kernel space for generalized median:', self.__k_dis_gen_median) |
|
|
|
print('minimum distance in kernel space for each graph in median set:', self.__k_dis_dataset) |
|
|
|
print('distance in kernel space for each graph in median set:', k_dis_median_set) |
|
|
|
|
|
|
|
|
|
|
|
def __set_graph_kernel_by_name(self): |
|
|
|
if self._kernel_options['name'] == 'ShortestPath': |
|
|
|
from gklearn.kernels import ShortestPath |
|
|
|
self._graph_kernel = ShortestPath(node_labels=self._dataset.node_labels, |
|
|
|
node_attrs=self._dataset.node_attrs, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
elif self._kernel_options['name'] == 'StructuralSP': |
|
|
|
from gklearn.kernels import StructuralSP |
|
|
|
self._graph_kernel = StructuralSP(node_labels=self._dataset.node_labels, |
|
|
|
edge_labels=self._dataset.edge_labels, |
|
|
|
node_attrs=self._dataset.node_attrs, |
|
|
|
edge_attrs=self._dataset.edge_attrs, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
elif self._kernel_options['name'] == 'PathUpToH': |
|
|
|
from gklearn.kernels import PathUpToH |
|
|
|
self._graph_kernel = PathUpToH(node_labels=self._dataset.node_labels, |
|
|
|
edge_labels=self._dataset.edge_labels, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
elif self._kernel_options['name'] == 'Treelet': |
|
|
|
from gklearn.kernels import Treelet |
|
|
|
self._graph_kernel = Treelet(node_labels=self._dataset.node_labels, |
|
|
|
edge_labels=self._dataset.edge_labels, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
elif self._kernel_options['name'] == 'WeisfeilerLehman': |
|
|
|
from gklearn.kernels import WeisfeilerLehman |
|
|
|
self._graph_kernel = WeisfeilerLehman(node_labels=self._dataset.node_labels, |
|
|
|
edge_labels=self._dataset.edge_labels, |
|
|
|
ds_infos=self._dataset.get_dataset_infos(keys=['directed']), |
|
|
|
**self._kernel_options) |
|
|
|
else: |
|
|
|
raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') |
|
|
|
|
|
|
|
|
|
|
|
# def __clean_graph(self, G, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]): |
|
|
|