You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ged.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 17 18:44:59 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. from tqdm import tqdm
  10. import sys
  11. import multiprocessing
  12. from multiprocessing import Pool
  13. from functools import partial
  14. #from gedlibpy_linlin import librariesImport, gedlibpy
  15. from libs import *
  16. def GED(g1, g2, dataset='monoterpenoides', lib='gedlibpy', cost='CHEM_1', method='IPFP',
  17. edit_cost_constant=[], algo_options='', stabilizer='min', repeat=50):
  18. """
  19. Compute GED for 2 graphs.
  20. """
  21. def convertGraph(G, cost):
  22. """Convert a graph to the proper NetworkX format that can be
  23. recognized by library gedlibpy.
  24. """
  25. G_new = nx.Graph()
  26. if cost == 'LETTER' or cost == 'LETTER2':
  27. for nd, attrs in G.nodes(data=True):
  28. G_new.add_node(str(nd), x=str(attrs['attributes'][0]),
  29. y=str(attrs['attributes'][1]))
  30. for nd1, nd2, attrs in G.edges(data=True):
  31. G_new.add_edge(str(nd1), str(nd2))
  32. elif cost == 'NON_SYMBOLIC':
  33. for nd, attrs in G.nodes(data=True):
  34. G_new.add_node(str(nd))
  35. for a_name in G.graph['node_attrs']:
  36. G_new.nodes[str(nd)][a_name] = str(attrs[a_name])
  37. for nd1, nd2, attrs in G.edges(data=True):
  38. G_new.add_edge(str(nd1), str(nd2))
  39. for a_name in G.graph['edge_attrs']:
  40. G_new.edges[str(nd1), str(nd2)][a_name] = str(attrs[a_name])
  41. else:
  42. for nd, attrs in G.nodes(data=True):
  43. G_new.add_node(str(nd), chem=attrs['atom'])
  44. for nd1, nd2, attrs in G.edges(data=True):
  45. G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  46. # G_new.add_edge(str(nd1), str(nd2))
  47. return G_new
  48. # dataset = dataset.lower()
  49. if lib == 'gedlibpy':
  50. gedlibpy.restart_env()
  51. gedlibpy.add_nx_graph(convertGraph(g1, cost), "")
  52. gedlibpy.add_nx_graph(convertGraph(g2, cost), "")
  53. listID = gedlibpy.get_all_graph_ids()
  54. gedlibpy.set_edit_cost(cost, edit_cost_constant=edit_cost_constant)
  55. gedlibpy.init()
  56. gedlibpy.set_method(method, algo_options)
  57. gedlibpy.init_method()
  58. g = listID[0]
  59. h = listID[1]
  60. if stabilizer is None:
  61. gedlibpy.run_method(g, h)
  62. pi_forward = gedlibpy.get_forward_map(g, h)
  63. pi_backward = gedlibpy.get_backward_map(g, h)
  64. upper = gedlibpy.get_upper_bound(g, h)
  65. lower = gedlibpy.get_lower_bound(g, h)
  66. elif stabilizer == 'mean':
  67. # @todo: to be finished...
  68. upper_list = [np.inf] * repeat
  69. for itr in range(repeat):
  70. gedlibpy.run_method(g, h)
  71. upper_list[itr] = gedlibpy.get_upper_bound(g, h)
  72. pi_forward = gedlibpy.get_forward_map(g, h)
  73. pi_backward = gedlibpy.get_backward_map(g, h)
  74. lower = gedlibpy.get_lower_bound(g, h)
  75. upper = np.mean(upper_list)
  76. elif stabilizer == 'median':
  77. if repeat % 2 == 0:
  78. repeat += 1
  79. upper_list = [np.inf] * repeat
  80. pi_forward_list = [0] * repeat
  81. pi_backward_list = [0] * repeat
  82. for itr in range(repeat):
  83. gedlibpy.run_method(g, h)
  84. upper_list[itr] = gedlibpy.get_upper_bound(g, h)
  85. pi_forward_list[itr] = gedlibpy.get_forward_map(g, h)
  86. pi_backward_list[itr] = gedlibpy.get_backward_map(g, h)
  87. lower = gedlibpy.get_lower_bound(g, h)
  88. upper = np.median(upper_list)
  89. idx_median = upper_list.index(upper)
  90. pi_forward = pi_forward_list[idx_median]
  91. pi_backward = pi_backward_list[idx_median]
  92. elif stabilizer == 'min':
  93. upper = np.inf
  94. for itr in range(repeat):
  95. gedlibpy.run_method(g, h)
  96. upper_tmp = gedlibpy.get_upper_bound(g, h)
  97. if upper_tmp < upper:
  98. upper = upper_tmp
  99. pi_forward = gedlibpy.get_forward_map(g, h)
  100. pi_backward = gedlibpy.get_backward_map(g, h)
  101. lower = gedlibpy.get_lower_bound(g, h)
  102. if upper == 0:
  103. break
  104. elif stabilizer == 'max':
  105. upper = 0
  106. for itr in range(repeat):
  107. gedlibpy.run_method(g, h)
  108. upper_tmp = gedlibpy.get_upper_bound(g, h)
  109. if upper_tmp > upper:
  110. upper = upper_tmp
  111. pi_forward = gedlibpy.get_forward_map(g, h)
  112. pi_backward = gedlibpy.get_backward_map(g, h)
  113. lower = gedlibpy.get_lower_bound(g, h)
  114. elif stabilizer == 'gaussian':
  115. pass
  116. dis = upper
  117. elif lib == 'gedlib-bash':
  118. import time
  119. import random
  120. import sys
  121. import os
  122. sys.path.insert(0, "../")
  123. from gklearn.utils.graphfiles import saveDataset
  124. tmp_dir = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/output/tmp_ged/'
  125. if not os.path.exists(tmp_dir):
  126. os.makedirs(tmp_dir)
  127. fn_collection = tmp_dir + 'collection.' + str(time.time()) + str(random.randint(0, 1e9))
  128. xparams = {'method': 'gedlib', 'graph_dir': fn_collection}
  129. saveDataset([g1, g2], ['dummy', 'dummy'], gformat='gxl', group='xml',
  130. filename=fn_collection, xparams=xparams)
  131. command = 'GEDLIB_HOME=\'/media/ljia/DATA/research-repo/codes/others/gedlib/gedlib2\'\n'
  132. command += 'LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$GEDLIB_HOME/lib\n'
  133. command += 'export LD_LIBRARY_PATH\n'
  134. command += 'cd \'/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/bin\'\n'
  135. command += './ged_for_python_bash monoterpenoides ' + fn_collection \
  136. + ' \'' + algo_options + '\' '
  137. for ec in edit_cost_constant:
  138. command += str(ec) + ' '
  139. # output = os.system(command)
  140. stream = os.popen(command)
  141. output = stream.readlines()
  142. # print(output)
  143. dis = float(output[0].strip())
  144. runtime = float(output[1].strip())
  145. size_forward = int(output[2].strip())
  146. pi_forward = [int(item.strip()) for item in output[3:3+size_forward]]
  147. pi_backward = [int(item.strip()) for item in output[3+size_forward:]]
  148. # print(dis)
  149. # print(runtime)
  150. # print(size_forward)
  151. # print(pi_forward)
  152. # print(pi_backward)
  153. # make the map label correct (label remove map as np.inf)
  154. nodes1 = [n for n in g1.nodes()]
  155. nodes2 = [n for n in g2.nodes()]
  156. nb1 = nx.number_of_nodes(g1)
  157. nb2 = nx.number_of_nodes(g2)
  158. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  159. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  160. # print(pi_forward)
  161. return dis, pi_forward, pi_backward
  162. def GED_n(Gn, lib='gedlibpy', cost='CHEM_1', method='IPFP',
  163. edit_cost_constant=[], stabilizer='min', repeat=50):
  164. """
  165. Compute GEDs for a group of graphs.
  166. """
  167. if lib == 'gedlibpy':
  168. def convertGraph(G):
  169. """Convert a graph to the proper NetworkX format that can be
  170. recognized by library gedlibpy.
  171. """
  172. G_new = nx.Graph()
  173. for nd, attrs in G.nodes(data=True):
  174. G_new.add_node(str(nd), chem=attrs['atom'])
  175. for nd1, nd2, attrs in G.edges(data=True):
  176. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  177. G_new.add_edge(str(nd1), str(nd2))
  178. return G_new
  179. gedlibpy.restart_env()
  180. gedlibpy.add_nx_graph(convertGraph(g1), "")
  181. gedlibpy.add_nx_graph(convertGraph(g2), "")
  182. listID = gedlibpy.get_all_graph_ids()
  183. gedlibpy.set_edit_cost(cost, edit_cost_constant=edit_cost_constant)
  184. gedlibpy.init()
  185. gedlibpy.set_method(method, "")
  186. gedlibpy.init_method()
  187. g = listID[0]
  188. h = listID[1]
  189. if stabilizer is None:
  190. gedlibpy.run_method(g, h)
  191. pi_forward = gedlibpy.get_forward_map(g, h)
  192. pi_backward = gedlibpy.get_backward_map(g, h)
  193. upper = gedlibpy.get_upper_bound(g, h)
  194. lower = gedlibpy.get_lower_bound(g, h)
  195. elif stabilizer == 'min':
  196. upper = np.inf
  197. for itr in range(repeat):
  198. gedlibpy.run_method(g, h)
  199. upper_tmp = gedlibpy.get_upper_bound(g, h)
  200. if upper_tmp < upper:
  201. upper = upper_tmp
  202. pi_forward = gedlibpy.get_forward_map(g, h)
  203. pi_backward = gedlibpy.get_backward_map(g, h)
  204. lower = gedlibpy.get_lower_bound(g, h)
  205. if upper == 0:
  206. break
  207. dis = upper
  208. # make the map label correct (label remove map as np.inf)
  209. nodes1 = [n for n in g1.nodes()]
  210. nodes2 = [n for n in g2.nodes()]
  211. nb1 = nx.number_of_nodes(g1)
  212. nb2 = nx.number_of_nodes(g2)
  213. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  214. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  215. return dis, pi_forward, pi_backward
  216. def ged_median(Gn, Gn_median, verbose=False, params_ged={'lib': 'gedlibpy',
  217. 'cost': 'CHEM_1', 'method': 'IPFP', 'edit_cost_constant': [],
  218. 'algo_options': '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1',
  219. 'stabilizer': None}, parallel=False):
  220. if parallel:
  221. len_itr = int(len(Gn))
  222. pi_forward_list = [[] for i in range(len_itr)]
  223. dis_list = [0 for i in range(len_itr)]
  224. itr = range(0, len_itr)
  225. n_jobs = multiprocessing.cpu_count()
  226. if len_itr < 100 * n_jobs:
  227. chunksize = int(len_itr / n_jobs) + 1
  228. else:
  229. chunksize = 100
  230. def init_worker(gn_toshare, gn_median_toshare):
  231. global G_gn, G_gn_median
  232. G_gn = gn_toshare
  233. G_gn_median = gn_median_toshare
  234. do_partial = partial(_compute_ged_median, params_ged)
  235. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(Gn, Gn_median))
  236. if verbose:
  237. iterator = tqdm(pool.imap_unordered(do_partial, itr, chunksize),
  238. desc='computing GEDs', file=sys.stdout)
  239. else:
  240. iterator = pool.imap_unordered(do_partial, itr, chunksize)
  241. for i, dis_sum, pi_forward in iterator:
  242. pi_forward_list[i] = pi_forward
  243. dis_list[i] = dis_sum
  244. # print('\n-------------------------------------------')
  245. # print(i, j, idx_itr, dis)
  246. pool.close()
  247. pool.join()
  248. else:
  249. dis_list = []
  250. pi_forward_list = []
  251. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  252. file=sys.stdout) if verbose else enumerate(Gn):
  253. dis_sum = 0
  254. pi_forward_list.append([])
  255. for G_p in Gn_median:
  256. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p,
  257. **params_ged)
  258. pi_forward_list[idx].append(pi_tmp_forward)
  259. dis_sum += dis_tmp
  260. dis_list.append(dis_sum)
  261. return dis_list, pi_forward_list
  262. def _compute_ged_median(params_ged, itr):
  263. # print(itr)
  264. dis_sum = 0
  265. pi_forward = []
  266. for G_p in G_gn_median:
  267. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_gn[itr], G_p,
  268. **params_ged)
  269. pi_forward.append(pi_tmp_forward)
  270. dis_sum += dis_tmp
  271. return itr, dis_sum, pi_forward
  272. def get_nb_edit_operations(g1, g2, forward_map, backward_map):
  273. """Compute the number of each edit operations.
  274. """
  275. n_vi = 0
  276. n_vr = 0
  277. n_vs = 0
  278. n_ei = 0
  279. n_er = 0
  280. n_es = 0
  281. nodes1 = [n for n in g1.nodes()]
  282. for i, map_i in enumerate(forward_map):
  283. if map_i == np.inf:
  284. n_vr += 1
  285. elif g1.node[nodes1[i]]['atom'] != g2.node[map_i]['atom']:
  286. n_vs += 1
  287. for map_i in backward_map:
  288. if map_i == np.inf:
  289. n_vi += 1
  290. # idx_nodes1 = range(0, len(node1))
  291. edges1 = [e for e in g1.edges()]
  292. nb_edges2_cnted = 0
  293. for n1, n2 in edges1:
  294. idx1 = nodes1.index(n1)
  295. idx2 = nodes1.index(n2)
  296. # one of the nodes is removed, thus the edge is removed.
  297. if forward_map[idx1] == np.inf or forward_map[idx2] == np.inf:
  298. n_er += 1
  299. # corresponding edge is in g2.
  300. elif (forward_map[idx1], forward_map[idx2]) in g2.edges():
  301. nb_edges2_cnted += 1
  302. # edge labels are different.
  303. if g2.edges[((forward_map[idx1], forward_map[idx2]))]['bond_type'] \
  304. != g1.edges[(n1, n2)]['bond_type']:
  305. n_es += 1
  306. elif (forward_map[idx2], forward_map[idx1]) in g2.edges():
  307. nb_edges2_cnted += 1
  308. # edge labels are different.
  309. if g2.edges[((forward_map[idx2], forward_map[idx1]))]['bond_type'] \
  310. != g1.edges[(n1, n2)]['bond_type']:
  311. n_es += 1
  312. # corresponding nodes are in g2, however the edge is removed.
  313. else:
  314. n_er += 1
  315. n_ei = nx.number_of_edges(g2) - nb_edges2_cnted
  316. return n_vi, n_vr, n_vs, n_ei, n_er, n_es
  317. def get_nb_edit_operations_letter(g1, g2, forward_map, backward_map):
  318. """Compute the number of each edit operations.
  319. """
  320. n_vi = 0
  321. n_vr = 0
  322. n_vs = 0
  323. sod_vs = 0
  324. n_ei = 0
  325. n_er = 0
  326. nodes1 = [n for n in g1.nodes()]
  327. for i, map_i in enumerate(forward_map):
  328. if map_i == np.inf:
  329. n_vr += 1
  330. else:
  331. n_vs += 1
  332. diff_x = float(g1.nodes[nodes1[i]]['x']) - float(g2.nodes[map_i]['x'])
  333. diff_y = float(g1.nodes[nodes1[i]]['y']) - float(g2.nodes[map_i]['y'])
  334. sod_vs += np.sqrt(np.square(diff_x) + np.square(diff_y))
  335. for map_i in backward_map:
  336. if map_i == np.inf:
  337. n_vi += 1
  338. # idx_nodes1 = range(0, len(node1))
  339. edges1 = [e for e in g1.edges()]
  340. nb_edges2_cnted = 0
  341. for n1, n2 in edges1:
  342. idx1 = nodes1.index(n1)
  343. idx2 = nodes1.index(n2)
  344. # one of the nodes is removed, thus the edge is removed.
  345. if forward_map[idx1] == np.inf or forward_map[idx2] == np.inf:
  346. n_er += 1
  347. # corresponding edge is in g2. Edge label is not considered.
  348. elif (forward_map[idx1], forward_map[idx2]) in g2.edges() or \
  349. (forward_map[idx2], forward_map[idx1]) in g2.edges():
  350. nb_edges2_cnted += 1
  351. # corresponding nodes are in g2, however the edge is removed.
  352. else:
  353. n_er += 1
  354. n_ei = nx.number_of_edges(g2) - nb_edges2_cnted
  355. return n_vi, n_vr, n_vs, sod_vs, n_ei, n_er
  356. def get_nb_edit_operations_nonsymbolic(g1, g2, forward_map, backward_map):
  357. """Compute the number of each edit operations.
  358. """
  359. n_vi = 0
  360. n_vr = 0
  361. n_vs = 0
  362. sod_vs = 0
  363. n_ei = 0
  364. n_er = 0
  365. n_es = 0
  366. sod_es = 0
  367. nodes1 = [n for n in g1.nodes()]
  368. for i, map_i in enumerate(forward_map):
  369. if map_i == np.inf:
  370. n_vr += 1
  371. else:
  372. n_vs += 1
  373. sum_squares = 0
  374. for a_name in g1.graph['node_attrs']:
  375. diff = float(g1.nodes[nodes1[i]][a_name]) - float(g2.nodes[map_i][a_name])
  376. sum_squares += np.square(diff)
  377. sod_vs += np.sqrt(sum_squares)
  378. for map_i in backward_map:
  379. if map_i == np.inf:
  380. n_vi += 1
  381. # idx_nodes1 = range(0, len(node1))
  382. edges1 = [e for e in g1.edges()]
  383. for n1, n2 in edges1:
  384. idx1 = nodes1.index(n1)
  385. idx2 = nodes1.index(n2)
  386. n1_g2 = forward_map[idx1]
  387. n2_g2 = forward_map[idx2]
  388. # one of the nodes is removed, thus the edge is removed.
  389. if n1_g2 == np.inf or n2_g2 == np.inf:
  390. n_er += 1
  391. # corresponding edge is in g2.
  392. elif (n1_g2, n2_g2) in g2.edges():
  393. n_es += 1
  394. sum_squares = 0
  395. for a_name in g1.graph['edge_attrs']:
  396. diff = float(g1.edges[n1, n2][a_name]) - float(g2.nodes[n1_g2, n2_g2][a_name])
  397. sum_squares += np.square(diff)
  398. sod_es += np.sqrt(sum_squares)
  399. elif (n2_g2, n1_g2) in g2.edges():
  400. n_es += 1
  401. sum_squares = 0
  402. for a_name in g1.graph['edge_attrs']:
  403. diff = float(g1.edges[n2, n1][a_name]) - float(g2.nodes[n2_g2, n1_g2][a_name])
  404. sum_squares += np.square(diff)
  405. sod_es += np.sqrt(sum_squares)
  406. # corresponding nodes are in g2, however the edge is removed.
  407. else:
  408. n_er += 1
  409. n_ei = nx.number_of_edges(g2) - n_es
  410. return n_vi, n_vr, sod_vs, n_ei, n_er, sod_es
  411. if __name__ == '__main__':
  412. print('check test_ged.py')

A Python package for graph kernels, graph edit distances and graph pre-image problem.