You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ged.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 17 18:44:59 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. from tqdm import tqdm
  10. import sys
  11. import multiprocessing
  12. from multiprocessing import Pool
  13. from functools import partial
  14. from gedlibpy import librariesImport, gedlibpy
  15. def GED(g1, g2, lib='gedlibpy', cost='CHEM_1', method='IPFP',
  16. edit_cost_constant=[], stabilizer='min', repeat=50):
  17. """
  18. Compute GED for 2 graphs.
  19. """
  20. if lib == 'gedlibpy':
  21. def convertGraph(G):
  22. """Convert a graph to the proper NetworkX format that can be
  23. recognized by library gedlibpy.
  24. """
  25. G_new = nx.Graph()
  26. for nd, attrs in G.nodes(data=True):
  27. G_new.add_node(str(nd), chem=attrs['atom'])
  28. # G_new.add_node(str(nd), x=str(attrs['attributes'][0]),
  29. # y=str(attrs['attributes'][1]))
  30. for nd1, nd2, attrs in G.edges(data=True):
  31. G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  32. # G_new.add_edge(str(nd1), str(nd2))
  33. return G_new
  34. gedlibpy.restart_env()
  35. gedlibpy.add_nx_graph(convertGraph(g1), "")
  36. gedlibpy.add_nx_graph(convertGraph(g2), "")
  37. listID = gedlibpy.get_all_graph_ids()
  38. gedlibpy.set_edit_cost(cost, edit_cost_constant=edit_cost_constant)
  39. gedlibpy.init()
  40. gedlibpy.set_method(method, "")
  41. gedlibpy.init_method()
  42. g = listID[0]
  43. h = listID[1]
  44. if stabilizer == None:
  45. gedlibpy.run_method(g, h)
  46. pi_forward = gedlibpy.get_forward_map(g, h)
  47. pi_backward = gedlibpy.get_backward_map(g, h)
  48. upper = gedlibpy.get_upper_bound(g, h)
  49. lower = gedlibpy.get_lower_bound(g, h)
  50. elif stabilizer == 'mean':
  51. # @todo: to be finished...
  52. upper_list = [np.inf] * repeat
  53. for itr in range(repeat):
  54. gedlibpy.run_method(g, h)
  55. upper_list[itr] = gedlibpy.get_upper_bound(g, h)
  56. pi_forward = gedlibpy.get_forward_map(g, h)
  57. pi_backward = gedlibpy.get_backward_map(g, h)
  58. lower = gedlibpy.get_lower_bound(g, h)
  59. upper = np.mean(upper_list)
  60. elif stabilizer == 'median':
  61. if repeat % 2 == 0:
  62. repeat += 1
  63. upper_list = [np.inf] * repeat
  64. pi_forward_list = [0] * repeat
  65. pi_backward_list = [0] * repeat
  66. for itr in range(repeat):
  67. gedlibpy.run_method(g, h)
  68. upper_list[itr] = gedlibpy.get_upper_bound(g, h)
  69. pi_forward_list[itr] = gedlibpy.get_forward_map(g, h)
  70. pi_backward_list[itr] = gedlibpy.get_backward_map(g, h)
  71. lower = gedlibpy.get_lower_bound(g, h)
  72. upper = np.median(upper_list)
  73. idx_median = upper_list.index(upper)
  74. pi_forward = pi_forward_list[idx_median]
  75. pi_backward = pi_backward_list[idx_median]
  76. elif stabilizer == 'min':
  77. upper = np.inf
  78. for itr in range(repeat):
  79. gedlibpy.run_method(g, h)
  80. upper_tmp = gedlibpy.get_upper_bound(g, h)
  81. if upper_tmp < upper:
  82. upper = upper_tmp
  83. pi_forward = gedlibpy.get_forward_map(g, h)
  84. pi_backward = gedlibpy.get_backward_map(g, h)
  85. lower = gedlibpy.get_lower_bound(g, h)
  86. if upper == 0:
  87. break
  88. elif stabilizer == 'max':
  89. upper = 0
  90. for itr in range(repeat):
  91. gedlibpy.run_method(g, h)
  92. upper_tmp = gedlibpy.get_upper_bound(g, h)
  93. if upper_tmp > upper:
  94. upper = upper_tmp
  95. pi_forward = gedlibpy.get_forward_map(g, h)
  96. pi_backward = gedlibpy.get_backward_map(g, h)
  97. lower = gedlibpy.get_lower_bound(g, h)
  98. elif stabilizer == 'gaussian':
  99. pass
  100. dis = upper
  101. # make the map label correct (label remove map as np.inf)
  102. nodes1 = [n for n in g1.nodes()]
  103. nodes2 = [n for n in g2.nodes()]
  104. nb1 = nx.number_of_nodes(g1)
  105. nb2 = nx.number_of_nodes(g2)
  106. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  107. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  108. return dis, pi_forward, pi_backward
  109. def GED_n(Gn, lib='gedlibpy', cost='CHEM_1', method='IPFP',
  110. edit_cost_constant=[], stabilizer='min', repeat=50):
  111. """
  112. Compute GEDs for a group of graphs.
  113. """
  114. if lib == 'gedlibpy':
  115. def convertGraph(G):
  116. """Convert a graph to the proper NetworkX format that can be
  117. recognized by library gedlibpy.
  118. """
  119. G_new = nx.Graph()
  120. for nd, attrs in G.nodes(data=True):
  121. G_new.add_node(str(nd), chem=attrs['atom'])
  122. for nd1, nd2, attrs in G.edges(data=True):
  123. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  124. G_new.add_edge(str(nd1), str(nd2))
  125. return G_new
  126. gedlibpy.restart_env()
  127. gedlibpy.add_nx_graph(convertGraph(g1), "")
  128. gedlibpy.add_nx_graph(convertGraph(g2), "")
  129. listID = gedlibpy.get_all_graph_ids()
  130. gedlibpy.set_edit_cost(cost, edit_cost_constant=edit_cost_constant)
  131. gedlibpy.init()
  132. gedlibpy.set_method(method, "")
  133. gedlibpy.init_method()
  134. g = listID[0]
  135. h = listID[1]
  136. if stabilizer == None:
  137. gedlibpy.run_method(g, h)
  138. pi_forward = gedlibpy.get_forward_map(g, h)
  139. pi_backward = gedlibpy.get_backward_map(g, h)
  140. upper = gedlibpy.get_upper_bound(g, h)
  141. lower = gedlibpy.get_lower_bound(g, h)
  142. elif stabilizer == 'min':
  143. upper = np.inf
  144. for itr in range(repeat):
  145. gedlibpy.run_method(g, h)
  146. upper_tmp = gedlibpy.get_upper_bound(g, h)
  147. if upper_tmp < upper:
  148. upper = upper_tmp
  149. pi_forward = gedlibpy.get_forward_map(g, h)
  150. pi_backward = gedlibpy.get_backward_map(g, h)
  151. lower = gedlibpy.get_lower_bound(g, h)
  152. if upper == 0:
  153. break
  154. dis = upper
  155. # make the map label correct (label remove map as np.inf)
  156. nodes1 = [n for n in g1.nodes()]
  157. nodes2 = [n for n in g2.nodes()]
  158. nb1 = nx.number_of_nodes(g1)
  159. nb2 = nx.number_of_nodes(g2)
  160. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  161. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  162. return dis, pi_forward, pi_backward
  163. def ged_median(Gn, Gn_median, verbose=False, params_ged={'lib': 'gedlibpy',
  164. 'cost': 'CHEM_1', 'method': 'IPFP', 'edit_cost_constant': [],
  165. 'stabilizer': 'min', 'repeat': 50}, parallel=False):
  166. if parallel:
  167. len_itr = int(len(Gn))
  168. pi_forward_list = [[] for i in range(len_itr)]
  169. dis_list = [0 for i in range(len_itr)]
  170. itr = range(0, len_itr)
  171. n_jobs = multiprocessing.cpu_count()
  172. if len_itr < 100 * n_jobs:
  173. chunksize = int(len_itr / n_jobs) + 1
  174. else:
  175. chunksize = 100
  176. def init_worker(gn_toshare, gn_median_toshare):
  177. global G_gn, G_gn_median
  178. G_gn = gn_toshare
  179. G_gn_median = gn_median_toshare
  180. do_partial = partial(_compute_ged_median, params_ged)
  181. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(Gn, Gn_median))
  182. if verbose:
  183. iterator = tqdm(pool.imap_unordered(do_partial, itr, chunksize),
  184. desc='computing GEDs', file=sys.stdout)
  185. else:
  186. iterator = pool.imap_unordered(do_partial, itr, chunksize)
  187. for i, dis_sum, pi_forward in iterator:
  188. pi_forward_list[i] = pi_forward
  189. dis_list[i] = dis_sum
  190. # print('\n-------------------------------------------')
  191. # print(i, j, idx_itr, dis)
  192. pool.close()
  193. pool.join()
  194. else:
  195. dis_list = []
  196. pi_forward_list = []
  197. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  198. file=sys.stdout) if verbose else enumerate(Gn):
  199. dis_sum = 0
  200. pi_forward_list.append([])
  201. for G_p in Gn_median:
  202. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p,
  203. **params_ged)
  204. pi_forward_list[idx].append(pi_tmp_forward)
  205. dis_sum += dis_tmp
  206. dis_list.append(dis_sum)
  207. return dis_list, pi_forward_list
  208. def _compute_ged_median(params_ged, itr):
  209. # print(itr)
  210. dis_sum = 0
  211. pi_forward = []
  212. for G_p in G_gn_median:
  213. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_gn[itr], G_p,
  214. **params_ged)
  215. pi_forward.append(pi_tmp_forward)
  216. dis_sum += dis_tmp
  217. return itr, dis_sum, pi_forward
  218. def get_nb_edit_operations(g1, g2, forward_map, backward_map):
  219. """Compute the number of each edit operations.
  220. """
  221. n_vi = 0
  222. n_vr = 0
  223. n_vs = 0
  224. n_ei = 0
  225. n_er = 0
  226. n_es = 0
  227. nodes1 = [n for n in g1.nodes()]
  228. for i, map_i in enumerate(forward_map):
  229. if map_i == np.inf:
  230. n_vr += 1
  231. elif g1.node[nodes1[i]]['atom'] != g2.node[map_i]['atom']:
  232. n_vs += 1
  233. for map_i in backward_map:
  234. if map_i == np.inf:
  235. n_vi += 1
  236. # idx_nodes1 = range(0, len(node1))
  237. edges1 = [e for e in g1.edges()]
  238. nb_edges2_cnted = 0
  239. for n1, n2 in edges1:
  240. idx1 = nodes1.index(n1)
  241. idx2 = nodes1.index(n2)
  242. # one of the nodes is removed, thus the edge is removed.
  243. if forward_map[idx1] == np.inf or forward_map[idx2] == np.inf:
  244. n_er += 1
  245. # corresponding edge is in g2. Edge label is not considered.
  246. elif (forward_map[idx1], forward_map[idx2]) in g2.edges() or \
  247. (forward_map[idx2], forward_map[idx1]) in g2.edges():
  248. nb_edges2_cnted += 1
  249. # corresponding nodes are in g2, however the edge is removed.
  250. else:
  251. n_er += 1
  252. n_ei = nx.number_of_edges(g2) - nb_edges2_cnted
  253. return n_vi, n_vr, n_vs, n_ei, n_er, n_es

A Python package for graph kernels, graph edit distances and graph pre-image problem.