You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. from tqdm import tqdm
  12. import sys
  13. sys.path.insert(0, "../")
  14. from pygraph.utils.graphdataset import get_dataset_attributes
  15. from pygraph.utils.utils import graph_isIdentical, get_node_labels, get_edge_labels
  16. from ged import GED, ged_median
  17. def iam_upgraded(Gn_median, Gn_candidate, c_ei=3, c_er=3, c_es=1, ite_max=50,
  18. epsilon=0.001, node_label='atom', edge_label='bond_type',
  19. connected=False, removeNodes=True, allBestInit=False, allBestNodes=False,
  20. allBestEdges=False, allBestOutput=False,
  21. params_ged={'lib': 'gedlibpy', 'cost': 'CHEM_1', 'method': 'IPFP',
  22. 'edit_cost_constant': [], 'stabilizer': 'min', 'repeat': 50}):
  23. """See my name, then you know what I do.
  24. """
  25. # Gn_median = Gn_median[0:10]
  26. # Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
  27. node_ir = np.inf # corresponding to the node remove and insertion.
  28. label_r = 'thanksdanny' # the label for node remove. # @todo: make this label unrepeatable.
  29. ds_attrs = get_dataset_attributes(Gn_median + Gn_candidate,
  30. attr_names=['edge_labeled', 'node_attr_dim', 'edge_attr_dim'],
  31. edge_label=edge_label)
  32. node_label_set = get_node_labels(Gn_median, node_label)
  33. edge_label_set = get_edge_labels(Gn_median, edge_label)
  34. def generate_graph(G, pi_p_forward):
  35. G_new_list = [G.copy()] # all "best" graphs generated in this iteration.
  36. # nx.draw_networkx(G)
  37. # import matplotlib.pyplot as plt
  38. # plt.show()
  39. # print(pi_p_forward)
  40. # update vertex labels.
  41. # pre-compute h_i0 for each label.
  42. # for label in get_node_labels(Gn, node_label):
  43. # print(label)
  44. # for nd in G.nodes(data=True):
  45. # pass
  46. if not ds_attrs['node_attr_dim']: # labels are symbolic
  47. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  48. h_i0_list = []
  49. label_list = []
  50. for label in node_label_set:
  51. h_i0 = 0
  52. for idx, g in enumerate(Gn_median):
  53. pi_i = pi_p_forward[idx][ndi]
  54. if pi_i != node_ir and g.nodes[pi_i][node_label] == label:
  55. h_i0 += 1
  56. h_i0_list.append(h_i0)
  57. label_list.append(label)
  58. # case when the node is to be removed.
  59. if removeNodes:
  60. h_i0_remove = 0 # @todo: maybe this can be added to the node_label_set above.
  61. for idx, g in enumerate(Gn_median):
  62. pi_i = pi_p_forward[idx][ndi]
  63. if pi_i == node_ir:
  64. h_i0_remove += 1
  65. h_i0_list.append(h_i0_remove)
  66. label_list.append(label_r)
  67. # get the best labels.
  68. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  69. if allBestNodes: # choose all best graphs.
  70. nlabel_best = [label_list[idx] for idx in idx_max]
  71. # generate "best" graphs with regard to "best" node labels.
  72. G_new_list_nd = []
  73. for g in G_new_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  74. for nl in nlabel_best:
  75. g_tmp = g.copy()
  76. if nl == label_r:
  77. g_tmp.remove_node(nd)
  78. else:
  79. g_tmp.nodes[nd][node_label] = nl
  80. G_new_list_nd.append(g_tmp)
  81. # nx.draw_networkx(g_tmp)
  82. # import matplotlib.pyplot as plt
  83. # plt.show()
  84. # print(g_tmp.nodes(data=True))
  85. # print(g_tmp.edges(data=True))
  86. G_new_list = [ggg.copy() for ggg in G_new_list_nd]
  87. else:
  88. # choose one of the best randomly.
  89. idx_rdm = random.randint(0, len(idx_max) - 1)
  90. best_label = label_list[idx_max[idx_rdm]]
  91. h_i0_max = h_i0_list[idx_max[idx_rdm]]
  92. g_new = G_new_list[0]
  93. if best_label == label_r:
  94. g_new.remove_node(nd)
  95. else:
  96. g_new.nodes[nd][node_label] = best_label
  97. G_new_list = [g_new]
  98. else: # labels are non-symbolic
  99. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  100. Si_norm = 0
  101. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  102. for idx, g in enumerate(Gn_median):
  103. pi_i = pi_p_forward[idx][ndi]
  104. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  105. Si_norm += 1
  106. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  107. phi_i_bar /= Si_norm
  108. G_new_list[0].nodes[nd]['attributes'] = phi_i_bar
  109. # for g in G_new_list:
  110. # import matplotlib.pyplot as plt
  111. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  112. # plt.show()
  113. # print(g.nodes(data=True))
  114. # print(g.edges(data=True))
  115. # update edge labels and adjacency matrix.
  116. if ds_attrs['edge_labeled']:
  117. G_new_list_edge = []
  118. for g_new in G_new_list:
  119. nd_list = [n for n in g_new.nodes()]
  120. g_tmp_list = [g_new.copy()]
  121. for nd1i in range(nx.number_of_nodes(g_new)):
  122. nd1 = nd_list[nd1i]# @todo: not just edges, but all pairs of nodes
  123. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_new)):
  124. nd2 = nd_list[nd2i]
  125. # for nd1, nd2, _ in g_new.edges(data=True):
  126. h_ij0_list = []
  127. label_list = []
  128. for label in edge_label_set:
  129. h_ij0 = 0
  130. for idx, g in enumerate(Gn_median):
  131. pi_i = pi_p_forward[idx][nd1i]
  132. pi_j = pi_p_forward[idx][nd2i]
  133. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  134. g.has_edge(pi_i, pi_j) and
  135. g.edges[pi_i, pi_j][edge_label] == label)
  136. h_ij0 += h_ij0_p
  137. h_ij0_list.append(h_ij0)
  138. label_list.append(label)
  139. # get the best labels.
  140. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  141. if allBestEdges: # choose all best graphs.
  142. elabel_best = [label_list[idx] for idx in idx_max]
  143. h_ij0_max = [h_ij0_list[idx] for idx in idx_max]
  144. # generate "best" graphs with regard to "best" node labels.
  145. G_new_list_ed = []
  146. for g_tmp in g_tmp_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  147. for idxl, el in enumerate(elabel_best):
  148. g_tmp_copy = g_tmp.copy()
  149. # check whether a_ij is 0 or 1.
  150. sij_norm = 0
  151. for idx, g in enumerate(Gn_median):
  152. pi_i = pi_p_forward[idx][nd1i]
  153. pi_j = pi_p_forward[idx][nd2i]
  154. if g.has_node(pi_i) and g.has_node(pi_j) and \
  155. g.has_edge(pi_i, pi_j):
  156. sij_norm += 1
  157. if h_ij0_max[idxl] > len(Gn_median) * c_er / c_es + \
  158. sij_norm * (1 - (c_er + c_ei) / c_es):
  159. if not g_tmp_copy.has_edge(nd1, nd2):
  160. g_tmp_copy.add_edge(nd1, nd2)
  161. g_tmp_copy.edges[nd1, nd2][edge_label] = elabel_best[idxl]
  162. else:
  163. if g_tmp_copy.has_edge(nd1, nd2):
  164. g_tmp_copy.remove_edge(nd1, nd2)
  165. G_new_list_ed.append(g_tmp_copy)
  166. g_tmp_list = [ggg.copy() for ggg in G_new_list_ed]
  167. else: # choose one of the best randomly.
  168. idx_rdm = random.randint(0, len(idx_max) - 1)
  169. best_label = label_list[idx_max[idx_rdm]]
  170. h_ij0_max = h_ij0_list[idx_max[idx_rdm]]
  171. # check whether a_ij is 0 or 1.
  172. sij_norm = 0
  173. for idx, g in enumerate(Gn_median):
  174. pi_i = pi_p_forward[idx][nd1i]
  175. pi_j = pi_p_forward[idx][nd2i]
  176. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  177. sij_norm += 1
  178. if h_ij0_max > len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  179. if not g_new.has_edge(nd1, nd2):
  180. g_new.add_edge(nd1, nd2)
  181. g_new.edges[nd1, nd2][edge_label] = best_label
  182. else:
  183. # elif h_ij0_max < len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  184. if g_new.has_edge(nd1, nd2):
  185. g_new.remove_edge(nd1, nd2)
  186. g_tmp_list = [g_new]
  187. G_new_list_edge += g_tmp_list
  188. G_new_list = [ggg.copy() for ggg in G_new_list_edge]
  189. else: # if edges are unlabeled
  190. # @todo: is this even right? G or g_tmp? check if the new one is right
  191. # @todo: works only for undirected graphs.
  192. for g_tmp in G_new_list:
  193. nd_list = [n for n in g_tmp.nodes()]
  194. for nd1i in range(nx.number_of_nodes(g_tmp)):
  195. nd1 = nd_list[nd1i]
  196. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_tmp)):
  197. nd2 = nd_list[nd2i]
  198. sij_norm = 0
  199. for idx, g in enumerate(Gn_median):
  200. pi_i = pi_p_forward[idx][nd1i]
  201. pi_j = pi_p_forward[idx][nd2i]
  202. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  203. sij_norm += 1
  204. if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
  205. # @todo: should we consider if nd1 and nd2 in g_tmp?
  206. # or just add the edge anyway?
  207. if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
  208. and not g_tmp.has_edge(nd1, nd2):
  209. g_tmp.add_edge(nd1, nd2)
  210. else: # @todo: which to use?
  211. # elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
  212. if g_tmp.has_edge(nd1, nd2):
  213. g_tmp.remove_edge(nd1, nd2)
  214. # do not change anything when equal.
  215. # for i, g in enumerate(G_new_list):
  216. # import matplotlib.pyplot as plt
  217. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  218. ## plt.savefig("results/gk_iam/simple_two/xx" + str(i) + ".png", format="PNG")
  219. # plt.show()
  220. # print(g.nodes(data=True))
  221. # print(g.edges(data=True))
  222. # # find the best graph generated in this iteration and update pi_p.
  223. # @todo: should we update all graphs generated or just the best ones?
  224. dis_list, pi_forward_list = ged_median(G_new_list, Gn_median,
  225. params_ged=params_ged)
  226. # @todo: should we remove the identical and connectivity check?
  227. # Don't know which is faster.
  228. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  229. G_new_list, idx_list = remove_duplicates(G_new_list)
  230. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  231. dis_list = [dis_list[idx] for idx in idx_list]
  232. # if connected == True:
  233. # G_new_list, idx_list = remove_disconnected(G_new_list)
  234. # pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  235. # idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  236. # dis_min = dis_list[idx_min_tmp_list[0]]
  237. # pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
  238. # G_new_list = [G_new_list[idx] for idx in idx_min_list]
  239. # for g in G_new_list:
  240. # import matplotlib.pyplot as plt
  241. # nx.draw_networkx(g)
  242. # plt.show()
  243. # print(g.nodes(data=True))
  244. # print(g.edges(data=True))
  245. return G_new_list, pi_forward_list, dis_list
  246. def best_median_graphs(Gn_candidate, pi_all_forward, dis_all):
  247. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  248. dis_min = dis_all[idx_min_list[0]]
  249. pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
  250. G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
  251. return G_min_list, pi_forward_min_list, dis_min
  252. def iteration_proc(G, pi_p_forward, cur_sod):
  253. G_list = [G]
  254. pi_forward_list = [pi_p_forward]
  255. old_sod = cur_sod * 2
  256. sod_list = [cur_sod]
  257. dis_list = [cur_sod]
  258. # iterations.
  259. itr = 0
  260. # @todo: what if difference == 0?
  261. # while itr < ite_max and (np.abs(old_sod - cur_sod) > epsilon or
  262. # np.abs(old_sod - cur_sod) == 0):
  263. while itr < ite_max and np.abs(old_sod - cur_sod) > epsilon:
  264. # while itr < ite_max:
  265. # for itr in range(0, 5): # the convergence condition?
  266. print('itr_iam is', itr)
  267. G_new_list = []
  268. pi_forward_new_list = []
  269. dis_new_list = []
  270. for idx, g in enumerate(G_list):
  271. # label_set = get_node_labels(Gn_median + [g], node_label)
  272. G_tmp_list, pi_forward_tmp_list, dis_tmp_list = generate_graph(
  273. g, pi_forward_list[idx])
  274. G_new_list += G_tmp_list
  275. pi_forward_new_list += pi_forward_tmp_list
  276. dis_new_list += dis_tmp_list
  277. # @todo: need to remove duplicates here?
  278. G_list = [ggg.copy() for ggg in G_new_list]
  279. pi_forward_list = [pitem.copy() for pitem in pi_forward_new_list]
  280. dis_list = dis_new_list[:]
  281. old_sod = cur_sod
  282. cur_sod = np.min(dis_list)
  283. sod_list.append(cur_sod)
  284. itr += 1
  285. # @todo: do we return all graphs or the best ones?
  286. # get the best ones of the generated graphs.
  287. G_list, pi_forward_list, dis_min = best_median_graphs(
  288. G_list, pi_forward_list, dis_list)
  289. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  290. G_list, idx_list = remove_duplicates(G_list)
  291. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  292. # dis_list = [dis_list[idx] for idx in idx_list]
  293. # import matplotlib.pyplot as plt
  294. # for g in G_list:
  295. # nx.draw_networkx(g)
  296. # plt.show()
  297. # print(g.nodes(data=True))
  298. # print(g.edges(data=True))
  299. print('\nsods:', sod_list, '\n')
  300. return G_list, pi_forward_list, dis_min, sod_list
  301. def remove_duplicates(Gn):
  302. """Remove duplicate graphs from list.
  303. """
  304. Gn_new = []
  305. idx_list = []
  306. for idx, g in enumerate(Gn):
  307. dupl = False
  308. for g_new in Gn_new:
  309. if graph_isIdentical(g_new, g):
  310. dupl = True
  311. break
  312. if not dupl:
  313. Gn_new.append(g)
  314. idx_list.append(idx)
  315. return Gn_new, idx_list
  316. def remove_disconnected(Gn):
  317. """Remove disconnected graphs from list.
  318. """
  319. Gn_new = []
  320. idx_list = []
  321. for idx, g in enumerate(Gn):
  322. if nx.is_connected(g):
  323. Gn_new.append(g)
  324. idx_list.append(idx)
  325. return Gn_new, idx_list
  326. ###########################################################################
  327. # phase 1: initilize.
  328. # compute set-median.
  329. dis_min = np.inf
  330. dis_list, pi_forward_all = ged_median(Gn_candidate, Gn_median,
  331. params_ged=params_ged, parallel=True)
  332. print('finish computing GEDs.')
  333. # find all smallest distances.
  334. if allBestInit: # try all best init graphs.
  335. idx_min_list = range(len(dis_list))
  336. dis_min = dis_list
  337. else:
  338. idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  339. dis_min = [dis_list[idx_min_list[0]]] * len(idx_min_list)
  340. idx_min_rdm = random.randint(0, len(idx_min_list) - 1)
  341. idx_min_list = [idx_min_list[idx_min_rdm]]
  342. sod_set_median = np.min(dis_min)
  343. # phase 2: iteration.
  344. G_list = []
  345. dis_list = []
  346. pi_forward_list = []
  347. G_set_median_list = []
  348. # sod_list = []
  349. for idx_tmp, idx_min in enumerate(idx_min_list):
  350. # print('idx_min is', idx_min)
  351. G = Gn_candidate[idx_min].copy()
  352. G_set_median_list.append(G.copy())
  353. # list of edit operations.
  354. pi_p_forward = pi_forward_all[idx_min]
  355. # pi_p_backward = pi_all_backward[idx_min]
  356. Gi_list, pi_i_forward_list, dis_i_min, sod_list = iteration_proc(G,
  357. pi_p_forward, dis_min[idx_tmp])
  358. G_list += Gi_list
  359. dis_list += [dis_i_min] * len(Gi_list)
  360. pi_forward_list += pi_i_forward_list
  361. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  362. G_list, idx_list = remove_duplicates(G_list)
  363. dis_list = [dis_list[idx] for idx in idx_list]
  364. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  365. if connected == True:
  366. G_list_con, idx_list = remove_disconnected(G_list)
  367. # if there is no connected graphs at all, then remain the disconnected ones.
  368. if len(G_list_con) > 0: # @todo: ??????????????????????????
  369. G_list = G_list_con
  370. dis_list = [dis_list[idx] for idx in idx_list]
  371. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  372. # import matplotlib.pyplot as plt
  373. # for g in G_list:
  374. # nx.draw_networkx(g)
  375. # plt.show()
  376. # print(g.nodes(data=True))
  377. # print(g.edges(data=True))
  378. # get the best median graphs
  379. G_gen_median_list, pi_forward_min_list, sod_gen_median = best_median_graphs(
  380. G_list, pi_forward_list, dis_list)
  381. # for g in G_gen_median_list:
  382. # nx.draw_networkx(g)
  383. # plt.show()
  384. # print(g.nodes(data=True))
  385. # print(g.edges(data=True))
  386. if not allBestOutput:
  387. # randomly choose one graph.
  388. idx_rdm = random.randint(0, len(G_gen_median_list) - 1)
  389. G_gen_median_list = [G_gen_median_list[idx_rdm]]
  390. return G_gen_median_list, sod_gen_median, sod_list, G_set_median_list, sod_set_median
  391. ###############################################################################
  392. # Old implementations.
  393. def iam(Gn, c_ei=3, c_er=3, c_es=1, node_label='atom', edge_label='bond_type',
  394. connected=True):
  395. """See my name, then you know what I do.
  396. """
  397. # Gn = Gn[0:10]
  398. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  399. # phase 1: initilize.
  400. # compute set-median.
  401. dis_min = np.inf
  402. pi_p = []
  403. pi_all = []
  404. for idx1, G_p in enumerate(Gn):
  405. dist_sum = 0
  406. pi_all.append([])
  407. for idx2, G_p_prime in enumerate(Gn):
  408. dist_tmp, pi_tmp, _ = GED(G_p, G_p_prime)
  409. pi_all[idx1].append(pi_tmp)
  410. dist_sum += dist_tmp
  411. if dist_sum < dis_min:
  412. dis_min = dist_sum
  413. G = G_p.copy()
  414. idx_min = idx1
  415. # list of edit operations.
  416. pi_p = pi_all[idx_min]
  417. # phase 2: iteration.
  418. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  419. edge_label=edge_label)
  420. for itr in range(0, 10): # @todo: the convergence condition?
  421. G_new = G.copy()
  422. # update vertex labels.
  423. # pre-compute h_i0 for each label.
  424. # for label in get_node_labels(Gn, node_label):
  425. # print(label)
  426. # for nd in G.nodes(data=True):
  427. # pass
  428. if not ds_attrs['node_attr_dim']: # labels are symbolic
  429. for nd, _ in G.nodes(data=True):
  430. h_i0_list = []
  431. label_list = []
  432. for label in get_node_labels(Gn, node_label):
  433. h_i0 = 0
  434. for idx, g in enumerate(Gn):
  435. pi_i = pi_p[idx][nd]
  436. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  437. h_i0 += 1
  438. h_i0_list.append(h_i0)
  439. label_list.append(label)
  440. # choose one of the best randomly.
  441. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  442. idx_rdm = random.randint(0, len(idx_max) - 1)
  443. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  444. else: # labels are non-symbolic
  445. for nd, _ in G.nodes(data=True):
  446. Si_norm = 0
  447. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  448. for idx, g in enumerate(Gn):
  449. pi_i = pi_p[idx][nd]
  450. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  451. Si_norm += 1
  452. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  453. phi_i_bar /= Si_norm
  454. G_new.nodes[nd]['attributes'] = phi_i_bar
  455. # update edge labels and adjacency matrix.
  456. if ds_attrs['edge_labeled']:
  457. for nd1, nd2, _ in G.edges(data=True):
  458. h_ij0_list = []
  459. label_list = []
  460. for label in get_edge_labels(Gn, edge_label):
  461. h_ij0 = 0
  462. for idx, g in enumerate(Gn):
  463. pi_i = pi_p[idx][nd1]
  464. pi_j = pi_p[idx][nd2]
  465. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  466. g.has_edge(pi_i, pi_j) and
  467. g.edges[pi_i, pi_j][edge_label] == label)
  468. h_ij0 += h_ij0_p
  469. h_ij0_list.append(h_ij0)
  470. label_list.append(label)
  471. # choose one of the best randomly.
  472. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  473. h_ij0_max = h_ij0_list[idx_max[0]]
  474. idx_rdm = random.randint(0, len(idx_max) - 1)
  475. best_label = label_list[idx_max[idx_rdm]]
  476. # check whether a_ij is 0 or 1.
  477. sij_norm = 0
  478. for idx, g in enumerate(Gn):
  479. pi_i = pi_p[idx][nd1]
  480. pi_j = pi_p[idx][nd2]
  481. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  482. sij_norm += 1
  483. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  484. if not G_new.has_edge(nd1, nd2):
  485. G_new.add_edge(nd1, nd2)
  486. G_new.edges[nd1, nd2][edge_label] = best_label
  487. else:
  488. if G_new.has_edge(nd1, nd2):
  489. G_new.remove_edge(nd1, nd2)
  490. else: # if edges are unlabeled
  491. for nd1, nd2, _ in G.edges(data=True):
  492. sij_norm = 0
  493. for idx, g in enumerate(Gn):
  494. pi_i = pi_p[idx][nd1]
  495. pi_j = pi_p[idx][nd2]
  496. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  497. sij_norm += 1
  498. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  499. if not G_new.has_edge(nd1, nd2):
  500. G_new.add_edge(nd1, nd2)
  501. else:
  502. if G_new.has_edge(nd1, nd2):
  503. G_new.remove_edge(nd1, nd2)
  504. G = G_new.copy()
  505. # update pi_p
  506. pi_p = []
  507. for idx1, G_p in enumerate(Gn):
  508. dist_tmp, pi_tmp, _ = GED(G, G_p)
  509. pi_p.append(pi_tmp)
  510. return G
  511. # --------------------------- These are tests --------------------------------#
  512. def test_iam_with_more_graphs_as_init(Gn, G_candidate, c_ei=3, c_er=3, c_es=1,
  513. node_label='atom', edge_label='bond_type'):
  514. """See my name, then you know what I do.
  515. """
  516. # Gn = Gn[0:10]
  517. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  518. # phase 1: initilize.
  519. # compute set-median.
  520. dis_min = np.inf
  521. # pi_p = []
  522. pi_all_forward = []
  523. pi_all_backward = []
  524. for idx1, G_p in tqdm(enumerate(G_candidate), desc='computing GEDs', file=sys.stdout):
  525. dist_sum = 0
  526. pi_all_forward.append([])
  527. pi_all_backward.append([])
  528. for idx2, G_p_prime in enumerate(Gn):
  529. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_p, G_p_prime)
  530. pi_all_forward[idx1].append(pi_tmp_forward)
  531. pi_all_backward[idx1].append(pi_tmp_backward)
  532. dist_sum += dist_tmp
  533. if dist_sum <= dis_min:
  534. dis_min = dist_sum
  535. G = G_p.copy()
  536. idx_min = idx1
  537. # list of edit operations.
  538. pi_p_forward = pi_all_forward[idx_min]
  539. pi_p_backward = pi_all_backward[idx_min]
  540. # phase 2: iteration.
  541. ds_attrs = get_dataset_attributes(Gn + [G], attr_names=['edge_labeled', 'node_attr_dim'],
  542. edge_label=edge_label)
  543. label_set = get_node_labels(Gn + [G], node_label)
  544. for itr in range(0, 10): # @todo: the convergence condition?
  545. G_new = G.copy()
  546. # update vertex labels.
  547. # pre-compute h_i0 for each label.
  548. # for label in get_node_labels(Gn, node_label):
  549. # print(label)
  550. # for nd in G.nodes(data=True):
  551. # pass
  552. if not ds_attrs['node_attr_dim']: # labels are symbolic
  553. for nd in G.nodes():
  554. h_i0_list = []
  555. label_list = []
  556. for label in label_set:
  557. h_i0 = 0
  558. for idx, g in enumerate(Gn):
  559. pi_i = pi_p_forward[idx][nd]
  560. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  561. h_i0 += 1
  562. h_i0_list.append(h_i0)
  563. label_list.append(label)
  564. # choose one of the best randomly.
  565. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  566. idx_rdm = random.randint(0, len(idx_max) - 1)
  567. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  568. else: # labels are non-symbolic
  569. for nd in G.nodes():
  570. Si_norm = 0
  571. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  572. for idx, g in enumerate(Gn):
  573. pi_i = pi_p_forward[idx][nd]
  574. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  575. Si_norm += 1
  576. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  577. phi_i_bar /= Si_norm
  578. G_new.nodes[nd]['attributes'] = phi_i_bar
  579. # update edge labels and adjacency matrix.
  580. if ds_attrs['edge_labeled']:
  581. for nd1, nd2, _ in G.edges(data=True):
  582. h_ij0_list = []
  583. label_list = []
  584. for label in get_edge_labels(Gn, edge_label):
  585. h_ij0 = 0
  586. for idx, g in enumerate(Gn):
  587. pi_i = pi_p_forward[idx][nd1]
  588. pi_j = pi_p_forward[idx][nd2]
  589. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  590. g.has_edge(pi_i, pi_j) and
  591. g.edges[pi_i, pi_j][edge_label] == label)
  592. h_ij0 += h_ij0_p
  593. h_ij0_list.append(h_ij0)
  594. label_list.append(label)
  595. # choose one of the best randomly.
  596. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  597. h_ij0_max = h_ij0_list[idx_max[0]]
  598. idx_rdm = random.randint(0, len(idx_max) - 1)
  599. best_label = label_list[idx_max[idx_rdm]]
  600. # check whether a_ij is 0 or 1.
  601. sij_norm = 0
  602. for idx, g in enumerate(Gn):
  603. pi_i = pi_p_forward[idx][nd1]
  604. pi_j = pi_p_forward[idx][nd2]
  605. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  606. sij_norm += 1
  607. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  608. if not G_new.has_edge(nd1, nd2):
  609. G_new.add_edge(nd1, nd2)
  610. G_new.edges[nd1, nd2][edge_label] = best_label
  611. else:
  612. if G_new.has_edge(nd1, nd2):
  613. G_new.remove_edge(nd1, nd2)
  614. else: # if edges are unlabeled
  615. # @todo: works only for undirected graphs.
  616. for nd1 in range(nx.number_of_nodes(G)):
  617. for nd2 in range(nd1 + 1, nx.number_of_nodes(G)):
  618. sij_norm = 0
  619. for idx, g in enumerate(Gn):
  620. pi_i = pi_p_forward[idx][nd1]
  621. pi_j = pi_p_forward[idx][nd2]
  622. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  623. sij_norm += 1
  624. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  625. if not G_new.has_edge(nd1, nd2):
  626. G_new.add_edge(nd1, nd2)
  627. elif sij_norm < len(Gn) * c_er / (c_er + c_ei):
  628. if G_new.has_edge(nd1, nd2):
  629. G_new.remove_edge(nd1, nd2)
  630. # do not change anything when equal.
  631. G = G_new.copy()
  632. # update pi_p
  633. pi_p_forward = []
  634. for G_p in Gn:
  635. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  636. pi_p_forward.append(pi_tmp_forward)
  637. return G
  638. ###############################################################################
  639. if __name__ == '__main__':
  640. from pygraph.utils.graphfiles import loadDataset
  641. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  642. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  643. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  644. # 'extra_params': {}} # node nsymb
  645. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  646. # 'extra_params': {}}
  647. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  648. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.