You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. from tqdm import tqdm
  12. import sys
  13. from gedlibpy import librariesImport, gedlibpy
  14. sys.path.insert(0, "../")
  15. from pygraph.utils.graphdataset import get_dataset_attributes
  16. from pygraph.utils.utils import graph_isIdentical, get_node_labels, get_edge_labels
  17. def iam_upgraded(Gn_median, Gn_candidate, c_ei=3, c_er=3, c_es=1, ite_max=50,
  18. epsilon=0.001, node_label='atom', edge_label='bond_type',
  19. connected=False, removeNodes=True, allBestInit=False, allBestNodes=False,
  20. allBestEdges=False, allBestOutput=False,
  21. params_ged={'ged_cost': 'CHEM_1', 'ged_method': 'IPFP', 'saveGXL': 'benoit'}):
  22. """See my name, then you know what I do.
  23. """
  24. # Gn_median = Gn_median[0:10]
  25. # Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
  26. if removeNodes:
  27. node_ir = np.inf # corresponding to the node remove and insertion.
  28. label_r = 'thanksdanny' # the label for node remove. # @todo: make this label unrepeatable.
  29. ds_attrs = get_dataset_attributes(Gn_median + Gn_candidate,
  30. attr_names=['edge_labeled', 'node_attr_dim', 'edge_attr_dim'],
  31. edge_label=edge_label)
  32. def generate_graph(G, pi_p_forward, label_set):
  33. G_new_list = [G.copy()] # all "best" graphs generated in this iteration.
  34. # nx.draw_networkx(G)
  35. # import matplotlib.pyplot as plt
  36. # plt.show()
  37. # print(pi_p_forward)
  38. # update vertex labels.
  39. # pre-compute h_i0 for each label.
  40. # for label in get_node_labels(Gn, node_label):
  41. # print(label)
  42. # for nd in G.nodes(data=True):
  43. # pass
  44. if not ds_attrs['node_attr_dim']: # labels are symbolic
  45. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  46. h_i0_list = []
  47. label_list = []
  48. for label in label_set:
  49. h_i0 = 0
  50. for idx, g in enumerate(Gn_median):
  51. pi_i = pi_p_forward[idx][ndi]
  52. if pi_i != node_ir and g.nodes[pi_i][node_label] == label:
  53. h_i0 += 1
  54. h_i0_list.append(h_i0)
  55. label_list.append(label)
  56. # case when the node is to be removed.
  57. if removeNodes:
  58. h_i0_remove = 0 # @todo: maybe this can be added to the label_set above.
  59. for idx, g in enumerate(Gn_median):
  60. pi_i = pi_p_forward[idx][ndi]
  61. if pi_i == node_ir:
  62. h_i0_remove += 1
  63. h_i0_list.append(h_i0_remove)
  64. label_list.append(label_r)
  65. # get the best labels.
  66. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  67. if allBestNodes: # choose all best graphs.
  68. nlabel_best = [label_list[idx] for idx in idx_max]
  69. # generate "best" graphs with regard to "best" node labels.
  70. G_new_list_nd = []
  71. for g in G_new_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  72. for nl in nlabel_best:
  73. g_tmp = g.copy()
  74. if nl == label_r:
  75. g_tmp.remove_node(nd)
  76. else:
  77. g_tmp.nodes[nd][node_label] = nl
  78. G_new_list_nd.append(g_tmp)
  79. # nx.draw_networkx(g_tmp)
  80. # import matplotlib.pyplot as plt
  81. # plt.show()
  82. # print(g_tmp.nodes(data=True))
  83. # print(g_tmp.edges(data=True))
  84. G_new_list = [ggg.copy() for ggg in G_new_list_nd]
  85. else:
  86. # choose one of the best randomly.
  87. h_ij0_max = h_i0_list[idx_max[0]]
  88. idx_rdm = random.randint(0, len(idx_max) - 1)
  89. best_label = label_list[idx_max[idx_rdm]]
  90. # check whether a_ij is 0 or 1.
  91. g_new = G_new_list[0]
  92. if best_label == label_r:
  93. g_new.remove_node(nd)
  94. else:
  95. g_new.nodes[nd][node_label] = best_label
  96. G_new_list = [g_new]
  97. else: # labels are non-symbolic
  98. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  99. Si_norm = 0
  100. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  101. for idx, g in enumerate(Gn_median):
  102. pi_i = pi_p_forward[idx][ndi]
  103. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  104. Si_norm += 1
  105. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  106. phi_i_bar /= Si_norm
  107. G_new_list[0].nodes[nd]['attributes'] = phi_i_bar
  108. # for g in G_new_list:
  109. # import matplotlib.pyplot as plt
  110. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  111. # plt.show()
  112. # print(g.nodes(data=True))
  113. # print(g.edges(data=True))
  114. # update edge labels and adjacency matrix.
  115. if ds_attrs['edge_labeled']:
  116. G_new_list_edge = []
  117. for g_new in G_new_list:
  118. nd_list = [n for n in g_new.nodes()]
  119. g_tmp_list = [g_new.copy()]
  120. for nd1i in range(nx.number_of_nodes(g_new)):
  121. nd1 = nd_list[nd1i]# @todo: not just edges, but all pairs of nodes
  122. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_new)):
  123. nd2 = nd_list[nd2i]
  124. # for nd1, nd2, _ in g_new.edges(data=True):
  125. h_ij0_list = []
  126. label_list = []
  127. # @todo: compute edge label set before.
  128. for label in get_edge_labels(Gn_median, edge_label):
  129. h_ij0 = 0
  130. for idx, g in enumerate(Gn_median):
  131. pi_i = pi_p_forward[idx][nd1i]
  132. pi_j = pi_p_forward[idx][nd2i]
  133. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  134. g.has_edge(pi_i, pi_j) and
  135. g.edges[pi_i, pi_j][edge_label] == label)
  136. h_ij0 += h_ij0_p
  137. h_ij0_list.append(h_ij0)
  138. label_list.append(label)
  139. # get the best labels.
  140. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  141. if allBestEdges: # choose all best graphs.
  142. elabel_best = [label_list[idx] for idx in idx_max]
  143. h_ij0_max = [h_ij0_list[idx] for idx in idx_max]
  144. # generate "best" graphs with regard to "best" node labels.
  145. G_new_list_ed = []
  146. for g_tmp in g_tmp_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  147. for idxl, el in enumerate(elabel_best):
  148. g_tmp_copy = g_tmp.copy()
  149. # check whether a_ij is 0 or 1.
  150. sij_norm = 0
  151. for idx, g in enumerate(Gn_median):
  152. pi_i = pi_p_forward[idx][nd1i]
  153. pi_j = pi_p_forward[idx][nd2i]
  154. if g.has_node(pi_i) and g.has_node(pi_j) and \
  155. g.has_edge(pi_i, pi_j):
  156. sij_norm += 1
  157. if h_ij0_max[idxl] > len(Gn_median) * c_er / c_es + \
  158. sij_norm * (1 - (c_er + c_ei) / c_es):
  159. if not g_tmp_copy.has_edge(nd1, nd2):
  160. g_tmp_copy.add_edge(nd1, nd2)
  161. g_tmp_copy.edges[nd1, nd2][edge_label] = elabel_best[idxl]
  162. else:
  163. if g_tmp_copy.has_edge(nd1, nd2):
  164. g_tmp_copy.remove_edge(nd1, nd2)
  165. G_new_list_ed.append(g_tmp_copy)
  166. g_tmp_list = [ggg.copy() for ggg in G_new_list_ed]
  167. else: # choose one of the best randomly.
  168. h_ij0_max = h_ij0_list[idx_max[0]]
  169. idx_rdm = random.randint(0, len(idx_max) - 1)
  170. best_label = label_list[idx_max[idx_rdm]]
  171. # check whether a_ij is 0 or 1.
  172. sij_norm = 0
  173. for idx, g in enumerate(Gn_median):
  174. pi_i = pi_p_forward[idx][nd1i]
  175. pi_j = pi_p_forward[idx][nd2i]
  176. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  177. sij_norm += 1
  178. if h_ij0_max > len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  179. if not g_new.has_edge(nd1, nd2):
  180. g_new.add_edge(nd1, nd2)
  181. g_new.edges[nd1, nd2][edge_label] = best_label
  182. else:
  183. if g_new.has_edge(nd1, nd2):
  184. g_new.remove_edge(nd1, nd2)
  185. g_tmp_list = [g_new]
  186. G_new_list_edge += g_tmp_list
  187. G_new_list = [ggg.copy() for ggg in G_new_list_edge]
  188. else: # if edges are unlabeled
  189. # @todo: is this even right? G or g_tmp? check if the new one is right
  190. # @todo: works only for undirected graphs.
  191. for g_tmp in G_new_list:
  192. nd_list = [n for n in g_tmp.nodes()]
  193. for nd1i in range(nx.number_of_nodes(g_tmp)):
  194. nd1 = nd_list[nd1i]
  195. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_tmp)):
  196. nd2 = nd_list[nd2i]
  197. sij_norm = 0
  198. for idx, g in enumerate(Gn_median):
  199. pi_i = pi_p_forward[idx][nd1i]
  200. pi_j = pi_p_forward[idx][nd2i]
  201. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  202. sij_norm += 1
  203. if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
  204. # @todo: should we consider if nd1 and nd2 in g_tmp?
  205. # or just add the edge anyway?
  206. if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
  207. and not g_tmp.has_edge(nd1, nd2):
  208. g_tmp.add_edge(nd1, nd2)
  209. # else: # @todo: which to use?
  210. elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
  211. if g_tmp.has_edge(nd1, nd2):
  212. g_tmp.remove_edge(nd1, nd2)
  213. # do not change anything when equal.
  214. # for i, g in enumerate(G_new_list):
  215. # import matplotlib.pyplot as plt
  216. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  217. ## plt.savefig("results/gk_iam/simple_two/xx" + str(i) + ".png", format="PNG")
  218. # plt.show()
  219. # print(g.nodes(data=True))
  220. # print(g.edges(data=True))
  221. # # find the best graph generated in this iteration and update pi_p.
  222. # @todo: should we update all graphs generated or just the best ones?
  223. dis_list, pi_forward_list = median_distance(G_new_list, Gn_median,
  224. **params_ged)
  225. # @todo: should we remove the identical and connectivity check?
  226. # Don't know which is faster.
  227. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  228. G_new_list, idx_list = remove_duplicates(G_new_list)
  229. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  230. dis_list = [dis_list[idx] for idx in idx_list]
  231. # if connected == True:
  232. # G_new_list, idx_list = remove_disconnected(G_new_list)
  233. # pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  234. # idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  235. # dis_min = dis_list[idx_min_tmp_list[0]]
  236. # pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
  237. # G_new_list = [G_new_list[idx] for idx in idx_min_list]
  238. # for g in G_new_list:
  239. # import matplotlib.pyplot as plt
  240. # nx.draw_networkx(g)
  241. # plt.show()
  242. # print(g.nodes(data=True))
  243. # print(g.edges(data=True))
  244. return G_new_list, pi_forward_list, dis_list
  245. def best_median_graphs(Gn_candidate, pi_all_forward, dis_all):
  246. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  247. dis_min = dis_all[idx_min_list[0]]
  248. pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
  249. G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
  250. return G_min_list, pi_forward_min_list, dis_min
  251. def iteration_proc(G, pi_p_forward, cur_sod):
  252. G_list = [G]
  253. pi_forward_list = [pi_p_forward]
  254. old_sod = cur_sod * 2
  255. sod_list = [cur_sod]
  256. dis_list = [cur_sod]
  257. # iterations.
  258. itr = 0
  259. # @todo: what if difference == 0?
  260. # while itr < ite_max and (np.abs(old_sod - cur_sod) > epsilon or
  261. # np.abs(old_sod - cur_sod) == 0):
  262. while itr < ite_max and np.abs(old_sod - cur_sod) > epsilon:
  263. # for itr in range(0, 5): # the convergence condition?
  264. print('itr_iam is', itr)
  265. G_new_list = []
  266. pi_forward_new_list = []
  267. dis_new_list = []
  268. for idx, g in enumerate(G_list):
  269. label_set = get_node_labels(Gn_median + [g], node_label)
  270. G_tmp_list, pi_forward_tmp_list, dis_tmp_list = generate_graph(
  271. g, pi_forward_list[idx], label_set)
  272. G_new_list += G_tmp_list
  273. pi_forward_new_list += pi_forward_tmp_list
  274. dis_new_list += dis_tmp_list
  275. # @todo: need to remove duplicates here?
  276. G_list = [ggg.copy() for ggg in G_new_list]
  277. pi_forward_list = [pitem.copy() for pitem in pi_forward_new_list]
  278. dis_list = dis_new_list[:]
  279. old_sod = cur_sod
  280. cur_sod = np.min(dis_list)
  281. sod_list.append(cur_sod)
  282. itr += 1
  283. # @todo: do we return all graphs or the best ones?
  284. # get the best ones of the generated graphs.
  285. G_list, pi_forward_list, dis_min = best_median_graphs(
  286. G_list, pi_forward_list, dis_list)
  287. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  288. G_list, idx_list = remove_duplicates(G_list)
  289. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  290. # dis_list = [dis_list[idx] for idx in idx_list]
  291. # import matplotlib.pyplot as plt
  292. # for g in G_list:
  293. # nx.draw_networkx(g)
  294. # plt.show()
  295. # print(g.nodes(data=True))
  296. # print(g.edges(data=True))
  297. print('\nsods:', sod_list, '\n')
  298. return G_list, pi_forward_list, dis_min
  299. def remove_duplicates(Gn):
  300. """Remove duplicate graphs from list.
  301. """
  302. Gn_new = []
  303. idx_list = []
  304. for idx, g in enumerate(Gn):
  305. dupl = False
  306. for g_new in Gn_new:
  307. if graph_isIdentical(g_new, g):
  308. dupl = True
  309. break
  310. if not dupl:
  311. Gn_new.append(g)
  312. idx_list.append(idx)
  313. return Gn_new, idx_list
  314. def remove_disconnected(Gn):
  315. """Remove disconnected graphs from list.
  316. """
  317. Gn_new = []
  318. idx_list = []
  319. for idx, g in enumerate(Gn):
  320. if nx.is_connected(g):
  321. Gn_new.append(g)
  322. idx_list.append(idx)
  323. return Gn_new, idx_list
  324. ###########################################################################
  325. # phase 1: initilize.
  326. # compute set-median.
  327. dis_min = np.inf
  328. dis_list, pi_forward_all = median_distance(Gn_candidate, Gn_median,
  329. **params_ged)
  330. # find all smallest distances.
  331. if allBestInit: # try all best init graphs.
  332. idx_min_list = range(len(dis_list))
  333. dis_min = dis_list
  334. else:
  335. idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  336. dis_min = [dis_list[idx_min_list[0]]] * len(idx_min_list)
  337. # phase 2: iteration.
  338. G_list = []
  339. dis_list = []
  340. pi_forward_list = []
  341. for idx_tmp, idx_min in enumerate(idx_min_list):
  342. # print('idx_min is', idx_min)
  343. G = Gn_candidate[idx_min].copy()
  344. # list of edit operations.
  345. pi_p_forward = pi_forward_all[idx_min]
  346. # pi_p_backward = pi_all_backward[idx_min]
  347. Gi_list, pi_i_forward_list, dis_i_min = iteration_proc(G, pi_p_forward, dis_min[idx_tmp])
  348. G_list += Gi_list
  349. dis_list += [dis_i_min] * len(Gi_list)
  350. pi_forward_list += pi_i_forward_list
  351. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  352. G_list, idx_list = remove_duplicates(G_list)
  353. dis_list = [dis_list[idx] for idx in idx_list]
  354. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  355. if connected == True:
  356. G_list_con, idx_list = remove_disconnected(G_list)
  357. # if there is no connected graphs at all, then remain the disconnected ones.
  358. if len(G_list_con) > 0: # @todo: ??????????????????????????
  359. G_list = G_list_con
  360. dis_list = [dis_list[idx] for idx in idx_list]
  361. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  362. # import matplotlib.pyplot as plt
  363. # for g in G_list:
  364. # nx.draw_networkx(g)
  365. # plt.show()
  366. # print(g.nodes(data=True))
  367. # print(g.edges(data=True))
  368. # get the best median graphs
  369. G_min_list, pi_forward_min_list, dis_min = best_median_graphs(
  370. G_list, pi_forward_list, dis_list)
  371. # for g in G_min_list:
  372. # nx.draw_networkx(g)
  373. # plt.show()
  374. # print(g.nodes(data=True))
  375. # print(g.edges(data=True))
  376. if not allBestOutput:
  377. # randomly choose one graph.
  378. idx_rdm = random.randint(0, len(G_min_list) - 1)
  379. G_min_list = [G_min_list[idx_rdm]]
  380. return G_min_list, dis_min
  381. ###############################################################################
  382. # Useful functions.
  383. def GED(g1, g2, lib='gedlibpy', cost='CHEM_1', method='IPFP', saveGXL='benoit',
  384. stabilizer='min'):
  385. """
  386. Compute GED.
  387. """
  388. if lib == 'gedlibpy':
  389. def convertGraph(G):
  390. """Convert a graph to the proper NetworkX format that can be
  391. recognized by library gedlibpy.
  392. """
  393. G_new = nx.Graph()
  394. for nd, attrs in G.nodes(data=True):
  395. G_new.add_node(str(nd), chem=attrs['atom'])
  396. for nd1, nd2, attrs in G.edges(data=True):
  397. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  398. G_new.add_edge(str(nd1), str(nd2))
  399. return G_new
  400. gedlibpy.restart_env()
  401. gedlibpy.add_nx_graph(convertGraph(g1), "")
  402. gedlibpy.add_nx_graph(convertGraph(g2), "")
  403. listID = gedlibpy.get_all_graph_ids()
  404. gedlibpy.set_edit_cost(cost)
  405. gedlibpy.init()
  406. gedlibpy.set_method(method, "")
  407. gedlibpy.init_method()
  408. g = listID[0]
  409. h = listID[1]
  410. if stabilizer == None:
  411. gedlibpy.run_method(g, h)
  412. pi_forward = gedlibpy.get_forward_map(g, h)
  413. pi_backward = gedlibpy.get_backward_map(g, h)
  414. upper = gedlibpy.get_upper_bound(g, h)
  415. lower = gedlibpy.get_lower_bound(g, h)
  416. elif stabilizer == 'min':
  417. upper = np.inf
  418. for itr in range(50):
  419. gedlibpy.run_method(g, h)
  420. upper_tmp = gedlibpy.get_upper_bound(g, h)
  421. if upper_tmp < upper:
  422. upper = upper_tmp
  423. pi_forward = gedlibpy.get_forward_map(g, h)
  424. pi_backward = gedlibpy.get_backward_map(g, h)
  425. lower = gedlibpy.get_lower_bound(g, h)
  426. if upper == 0:
  427. break
  428. dis = upper
  429. # make the map label correct (label remove map as np.inf)
  430. nodes1 = [n for n in g1.nodes()]
  431. nodes2 = [n for n in g2.nodes()]
  432. nb1 = nx.number_of_nodes(g1)
  433. nb2 = nx.number_of_nodes(g2)
  434. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  435. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  436. return dis, pi_forward, pi_backward
  437. def median_distance(Gn, Gn_median, measure='ged', verbose=False,
  438. ged_cost='CHEM_1', ged_method='IPFP', saveGXL='benoit'):
  439. dis_list = []
  440. pi_forward_list = []
  441. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  442. file=sys.stdout) if verbose else enumerate(Gn):
  443. dis_sum = 0
  444. pi_forward_list.append([])
  445. for G_p in Gn_median:
  446. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p,
  447. cost=ged_cost, method=ged_method, saveGXL=saveGXL)
  448. pi_forward_list[idx].append(pi_tmp_forward)
  449. dis_sum += dis_tmp
  450. dis_list.append(dis_sum)
  451. return dis_list, pi_forward_list
  452. ###############################################################################
  453. # Old implementations.
  454. def iam(Gn, c_ei=3, c_er=3, c_es=1, node_label='atom', edge_label='bond_type',
  455. connected=True):
  456. """See my name, then you know what I do.
  457. """
  458. # Gn = Gn[0:10]
  459. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  460. # phase 1: initilize.
  461. # compute set-median.
  462. dis_min = np.inf
  463. pi_p = []
  464. pi_all = []
  465. for idx1, G_p in enumerate(Gn):
  466. dist_sum = 0
  467. pi_all.append([])
  468. for idx2, G_p_prime in enumerate(Gn):
  469. dist_tmp, pi_tmp, _ = GED(G_p, G_p_prime)
  470. pi_all[idx1].append(pi_tmp)
  471. dist_sum += dist_tmp
  472. if dist_sum < dis_min:
  473. dis_min = dist_sum
  474. G = G_p.copy()
  475. idx_min = idx1
  476. # list of edit operations.
  477. pi_p = pi_all[idx_min]
  478. # phase 2: iteration.
  479. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  480. edge_label=edge_label)
  481. for itr in range(0, 10): # @todo: the convergence condition?
  482. G_new = G.copy()
  483. # update vertex labels.
  484. # pre-compute h_i0 for each label.
  485. # for label in get_node_labels(Gn, node_label):
  486. # print(label)
  487. # for nd in G.nodes(data=True):
  488. # pass
  489. if not ds_attrs['node_attr_dim']: # labels are symbolic
  490. for nd, _ in G.nodes(data=True):
  491. h_i0_list = []
  492. label_list = []
  493. for label in get_node_labels(Gn, node_label):
  494. h_i0 = 0
  495. for idx, g in enumerate(Gn):
  496. pi_i = pi_p[idx][nd]
  497. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  498. h_i0 += 1
  499. h_i0_list.append(h_i0)
  500. label_list.append(label)
  501. # choose one of the best randomly.
  502. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  503. idx_rdm = random.randint(0, len(idx_max) - 1)
  504. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  505. else: # labels are non-symbolic
  506. for nd, _ in G.nodes(data=True):
  507. Si_norm = 0
  508. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  509. for idx, g in enumerate(Gn):
  510. pi_i = pi_p[idx][nd]
  511. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  512. Si_norm += 1
  513. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  514. phi_i_bar /= Si_norm
  515. G_new.nodes[nd]['attributes'] = phi_i_bar
  516. # update edge labels and adjacency matrix.
  517. if ds_attrs['edge_labeled']:
  518. for nd1, nd2, _ in G.edges(data=True):
  519. h_ij0_list = []
  520. label_list = []
  521. for label in get_edge_labels(Gn, edge_label):
  522. h_ij0 = 0
  523. for idx, g in enumerate(Gn):
  524. pi_i = pi_p[idx][nd1]
  525. pi_j = pi_p[idx][nd2]
  526. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  527. g.has_edge(pi_i, pi_j) and
  528. g.edges[pi_i, pi_j][edge_label] == label)
  529. h_ij0 += h_ij0_p
  530. h_ij0_list.append(h_ij0)
  531. label_list.append(label)
  532. # choose one of the best randomly.
  533. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  534. h_ij0_max = h_ij0_list[idx_max[0]]
  535. idx_rdm = random.randint(0, len(idx_max) - 1)
  536. best_label = label_list[idx_max[idx_rdm]]
  537. # check whether a_ij is 0 or 1.
  538. sij_norm = 0
  539. for idx, g in enumerate(Gn):
  540. pi_i = pi_p[idx][nd1]
  541. pi_j = pi_p[idx][nd2]
  542. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  543. sij_norm += 1
  544. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  545. if not G_new.has_edge(nd1, nd2):
  546. G_new.add_edge(nd1, nd2)
  547. G_new.edges[nd1, nd2][edge_label] = best_label
  548. else:
  549. if G_new.has_edge(nd1, nd2):
  550. G_new.remove_edge(nd1, nd2)
  551. else: # if edges are unlabeled
  552. for nd1, nd2, _ in G.edges(data=True):
  553. sij_norm = 0
  554. for idx, g in enumerate(Gn):
  555. pi_i = pi_p[idx][nd1]
  556. pi_j = pi_p[idx][nd2]
  557. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  558. sij_norm += 1
  559. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  560. if not G_new.has_edge(nd1, nd2):
  561. G_new.add_edge(nd1, nd2)
  562. else:
  563. if G_new.has_edge(nd1, nd2):
  564. G_new.remove_edge(nd1, nd2)
  565. G = G_new.copy()
  566. # update pi_p
  567. pi_p = []
  568. for idx1, G_p in enumerate(Gn):
  569. dist_tmp, pi_tmp, _ = GED(G, G_p)
  570. pi_p.append(pi_tmp)
  571. return G
  572. # --------------------------- These are tests --------------------------------#
  573. def test_iam_with_more_graphs_as_init(Gn, G_candidate, c_ei=3, c_er=3, c_es=1,
  574. node_label='atom', edge_label='bond_type'):
  575. """See my name, then you know what I do.
  576. """
  577. # Gn = Gn[0:10]
  578. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  579. # phase 1: initilize.
  580. # compute set-median.
  581. dis_min = np.inf
  582. # pi_p = []
  583. pi_all_forward = []
  584. pi_all_backward = []
  585. for idx1, G_p in tqdm(enumerate(G_candidate), desc='computing GEDs', file=sys.stdout):
  586. dist_sum = 0
  587. pi_all_forward.append([])
  588. pi_all_backward.append([])
  589. for idx2, G_p_prime in enumerate(Gn):
  590. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_p, G_p_prime)
  591. pi_all_forward[idx1].append(pi_tmp_forward)
  592. pi_all_backward[idx1].append(pi_tmp_backward)
  593. dist_sum += dist_tmp
  594. if dist_sum <= dis_min:
  595. dis_min = dist_sum
  596. G = G_p.copy()
  597. idx_min = idx1
  598. # list of edit operations.
  599. pi_p_forward = pi_all_forward[idx_min]
  600. pi_p_backward = pi_all_backward[idx_min]
  601. # phase 2: iteration.
  602. ds_attrs = get_dataset_attributes(Gn + [G], attr_names=['edge_labeled', 'node_attr_dim'],
  603. edge_label=edge_label)
  604. label_set = get_node_labels(Gn + [G], node_label)
  605. for itr in range(0, 10): # @todo: the convergence condition?
  606. G_new = G.copy()
  607. # update vertex labels.
  608. # pre-compute h_i0 for each label.
  609. # for label in get_node_labels(Gn, node_label):
  610. # print(label)
  611. # for nd in G.nodes(data=True):
  612. # pass
  613. if not ds_attrs['node_attr_dim']: # labels are symbolic
  614. for nd in G.nodes():
  615. h_i0_list = []
  616. label_list = []
  617. for label in label_set:
  618. h_i0 = 0
  619. for idx, g in enumerate(Gn):
  620. pi_i = pi_p_forward[idx][nd]
  621. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  622. h_i0 += 1
  623. h_i0_list.append(h_i0)
  624. label_list.append(label)
  625. # choose one of the best randomly.
  626. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  627. idx_rdm = random.randint(0, len(idx_max) - 1)
  628. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  629. else: # labels are non-symbolic
  630. for nd in G.nodes():
  631. Si_norm = 0
  632. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  633. for idx, g in enumerate(Gn):
  634. pi_i = pi_p_forward[idx][nd]
  635. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  636. Si_norm += 1
  637. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  638. phi_i_bar /= Si_norm
  639. G_new.nodes[nd]['attributes'] = phi_i_bar
  640. # update edge labels and adjacency matrix.
  641. if ds_attrs['edge_labeled']:
  642. for nd1, nd2, _ in G.edges(data=True):
  643. h_ij0_list = []
  644. label_list = []
  645. for label in get_edge_labels(Gn, edge_label):
  646. h_ij0 = 0
  647. for idx, g in enumerate(Gn):
  648. pi_i = pi_p_forward[idx][nd1]
  649. pi_j = pi_p_forward[idx][nd2]
  650. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  651. g.has_edge(pi_i, pi_j) and
  652. g.edges[pi_i, pi_j][edge_label] == label)
  653. h_ij0 += h_ij0_p
  654. h_ij0_list.append(h_ij0)
  655. label_list.append(label)
  656. # choose one of the best randomly.
  657. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  658. h_ij0_max = h_ij0_list[idx_max[0]]
  659. idx_rdm = random.randint(0, len(idx_max) - 1)
  660. best_label = label_list[idx_max[idx_rdm]]
  661. # check whether a_ij is 0 or 1.
  662. sij_norm = 0
  663. for idx, g in enumerate(Gn):
  664. pi_i = pi_p_forward[idx][nd1]
  665. pi_j = pi_p_forward[idx][nd2]
  666. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  667. sij_norm += 1
  668. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  669. if not G_new.has_edge(nd1, nd2):
  670. G_new.add_edge(nd1, nd2)
  671. G_new.edges[nd1, nd2][edge_label] = best_label
  672. else:
  673. if G_new.has_edge(nd1, nd2):
  674. G_new.remove_edge(nd1, nd2)
  675. else: # if edges are unlabeled
  676. # @todo: works only for undirected graphs.
  677. for nd1 in range(nx.number_of_nodes(G)):
  678. for nd2 in range(nd1 + 1, nx.number_of_nodes(G)):
  679. sij_norm = 0
  680. for idx, g in enumerate(Gn):
  681. pi_i = pi_p_forward[idx][nd1]
  682. pi_j = pi_p_forward[idx][nd2]
  683. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  684. sij_norm += 1
  685. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  686. if not G_new.has_edge(nd1, nd2):
  687. G_new.add_edge(nd1, nd2)
  688. elif sij_norm < len(Gn) * c_er / (c_er + c_ei):
  689. if G_new.has_edge(nd1, nd2):
  690. G_new.remove_edge(nd1, nd2)
  691. # do not change anything when equal.
  692. G = G_new.copy()
  693. # update pi_p
  694. pi_p_forward = []
  695. for G_p in Gn:
  696. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  697. pi_p_forward.append(pi_tmp_forward)
  698. return G
  699. ###############################################################################
  700. if __name__ == '__main__':
  701. from pygraph.utils.graphfiles import loadDataset
  702. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  703. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  704. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  705. # 'extra_params': {}} # node nsymb
  706. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  707. # 'extra_params': {}}
  708. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  709. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.