You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ged_env.py 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Jun 17 12:02:36 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. from gklearn.ged.env import Options, OptionsStringMap
  10. from gklearn.ged.env import GEDData
  11. class GEDEnv(object):
  12. def __init__(self):
  13. self.__initialized = False
  14. self.__new_graph_ids = []
  15. self.__ged_data = GEDData()
  16. # Variables needed for approximating ged_instance_.
  17. self.__lower_bounds = {}
  18. self.__upper_bounds = {}
  19. self.__runtimes = {}
  20. self.__node_maps = {}
  21. self.__original_to_internal_node_ids = []
  22. self.__internal_to_original_node_ids = []
  23. self.__ged_method = None
  24. def set_edit_cost(self, edit_cost, edit_cost_constants=[]):
  25. """
  26. /*!
  27. * @brief Sets the edit costs to one of the predefined edit costs.
  28. * @param[in] edit_costs Select one of the predefined edit costs.
  29. * @param[in] edit_cost_constants Constants passed to the constructor of the edit cost class selected by @p edit_costs.
  30. */
  31. """
  32. self.__ged_data._set_edit_cost(edit_cost, edit_cost_constants)
  33. def add_graph(self, graph_name='', graph_class=''):
  34. """
  35. /*!
  36. * @brief Adds a new uninitialized graph to the environment. Call init() after calling this method.
  37. * @param[in] graph_name The name of the added graph. Empty if not specified.
  38. * @param[in] graph_class The class of the added graph. Empty if not specified.
  39. * @return The ID of the newly added graph.
  40. */
  41. """
  42. # @todo: graphs are not uninitialized.
  43. self.__initialized = False
  44. graph_id = self.__ged_data._num_graphs_without_shuffled_copies
  45. self.__ged_data._num_graphs_without_shuffled_copies += 1
  46. self.__new_graph_ids.append(graph_id)
  47. self.__ged_data._graphs.append(nx.Graph())
  48. self.__ged_data._graph_names.append(graph_name)
  49. self.__ged_data._graph_classes.append(graph_class)
  50. self.__original_to_internal_node_ids.append({})
  51. self.__internal_to_original_node_ids.append({})
  52. self.__ged_data._strings_to_internal_node_ids.append({})
  53. self.__ged_data._internal_node_ids_to_strings.append({})
  54. return graph_id
  55. def add_node(self, graph_id, node_id, node_label):
  56. """
  57. /*!
  58. * @brief Adds a labeled node.
  59. * @param[in] graph_id ID of graph that has been added to the environment.
  60. * @param[in] node_id The user-specific ID of the vertex that has to be added.
  61. * @param[in] node_label The label of the vertex that has to be added. Set to ged::NoLabel() if template parameter @p UserNodeLabel equals ged::NoLabel.
  62. */
  63. """
  64. # @todo: check ids.
  65. self.__initialized = False
  66. internal_node_id = nx.number_of_nodes(self.__ged_data._graphs[graph_id])
  67. self.__ged_data._graphs[graph_id].add_node(internal_node_id, label=node_label)
  68. self.__original_to_internal_node_ids[graph_id][node_id] = internal_node_id
  69. self.__internal_to_original_node_ids[graph_id][internal_node_id] = node_id
  70. self.__ged_data._strings_to_internal_node_ids[graph_id][str(node_id)] = internal_node_id
  71. self.__ged_data._internal_node_ids_to_strings[graph_id][internal_node_id] = str(node_id)
  72. # @todo: node_label_to_id_
  73. def add_edge(self, graph_id, nd_from, nd_to, edge_label, ignore_duplicates=True):
  74. """
  75. /*!
  76. * @brief Adds a labeled edge.
  77. * @param[in] graph_id ID of graph that has been added to the environment.
  78. * @param[in] tail The user-specific ID of the tail of the edge that has to be added.
  79. * @param[in] head The user-specific ID of the head of the edge that has to be added.
  80. * @param[in] edge_label The label of the vertex that has to be added. Set to ged::NoLabel() if template parameter @p UserEdgeLabel equals ged::NoLabel.
  81. * @param[in] ignore_duplicates If @p true, duplicate edges are ignores. Otherwise, an exception is thrown if an existing edge is added to the graph.
  82. */
  83. """
  84. # @todo: check everything.
  85. self.__initialized = False
  86. # @todo: check ignore_duplicates.
  87. self.__ged_data._graphs[graph_id].add_edge(self.__original_to_internal_node_ids[graph_id][nd_from], self.__original_to_internal_node_ids[graph_id][nd_to], label=edge_label)
  88. # @todo: edge_id and label_id, edge_label_to_id_.
  89. def add_nx_graph(self, g, classe, ignore_duplicates=True) :
  90. """
  91. Add a Graph (made by networkx) on the environment. Be careful to respect the same format as GXL graphs for labelling nodes and edges.
  92. :param g: The graph to add (networkx graph)
  93. :param ignore_duplicates: If True, duplicate edges are ignored, otherwise it's raise an error if an existing edge is added. True by default
  94. :type g: networkx.graph
  95. :type ignore_duplicates: bool
  96. :return: The ID of the newly added graphe
  97. :rtype: size_t
  98. .. note:: The NX graph must respect the GXL structure. Please see how a GXL graph is construct.
  99. """
  100. graph_id = self.add_graph(g.name, classe) # check if the graph name already exists.
  101. for node in g.nodes: # @todo: if the keys of labels include int and str at the same time.
  102. self.add_node(graph_id, node, tuple(sorted(g.nodes[node].items(), key=lambda kv: kv[0])))
  103. for edge in g.edges:
  104. self.add_edge(graph_id, edge[0], edge[1], tuple(sorted(g.edges[(edge[0], edge[1])].items(), key=lambda kv: kv[0])), ignore_duplicates)
  105. return graph_id
  106. def init(self, init_type=Options.InitType.EAGER_WITHOUT_SHUFFLED_COPIES, print_to_stdout=False):
  107. if isinstance(init_type, str):
  108. init_type = OptionsStringMap.InitType[init_type]
  109. # Throw an exception if no edit costs have been selected.
  110. if self.__ged_data._edit_cost is None:
  111. raise Exception('No edit costs have been selected. Call set_edit_cost() before calling init().')
  112. # Return if the environment is initialized.
  113. if self.__initialized:
  114. return
  115. # Set initialization type.
  116. self.__ged_data._init_type = init_type
  117. # @todo: Construct shuffled graph copies if necessary.
  118. # Re-initialize adjacency matrices (also previously initialized graphs must be re-initialized because of possible re-allocation).
  119. # @todo: setup_adjacency_matrix, don't know if neccessary.
  120. self.__ged_data._max_num_nodes = np.max([nx.number_of_nodes(g) for g in self.__ged_data._graphs])
  121. self.__ged_data._max_num_edges = np.max([nx.number_of_edges(g) for g in self.__ged_data._graphs])
  122. # Initialize cost matrices if necessary.
  123. if self.__ged_data._eager_init():
  124. pass # @todo: init_cost_matrices_: 1. Update node cost matrix if new node labels have been added to the environment; 2. Update edge cost matrix if new edge labels have been added to the environment.
  125. # Mark environment as initialized.
  126. self.__initialized = True
  127. self.__new_graph_ids.clear()
  128. def set_method(self, method, options=''):
  129. """
  130. /*!
  131. * @brief Sets the GEDMethod to be used by run_method().
  132. * @param[in] method Select the method that is to be used.
  133. * @param[in] options An options string of the form @"[--@<option@> @<arg@>] [...]@" passed to the selected method.
  134. */
  135. """
  136. del self.__ged_method
  137. if isinstance(method, str):
  138. method = OptionsStringMap.GEDMethod[method]
  139. if method == Options.GEDMethod.BRANCH:
  140. self.__ged_method = Branch(self.__ged_data)
  141. elif method == Options.GEDMethod.BRANCH_FAST:
  142. self.__ged_method = BranchFast(self.__ged_data)
  143. elif method == Options.GEDMethod.BRANCH_FAST:
  144. self.__ged_method = BranchFast(self.__ged_data)
  145. elif method == Options.GEDMethod.BRANCH_TIGHT:
  146. self.__ged_method = BranchTight(self.__ged_data)
  147. elif method == Options.GEDMethod.BRANCH_UNIFORM:
  148. self.__ged_method = BranchUniform(self.__ged_data)
  149. elif method == Options.GEDMethod.BRANCH_COMPACT:
  150. self.__ged_method = BranchCompact(self.__ged_data)
  151. elif method == Options.GEDMethod.PARTITION:
  152. self.__ged_method = Partition(self.__ged_data)
  153. elif method == Options.GEDMethod.HYBRID:
  154. self.__ged_method = Hybrid(self.__ged_data)
  155. elif method == Options.GEDMethod.RING:
  156. self.__ged_method = Ring(self.__ged_data)
  157. elif method == Options.GEDMethod.ANCHOR_AWARE_GED:
  158. self.__ged_method = AnchorAwareGED(self.__ged_data)
  159. elif method == Options.GEDMethod.WALKS:
  160. self.__ged_method = Walks(self.__ged_data)
  161. elif method == Options.GEDMethod.IPFP:
  162. self.__ged_method = IPFP(self.__ged_data)
  163. elif method == Options.GEDMethod.BIPARTITE:
  164. from gklearn.ged.methods import Bipartite
  165. self.__ged_method = Bipartite(self.__ged_data)
  166. elif method == Options.GEDMethod.SUBGRAPH:
  167. self.__ged_method = Subgraph(self.__ged_data)
  168. elif method == Options.GEDMethod.NODE:
  169. self.__ged_method = Node(self.__ged_data)
  170. elif method == Options.GEDMethod.RING_ML:
  171. self.__ged_method = RingML(self.__ged_data)
  172. elif method == Options.GEDMethod.BIPARTITE_ML:
  173. self.__ged_method = BipartiteML(self.__ged_data)
  174. elif method == Options.GEDMethod.REFINE:
  175. self.__ged_method = Refine(self.__ged_data)
  176. elif method == Options.GEDMethod.BP_BEAM:
  177. self.__ged_method = BPBeam(self.__ged_data)
  178. elif method == Options.GEDMethod.SIMULATED_ANNEALING:
  179. self.__ged_method = SimulatedAnnealing(self.__ged_data)
  180. elif method == Options.GEDMethod.HED:
  181. self.__ged_method = HED(self.__ged_data)
  182. elif method == Options.GEDMethod.STAR:
  183. self.__ged_method = STAR(self.__ged_data)
  184. # #ifdef GUROBI
  185. elif method == Options.GEDMethod.F1:
  186. self.__ged_method = F1(self.__ged_data)
  187. elif method == Options.GEDMethod.F2:
  188. self.__ged_method = F2(self.__ged_data)
  189. elif method == Options.GEDMethod.COMPACT_MIP:
  190. self.__ged_method = CompactMIP(self.__ged_data)
  191. elif method == Options.GEDMethod.BLP_NO_EDGE_LABELS:
  192. self.__ged_method = BLPNoEdgeLabels(self.__ged_data)
  193. self.__ged_method.set_options(options)
  194. def run_method(self, g_id, h_id):
  195. """
  196. /*!
  197. * @brief Runs the GED method specified by call to set_method() between the graphs with IDs @p g_id and @p h_id.
  198. * @param[in] g_id ID of an input graph that has been added to the environment.
  199. * @param[in] h_id ID of an input graph that has been added to the environment.
  200. */
  201. """
  202. if g_id >= self.__ged_data.num_graphs():
  203. raise Exception('The graph with ID', str(g_id), 'has not been added to the environment.')
  204. if h_id >= self.__ged_data.num_graphs():
  205. raise Exception('The graph with ID', str(h_id), 'has not been added to the environment.')
  206. if not self.__initialized:
  207. raise Exception('The environment is uninitialized. Call init() after adding all graphs to the environment.')
  208. if self.__ged_method is None:
  209. raise Exception('No method has been set. Call set_method() before calling run().')
  210. # Call selected GEDMethod and store results.
  211. if self.__ged_data.shuffled_graph_copies_available() and (g_id == h_id):
  212. self.__ged_method.run(g_id, self.__ged_data.id_shuffled_graph_copy(h_id)) # @todo: why shuffle?
  213. else:
  214. self.__ged_method.run(g_id, h_id)
  215. self.__lower_bounds[(g_id, h_id)] = self.__ged_method.get_lower_bound()
  216. self.__upper_bounds[(g_id, h_id)] = self.__ged_method.get_upper_bound()
  217. self.__runtimes[(g_id, h_id)] = self.__ged_method.get_runtime()
  218. self.__node_maps[(g_id, h_id)] = self.__ged_method.get_node_map()
  219. def init_method(self):
  220. """Initializes the method specified by call to set_method().
  221. """
  222. if not self.__initialized:
  223. raise Exception('The environment is uninitialized. Call init() before calling init_method().')
  224. if self.__ged_method is None:
  225. raise Exception('No method has been set. Call set_method() before calling init_method().')
  226. self.__ged_method.init()
  227. def get_upper_bound(self, g_id, h_id):
  228. """
  229. /*!
  230. * @brief Returns upper bound for edit distance between the input graphs.
  231. * @param[in] g_id ID of an input graph that has been added to the environment.
  232. * @param[in] h_id ID of an input graph that has been added to the environment.
  233. * @return Upper bound computed by the last call to run_method() with arguments @p g_id and @p h_id.
  234. */
  235. """
  236. if (g_id, h_id) not in self.__upper_bounds:
  237. raise Exception('Call run(' + str(g_id) + ',' + str(h_id) + ') before calling get_upper_bound(' + str(g_id) + ',' + str(h_id) + ').')
  238. return self.__upper_bounds[(g_id, h_id)]
  239. def get_lower_bound(self, g_id, h_id):
  240. """
  241. /*!
  242. * @brief Returns lower bound for edit distance between the input graphs.
  243. * @param[in] g_id ID of an input graph that has been added to the environment.
  244. * @param[in] h_id ID of an input graph that has been added to the environment.
  245. * @return Lower bound computed by the last call to run_method() with arguments @p g_id and @p h_id.
  246. */
  247. """
  248. if (g_id, h_id) not in self.__lower_bounds:
  249. raise Exception('Call run(' + str(g_id) + ',' + str(h_id) + ') before calling get_lower_bound(' + str(g_id) + ',' + str(h_id) + ').')
  250. return self.__lower_bounds[(g_id, h_id)]
  251. def get_runtime(self, g_id, h_id):
  252. """
  253. /*!
  254. * @brief Returns runtime.
  255. * @param[in] g_id ID of an input graph that has been added to the environment.
  256. * @param[in] h_id ID of an input graph that has been added to the environment.
  257. * @return Runtime of last call to run_method() with arguments @p g_id and @p h_id.
  258. */
  259. """
  260. if (g_id, h_id) not in self.__runtimes:
  261. raise Exception('Call run(' + str(g_id) + ',' + str(h_id) + ') before calling get_runtime(' + str(g_id) + ',' + str(h_id) + ').')
  262. return self.__runtimes[(g_id, h_id)]
  263. def get_init_time(self):
  264. """
  265. /*!
  266. * @brief Returns initialization time.
  267. * @return Runtime of the last call to init_method().
  268. */
  269. """
  270. return self.__ged_method.get_init_time()
  271. def get_node_map(self, g_id, h_id):
  272. """
  273. /*!
  274. * @brief Returns node map between the input graphs.
  275. * @param[in] g_id ID of an input graph that has been added to the environment.
  276. * @param[in] h_id ID of an input graph that has been added to the environment.
  277. * @return Node map computed by the last call to run_method() with arguments @p g_id and @p h_id.
  278. */
  279. """
  280. if (g_id, h_id) not in self.__node_maps:
  281. raise Exception('Call run(' + str(g_id) + ',' + str(h_id) + ') before calling get_node_map(' + str(g_id) + ',' + str(h_id) + ').')
  282. return self.__node_maps[(g_id, h_id)]
  283. def get_forward_map(self, g_id, h_id) :
  284. """
  285. Returns the forward map (or the half of the adjacence matrix) between nodes of the two indicated graphs.
  286. :param g: The Id of the first compared graph
  287. :param h: The Id of the second compared graph
  288. :type g: size_t
  289. :type h: size_t
  290. :return: The forward map to the adjacence matrix between nodes of the two graphs
  291. :rtype: list[npy_uint32]
  292. .. seealso:: run_method(), get_upper_bound(), get_lower_bound(), get_backward_map(), get_runtime(), quasimetric_cost(), get_node_map(), get_assignment_matrix()
  293. .. warning:: run_method() between the same two graph must be called before this function.
  294. .. note:: I don't know how to connect the two map to reconstruct the adjacence matrix. Please come back when I know how it's work !
  295. """
  296. return self.get_node_map(g_id, h_id).forward_map
  297. def get_backward_map(self, g_id, h_id) :
  298. """
  299. Returns the backward map (or the half of the adjacence matrix) between nodes of the two indicated graphs.
  300. :param g: The Id of the first compared graph
  301. :param h: The Id of the second compared graph
  302. :type g: size_t
  303. :type h: size_t
  304. :return: The backward map to the adjacence matrix between nodes of the two graphs
  305. :rtype: list[npy_uint32]
  306. .. seealso:: run_method(), get_upper_bound(), get_lower_bound(), get_forward_map(), get_runtime(), quasimetric_cost(), get_node_map(), get_assignment_matrix()
  307. .. warning:: run_method() between the same two graph must be called before this function.
  308. .. note:: I don't know how to connect the two map to reconstruct the adjacence matrix. Please come back when I know how it's work !
  309. """
  310. return self.get_node_map(g_id, h_id).backward_map
  311. def get_all_graph_ids(self):
  312. return [i for i in range(0, self.__ged_data._num_graphs_without_shuffled_copies)]

A Python package for graph kernels, graph edit distances and graph pre-image problem.