@@ -392,7 +392,7 @@ class MedianGraphEstimator(object): | |||
# Update the median. # @todo!!!!!!!!!!!!!!!!!!!!!! | |||
median_modified = self.__update_median(graphs, median) | |||
if not median_modified or self.__itrs[median_pos] == 0: | |||
decreased_order = False | |||
decreased_order = self.__decrease_order(graphs, median) | |||
if not decreased_order or self.__itrs[median_pos] == 0: | |||
increased_order = False | |||
@@ -742,6 +742,81 @@ class MedianGraphEstimator(object): | |||
return node_maps_were_modified | |||
def __decrease_order(self, graphs, median): | |||
# Print information about current iteration | |||
if self.__print_to_stdout == 2: | |||
print('Trying to decrease order: ... ', end='') | |||
# Initialize ID of the node that is to be deleted. | |||
id_deleted_node = None # @todo: or np.inf | |||
decreased_order = False | |||
# Decrease the order as long as the best deletion delta is negative. | |||
while self.__compute_best_deletion_delta(graphs, median, [id_deleted_node]) < -self.__epsilon: | |||
decreased_order = True | |||
self.__delete_node_from_median(id_deleted_node, median) | |||
# Print information about current iteration. | |||
if self.__print_to_stdout == 2: | |||
print('done.') | |||
# Return true iff the order was decreased. | |||
return decreased_order | |||
def __compute_best_deletion_delta(self, graphs, median, id_deleted_node): | |||
best_delta = 0.0 | |||
# Determine node that should be deleted (if any). | |||
for i in range(0, nx.number_of_nodes(median)): | |||
# Compute cost delta. | |||
delta = 0.0 | |||
for graph_id, graph in graphs.items(): | |||
k = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], i) | |||
if k == np.inf: | |||
delta -= self.__node_del_cost | |||
else: | |||
delta += self.__node_ins_cost - self.__ged_env.get_node_rel_cost(median.nodes[i], graph.nodes[k]) | |||
for j, j_label in median[i]: | |||
l = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], j) | |||
if k == np.inf or l == np.inf: | |||
delta -= self.__edge_del_cost | |||
elif not graph.has_edge(k, l): | |||
delta -= self.__edge_del_cost | |||
else: | |||
delta += self.__edge_ins_cost - self.__ged_env.get_edge_rel_cost(j_label, graph.edges[(k, l)]) | |||
# Update best deletion delta. | |||
if delta < best_delta - self.__epsilon: | |||
best_delta = delta | |||
id_deleted_node[0] = i | |||
return best_delta | |||
def __delete_node_from_median(self, id_deleted_node, median): | |||
# Update the nodes of the median. | |||
median.remove_node(id_deleted_node) # @todo: test if it is right. | |||
# Update the node maps. | |||
for _, node_map in self.__node_maps_from_median.items(): | |||
new_node_map = {nx.number_of_nodes(median): ''} # @todo | |||
is_unassigned_target_node = ['', True] | |||
for i in range(0, nx.number_of_nodes(median)): | |||
if i != id_deleted_node: | |||
new_i = (i if i < id_deleted_node else i - 1) | |||
k = self.__get_node_image_from_map(node_map, i) | |||
new_node_map["ds"] # @todo | |||
if k != np.inf: | |||
is_unassigned_target_node[k] = False | |||
for k in range(0, ''): | |||
if is_unassigned_target_node[k]: | |||
new_node_map.sdf[] | |||
node_map = new_node_map | |||
# Increase overall number of decreases. | |||
self.__num_decrease_order += 1 | |||
def __improve_sum_of_distances(self, timer): | |||
pass | |||
@@ -12,4 +12,4 @@ from gklearn.kernels.structural_sp import StructuralSP | |||
from gklearn.kernels.shortest_path import ShortestPath | |||
from gklearn.kernels.path_up_to_h import PathUpToH | |||
from gklearn.kernels.treelet import Treelet | |||
from gklearn.kernels.weisfeiler_lehman import WeisfeilerLehman | |||
from gklearn.kernels.weisfeiler_lehman import WeisfeilerLehman, WLSubtree |
@@ -472,4 +472,11 @@ class WeisfeilerLehman(GraphKernel): # @todo: total parallelization and sp, edge | |||
if len(self.__node_labels) == 0: | |||
for G in Gn: | |||
nx.set_node_attributes(G, '0', 'dummy') | |||
self.__node_labels.append('dummy') | |||
self.__node_labels.append('dummy') | |||
class WLSubtree(WeisfeilerLehman): | |||
def __init__(self, **kwargs): | |||
kwargs['base_kernel'] = 'subtree' | |||
super().__init__(**kwargs) |
@@ -260,20 +260,20 @@ def test_Treelet(ds_name, parallel): | |||
@pytest.mark.parametrize('ds_name', ['Acyclic']) | |||
#@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge']) | |||
@pytest.mark.parametrize('base_kernel', ['subtree']) | |||
# @pytest.mark.parametrize('base_kernel', ['subtree']) | |||
@pytest.mark.parametrize('parallel', ['imap_unordered', None]) | |||
def test_WeisfeilerLehman(ds_name, parallel, base_kernel): | |||
"""Test Weisfeiler-Lehman kernel. | |||
def test_WLSubtree(ds_name, parallel): | |||
"""Test Weisfeiler-Lehman subtree kernel. | |||
""" | |||
from gklearn.kernels import WeisfeilerLehman | |||
from gklearn.kernels import WLSubtree | |||
dataset = chooseDataset(ds_name) | |||
try: | |||
graph_kernel = WeisfeilerLehman(node_labels=dataset.node_labels, | |||
graph_kernel = WLSubtree(node_labels=dataset.node_labels, | |||
edge_labels=dataset.edge_labels, | |||
ds_infos=dataset.get_dataset_infos(keys=['directed']), | |||
height=2, base_kernel=base_kernel) | |||
height=2) | |||
gram_matrix, run_time = graph_kernel.compute(dataset.graphs, | |||
parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True) | |||
kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:], | |||
@@ -325,6 +325,12 @@ def get_graph_kernel_by_name(name, node_labels=None, edge_labels=None, node_attr | |||
edge_labels=edge_labels, | |||
ds_infos=ds_infos, | |||
**kernel_options) | |||
elif name == 'WLSubtree': | |||
from gklearn.kernels import WLSubtree | |||
graph_kernel = WLSubtree(node_labels=node_labels, | |||
edge_labels=edge_labels, | |||
ds_infos=ds_infos, | |||
**kernel_options) | |||
elif name == 'WeisfeilerLehman': | |||
from gklearn.kernels import WeisfeilerLehman | |||
graph_kernel = WeisfeilerLehman(node_labels=node_labels, | |||
@@ -332,7 +338,7 @@ def get_graph_kernel_by_name(name, node_labels=None, edge_labels=None, node_attr | |||
ds_infos=ds_infos, | |||
**kernel_options) | |||
else: | |||
raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WeisfeilerLehman".') | |||
raise Exception('The graph kernel given is not defined. Possible choices include: "StructuralSP", "ShortestPath", "PathUpToH", "Treelet", "WLSubtree", "WeisfeilerLehman".') | |||
return graph_kernel | |||