@@ -637,6 +637,10 @@ class GEDEnv(object): | |||||
return [i for i in self.__internal_to_original_node_ids[graph_id].values()] | return [i for i in self.__internal_to_original_node_ids[graph_id].values()] | ||||
def get_node_cost(self, node_label_1, node_label_2): | |||||
return self.__ged_data.node_cost(node_label_1, node_label_2) | |||||
def get_node_rel_cost(self, node_label_1, node_label_2): | def get_node_rel_cost(self, node_label_1, node_label_2): | ||||
""" | """ | ||||
/*! | /*! | ||||
@@ -650,7 +654,7 @@ class GEDEnv(object): | |||||
node_label_1 = tuple(sorted(node_label_1.items(), key=lambda kv: kv[0])) | node_label_1 = tuple(sorted(node_label_1.items(), key=lambda kv: kv[0])) | ||||
if isinstance(node_label_2, dict): | if isinstance(node_label_2, dict): | ||||
node_label_2 = tuple(sorted(node_label_2.items(), key=lambda kv: kv[0])) | node_label_2 = tuple(sorted(node_label_2.items(), key=lambda kv: kv[0])) | ||||
return self.__ged_data._edit_cost.node_rel_cost_fun(node_label_1, node_label_2) | |||||
return self.__ged_data._edit_cost.node_rel_cost_fun(node_label_1, node_label_2) # @todo: may need to use node_cost() instead (or change node_cost() and modify ged_method for pre-defined cost matrices.) | |||||
def get_node_del_cost(self, node_label): | def get_node_del_cost(self, node_label): | ||||
@@ -677,6 +681,10 @@ class GEDEnv(object): | |||||
if isinstance(node_label, dict): | if isinstance(node_label, dict): | ||||
node_label = tuple(sorted(node_label.items(), key=lambda kv: kv[0])) | node_label = tuple(sorted(node_label.items(), key=lambda kv: kv[0])) | ||||
return self.__ged_data._edit_cost.node_ins_cost_fun(node_label) | return self.__ged_data._edit_cost.node_ins_cost_fun(node_label) | ||||
def get_edge_cost(self, edge_label_1, edge_label_2): | |||||
return self.__ged_data.edge_cost(edge_label_1, edge_label_2) | |||||
def get_edge_rel_cost(self, edge_label_1, edge_label_2): | def get_edge_rel_cost(self, edge_label_1, edge_label_2): | ||||
@@ -1,3 +1,4 @@ | |||||
from gklearn.ged.median.median_graph_estimator import MedianGraphEstimator | from gklearn.ged.median.median_graph_estimator import MedianGraphEstimator | ||||
from gklearn.ged.median.median_graph_estimator_py import MedianGraphEstimatorPy | from gklearn.ged.median.median_graph_estimator_py import MedianGraphEstimatorPy | ||||
from gklearn.ged.median.median_graph_estimator_cml import MedianGraphEstimatorCML | |||||
from gklearn.ged.median.utils import constant_node_costs, mge_options_to_string | from gklearn.ged.median.utils import constant_node_costs, mge_options_to_string |
@@ -1,3 +1,3 @@ | |||||
from gklearn.ged.util.lsape_solver import LSAPESolver | from gklearn.ged.util.lsape_solver import LSAPESolver | ||||
from gklearn.ged.util.util import compute_geds, ged_options_to_string | from gklearn.ged.util.util import compute_geds, ged_options_to_string | ||||
from gklearn.ged.util.util import compute_geds_cml | |||||
from gklearn.ged.util.util import compute_geds_cml, label_costs_to_matrix |
@@ -14,9 +14,10 @@ from gklearn.preimage import PreimageGenerator | |||||
from gklearn.preimage.utils import compute_k_dis | from gklearn.preimage.utils import compute_k_dis | ||||
from gklearn.ged.env import GEDEnv | from gklearn.ged.env import GEDEnv | ||||
from gklearn.ged.learning import CostMatricesLearner | from gklearn.ged.learning import CostMatricesLearner | ||||
from gklearn.ged.median import MedianGraphEstimatorPy | |||||
from gklearn.ged.median import MedianGraphEstimatorCML | |||||
from gklearn.ged.median import constant_node_costs, mge_options_to_string | from gklearn.ged.median import constant_node_costs, mge_options_to_string | ||||
from gklearn.utils.utils import get_graph_kernel_by_name | from gklearn.utils.utils import get_graph_kernel_by_name | ||||
from gklearn.ged.util import label_costs_to_matrix | |||||
class MedianPreimageGeneratorCML(PreimageGenerator): | class MedianPreimageGeneratorCML(PreimageGenerator): | ||||
@@ -347,12 +348,19 @@ class MedianPreimageGeneratorCML(PreimageGenerator): | |||||
for g in graphs: | for g in graphs: | ||||
ged_env.add_nx_graph(g, '') | ged_env.add_nx_graph(g, '') | ||||
graph_ids = ged_env.get_all_graph_ids() | graph_ids = ged_env.get_all_graph_ids() | ||||
node_labels = ged_env.get_all_node_labels() | |||||
edge_labels = ged_env.get_all_edge_labels() | |||||
node_label_costs = label_costs_to_matrix(self.__node_label_costs, len(node_labels)) | |||||
edge_label_costs = label_costs_to_matrix(self.__edge_label_costs, len(edge_labels)) | |||||
ged_env.set_label_costs(node_label_costs, edge_label_costs) | |||||
set_median_id = ged_env.add_graph('set_median') | set_median_id = ged_env.add_graph('set_median') | ||||
gen_median_id = ged_env.add_graph('gen_median') | gen_median_id = ged_env.add_graph('gen_median') | ||||
ged_env.init(init_type=self.__ged_options['init_option']) | ged_env.init(init_type=self.__ged_options['init_option']) | ||||
# Set up the madian graph estimator. | # Set up the madian graph estimator. | ||||
self.__mge = MedianGraphEstimatorPy(ged_env, constant_node_costs(self.__ged_options['edit_cost'])) | |||||
self.__mge = MedianGraphEstimatorCML(ged_env, constant_node_costs(self.__ged_options['edit_cost'])) | |||||
self.__mge.set_refine_method(self.__ged_options['method'], self.__ged_options) | self.__mge.set_refine_method(self.__ged_options['method'], self.__ged_options) | ||||
options = self.__mge_options.copy() | options = self.__mge_options.copy() | ||||
if not 'seed' in options: | if not 'seed' in options: | ||||