You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 36 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. from tqdm import tqdm
  12. import sys
  13. #from Cython_GedLib_2 import librariesImport, script
  14. import librariesImport, script
  15. sys.path.insert(0, "../")
  16. from pygraph.utils.graphfiles import saveDataset
  17. from pygraph.utils.graphdataset import get_dataset_attributes
  18. from pygraph.utils.utils import graph_isIdentical, get_node_labels, get_edge_labels
  19. #from pygraph.utils.utils import graph_deepcopy
  20. def iam_moreGraphsAsInit_tryAllPossibleBestGraphs(Gn_median, Gn_candidate,
  21. c_ei=3, c_er=3, c_es=1, ite_max=50, epsilon=0.001,
  22. node_label='atom', edge_label='bond_type',
  23. connected=False, removeNodes=True, allBestInit=False, allBestNodes=False,
  24. allBestEdges=False,
  25. params_ged={'ged_cost': 'CHEM_1', 'ged_method': 'IPFP', 'saveGXL': 'benoit'}):
  26. """See my name, then you know what I do.
  27. """
  28. from tqdm import tqdm
  29. # Gn_median = Gn_median[0:10]
  30. # Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
  31. if removeNodes:
  32. node_ir = np.inf # corresponding to the node remove and insertion.
  33. label_r = 'thanksdanny' # the label for node remove. # @todo: make this label unrepeatable.
  34. ds_attrs = get_dataset_attributes(Gn_median + Gn_candidate,
  35. attr_names=['edge_labeled', 'node_attr_dim', 'edge_attr_dim'],
  36. edge_label=edge_label)
  37. def generate_graph(G, pi_p_forward, label_set):
  38. G_new_list = [G.copy()] # all "best" graphs generated in this iteration.
  39. # nx.draw_networkx(G)
  40. # import matplotlib.pyplot as plt
  41. # plt.show()
  42. # print(pi_p_forward)
  43. # update vertex labels.
  44. # pre-compute h_i0 for each label.
  45. # for label in get_node_labels(Gn, node_label):
  46. # print(label)
  47. # for nd in G.nodes(data=True):
  48. # pass
  49. if not ds_attrs['node_attr_dim']: # labels are symbolic
  50. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  51. h_i0_list = []
  52. label_list = []
  53. for label in label_set:
  54. h_i0 = 0
  55. for idx, g in enumerate(Gn_median):
  56. pi_i = pi_p_forward[idx][ndi]
  57. if pi_i != node_ir and g.nodes[pi_i][node_label] == label:
  58. h_i0 += 1
  59. h_i0_list.append(h_i0)
  60. label_list.append(label)
  61. # case when the node is to be removed.
  62. if removeNodes:
  63. h_i0_remove = 0 # @todo: maybe this can be added to the label_set above.
  64. for idx, g in enumerate(Gn_median):
  65. pi_i = pi_p_forward[idx][ndi]
  66. if pi_i == node_ir:
  67. h_i0_remove += 1
  68. h_i0_list.append(h_i0_remove)
  69. label_list.append(label_r)
  70. # get the best labels.
  71. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  72. if allBestNodes: # choose all best graphs.
  73. nlabel_best = [label_list[idx] for idx in idx_max]
  74. # generate "best" graphs with regard to "best" node labels.
  75. G_new_list_nd = []
  76. for g in G_new_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  77. for nl in nlabel_best:
  78. g_tmp = g.copy()
  79. if nl == label_r:
  80. g_tmp.remove_node(nd)
  81. else:
  82. g_tmp.nodes[nd][node_label] = nl
  83. G_new_list_nd.append(g_tmp)
  84. # nx.draw_networkx(g_tmp)
  85. # import matplotlib.pyplot as plt
  86. # plt.show()
  87. # print(g_tmp.nodes(data=True))
  88. # print(g_tmp.edges(data=True))
  89. G_new_list = [ggg.copy() for ggg in G_new_list_nd]
  90. else:
  91. # choose one of the best randomly.
  92. h_ij0_max = h_i0_list[idx_max[0]]
  93. idx_rdm = random.randint(0, len(idx_max) - 1)
  94. best_label = label_list[idx_max[idx_rdm]]
  95. # check whether a_ij is 0 or 1.
  96. g_new = G_new_list[0]
  97. if best_label == label_r:
  98. g_new.remove_node(nd)
  99. else:
  100. g_new.nodes[nd][node_label] = best_label
  101. G_new_list = [g_new]
  102. else: # labels are non-symbolic
  103. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  104. Si_norm = 0
  105. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  106. for idx, g in enumerate(Gn_median):
  107. pi_i = pi_p_forward[idx][ndi]
  108. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  109. Si_norm += 1
  110. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  111. phi_i_bar /= Si_norm
  112. G_new_list[0].nodes[nd]['attributes'] = phi_i_bar
  113. # for g in G_new_list:
  114. # import matplotlib.pyplot as plt
  115. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  116. # plt.show()
  117. # print(g.nodes(data=True))
  118. # print(g.edges(data=True))
  119. # update edge labels and adjacency matrix.
  120. if ds_attrs['edge_labeled']:
  121. G_new_list_edge = []
  122. for g_new in G_new_list:
  123. nd_list = [n for n in g_new.nodes()]
  124. g_tmp_list = [g_new.copy()]
  125. for nd1i in range(nx.number_of_nodes(g_new)):
  126. nd1 = nd_list[nd1i]# @todo: not just edges, but all pairs of nodes
  127. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_new)):
  128. nd2 = nd_list[nd2i]
  129. # for nd1, nd2, _ in g_new.edges(data=True):
  130. h_ij0_list = []
  131. label_list = []
  132. # @todo: compute edge label set before.
  133. for label in get_edge_labels(Gn_median, edge_label):
  134. h_ij0 = 0
  135. for idx, g in enumerate(Gn_median):
  136. pi_i = pi_p_forward[idx][nd1i]
  137. pi_j = pi_p_forward[idx][nd2i]
  138. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  139. g.has_edge(pi_i, pi_j) and
  140. g.edges[pi_i, pi_j][edge_label] == label)
  141. h_ij0 += h_ij0_p
  142. h_ij0_list.append(h_ij0)
  143. label_list.append(label)
  144. # # case when the edge is to be removed.
  145. # h_ij0_remove = 0
  146. # for idx, g in enumerate(Gn_median):
  147. # pi_i = pi_p_forward[idx][nd1i]
  148. # pi_j = pi_p_forward[idx][nd2i]
  149. # if g.has_node(pi_i) and g.has_node(pi_j) and not
  150. # g.has_edge(pi_i, pi_j):
  151. # h_ij0_remove += 1
  152. # h_ij0_list.append(h_ij0_remove)
  153. # label_list.append(label_r)
  154. # get the best labels.
  155. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  156. if allBestEdges: # choose all best graphs.
  157. elabel_best = [label_list[idx] for idx in idx_max]
  158. h_ij0_max = [h_ij0_list[idx] for idx in idx_max]
  159. # generate "best" graphs with regard to "best" node labels.
  160. G_new_list_ed = []
  161. for g_tmp in g_tmp_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  162. for idxl, el in enumerate(elabel_best):
  163. g_tmp_copy = g_tmp.copy()
  164. # check whether a_ij is 0 or 1.
  165. sij_norm = 0
  166. for idx, g in enumerate(Gn_median):
  167. pi_i = pi_p_forward[idx][nd1i]
  168. pi_j = pi_p_forward[idx][nd2i]
  169. if g.has_node(pi_i) and g.has_node(pi_j) and \
  170. g.has_edge(pi_i, pi_j):
  171. sij_norm += 1
  172. if h_ij0_max[idxl] > len(Gn_median) * c_er / c_es + \
  173. sij_norm * (1 - (c_er + c_ei) / c_es):
  174. if not g_tmp_copy.has_edge(nd1, nd2):
  175. g_tmp_copy.add_edge(nd1, nd2)
  176. g_tmp_copy.edges[nd1, nd2][edge_label] = elabel_best[idxl]
  177. else:
  178. if g_tmp_copy.has_edge(nd1, nd2):
  179. g_tmp_copy.remove_edge(nd1, nd2)
  180. G_new_list_ed.append(g_tmp_copy)
  181. g_tmp_list = [ggg.copy() for ggg in G_new_list_ed]
  182. else: # choose one of the best randomly.
  183. h_ij0_max = h_ij0_list[idx_max[0]]
  184. idx_rdm = random.randint(0, len(idx_max) - 1)
  185. best_label = label_list[idx_max[idx_rdm]]
  186. # check whether a_ij is 0 or 1.
  187. sij_norm = 0
  188. for idx, g in enumerate(Gn_median):
  189. pi_i = pi_p_forward[idx][nd1i]
  190. pi_j = pi_p_forward[idx][nd2i]
  191. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  192. sij_norm += 1
  193. if h_ij0_max > len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  194. if not g_new.has_edge(nd1, nd2):
  195. g_new.add_edge(nd1, nd2)
  196. g_new.edges[nd1, nd2][edge_label] = best_label
  197. else:
  198. if g_new.has_edge(nd1, nd2):
  199. g_new.remove_edge(nd1, nd2)
  200. g_tmp_list = [g_new]
  201. G_new_list_edge += g_tmp_list
  202. G_new_list = [ggg.copy() for ggg in G_new_list_edge]
  203. else: # if edges are unlabeled
  204. # @todo: is this even right? G or g_tmp? check if the new one is right
  205. # @todo: works only for undirected graphs.
  206. for g_tmp in G_new_list:
  207. nd_list = [n for n in g_tmp.nodes()]
  208. for nd1i in range(nx.number_of_nodes(g_tmp)):
  209. nd1 = nd_list[nd1i]
  210. for nd2i in range(nd1i + 1, nx.number_of_nodes(g_tmp)):
  211. nd2 = nd_list[nd2i]
  212. sij_norm = 0
  213. for idx, g in enumerate(Gn_median):
  214. pi_i = pi_p_forward[idx][nd1i]
  215. pi_j = pi_p_forward[idx][nd2i]
  216. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  217. sij_norm += 1
  218. if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
  219. # @todo: should we consider if nd1 and nd2 in g_tmp?
  220. # or just add the edge anyway?
  221. if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
  222. and not g_tmp.has_edge(nd1, nd2):
  223. g_tmp.add_edge(nd1, nd2)
  224. # else: # @todo: which to use?
  225. elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
  226. if g_tmp.has_edge(nd1, nd2):
  227. g_tmp.remove_edge(nd1, nd2)
  228. # do not change anything when equal.
  229. # for i, g in enumerate(G_new_list):
  230. # import matplotlib.pyplot as plt
  231. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  232. ## plt.savefig("results/gk_iam/simple_two/xx" + str(i) + ".png", format="PNG")
  233. # plt.show()
  234. # print(g.nodes(data=True))
  235. # print(g.edges(data=True))
  236. # # find the best graph generated in this iteration and update pi_p.
  237. # @todo: should we update all graphs generated or just the best ones?
  238. dis_list, pi_forward_list = median_distance(G_new_list, Gn_median,
  239. **params_ged)
  240. # @todo: should we remove the identical and connectivity check?
  241. # Don't know which is faster.
  242. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  243. G_new_list, idx_list = remove_duplicates(G_new_list)
  244. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  245. dis_list = [dis_list[idx] for idx in idx_list]
  246. # if connected == True:
  247. # G_new_list, idx_list = remove_disconnected(G_new_list)
  248. # pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  249. # idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  250. # dis_min = dis_list[idx_min_tmp_list[0]]
  251. # pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
  252. # G_new_list = [G_new_list[idx] for idx in idx_min_list]
  253. # for g in G_new_list:
  254. # import matplotlib.pyplot as plt
  255. # nx.draw_networkx(g)
  256. # plt.show()
  257. # print(g.nodes(data=True))
  258. # print(g.edges(data=True))
  259. return G_new_list, pi_forward_list, dis_list
  260. def best_median_graphs(Gn_candidate, pi_all_forward, dis_all):
  261. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  262. dis_min = dis_all[idx_min_list[0]]
  263. pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
  264. G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
  265. return G_min_list, pi_forward_min_list, dis_min
  266. def iteration_proc(G, pi_p_forward, cur_sod):
  267. G_list = [G]
  268. pi_forward_list = [pi_p_forward]
  269. old_sod = cur_sod * 2
  270. sod_list = [cur_sod]
  271. dis_list = [cur_sod]
  272. # iterations.
  273. itr = 0
  274. # @todo: what if difference == 0?
  275. # while itr < ite_max and (np.abs(old_sod - cur_sod) > epsilon or
  276. # np.abs(old_sod - cur_sod) == 0):
  277. while itr < ite_max and np.abs(old_sod - cur_sod) > epsilon:
  278. # for itr in range(0, 5): # the convergence condition?
  279. print('itr_iam is', itr)
  280. G_new_list = []
  281. pi_forward_new_list = []
  282. dis_new_list = []
  283. for idx, g in enumerate(G_list):
  284. label_set = get_node_labels(Gn_median + [g], node_label)
  285. G_tmp_list, pi_forward_tmp_list, dis_tmp_list = generate_graph(
  286. g, pi_forward_list[idx], label_set)
  287. G_new_list += G_tmp_list
  288. pi_forward_new_list += pi_forward_tmp_list
  289. dis_new_list += dis_tmp_list
  290. # @todo: need to remove duplicates here?
  291. G_list = [ggg.copy() for ggg in G_new_list]
  292. pi_forward_list = [pitem.copy() for pitem in pi_forward_new_list]
  293. dis_list = dis_new_list[:]
  294. old_sod = cur_sod
  295. cur_sod = np.min(dis_list)
  296. sod_list.append(cur_sod)
  297. itr += 1
  298. # @todo: do we return all graphs or the best ones?
  299. # get the best ones of the generated graphs.
  300. G_list, pi_forward_list, dis_min = best_median_graphs(
  301. G_list, pi_forward_list, dis_list)
  302. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  303. G_list, idx_list = remove_duplicates(G_list)
  304. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  305. # dis_list = [dis_list[idx] for idx in idx_list]
  306. # import matplotlib.pyplot as plt
  307. # for g in G_list:
  308. # nx.draw_networkx(g)
  309. # plt.show()
  310. # print(g.nodes(data=True))
  311. # print(g.edges(data=True))
  312. print('\nsods:', sod_list, '\n')
  313. return G_list, pi_forward_list, dis_min
  314. def remove_duplicates(Gn):
  315. """Remove duplicate graphs from list.
  316. """
  317. Gn_new = []
  318. idx_list = []
  319. for idx, g in enumerate(Gn):
  320. dupl = False
  321. for g_new in Gn_new:
  322. if graph_isIdentical(g_new, g):
  323. dupl = True
  324. break
  325. if not dupl:
  326. Gn_new.append(g)
  327. idx_list.append(idx)
  328. return Gn_new, idx_list
  329. def remove_disconnected(Gn):
  330. """Remove disconnected graphs from list.
  331. """
  332. Gn_new = []
  333. idx_list = []
  334. for idx, g in enumerate(Gn):
  335. if nx.is_connected(g):
  336. Gn_new.append(g)
  337. idx_list.append(idx)
  338. return Gn_new, idx_list
  339. # phase 1: initilize.
  340. # compute set-median.
  341. dis_min = np.inf
  342. dis_list, pi_forward_all = median_distance(Gn_candidate, Gn_median,
  343. **params_ged)
  344. # find all smallest distances.
  345. if allBestInit: # try all best init graphs.
  346. idx_min_list = range(len(dis_list))
  347. dis_min = dis_list
  348. else:
  349. idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  350. dis_min = [dis_list[idx_min_list[0]]] * len(idx_min_list)
  351. # phase 2: iteration.
  352. G_list = []
  353. dis_list = []
  354. pi_forward_list = []
  355. for idx_tmp, idx_min in enumerate(idx_min_list):
  356. # print('idx_min is', idx_min)
  357. G = Gn_candidate[idx_min].copy()
  358. # list of edit operations.
  359. pi_p_forward = pi_forward_all[idx_min]
  360. # pi_p_backward = pi_all_backward[idx_min]
  361. Gi_list, pi_i_forward_list, dis_i_min = iteration_proc(G, pi_p_forward, dis_min[idx_tmp])
  362. G_list += Gi_list
  363. dis_list += [dis_i_min] * len(Gi_list)
  364. pi_forward_list += pi_i_forward_list
  365. if ds_attrs['node_attr_dim'] == 0 and ds_attrs['edge_attr_dim'] == 0:
  366. G_list, idx_list = remove_duplicates(G_list)
  367. dis_list = [dis_list[idx] for idx in idx_list]
  368. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  369. if connected == True:
  370. G_list_con, idx_list = remove_disconnected(G_list)
  371. # if there is no connected graphs at all, then remain the disconnected ones.
  372. if len(G_list_con) > 0: # @todo: ??????????????????????????
  373. G_list = G_list_con
  374. dis_list = [dis_list[idx] for idx in idx_list]
  375. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  376. # import matplotlib.pyplot as plt
  377. # for g in G_list:
  378. # nx.draw_networkx(g)
  379. # plt.show()
  380. # print(g.nodes(data=True))
  381. # print(g.edges(data=True))
  382. # get the best median graphs
  383. # dis_list, pi_forward_list = median_distance(G_list, Gn_median,
  384. # **params_ged)
  385. G_min_list, pi_forward_min_list, dis_min = best_median_graphs(
  386. G_list, pi_forward_list, dis_list)
  387. # for g in G_min_list:
  388. # nx.draw_networkx(g)
  389. # plt.show()
  390. # print(g.nodes(data=True))
  391. # print(g.edges(data=True))
  392. # randomly choose one graph.
  393. idx_rdm = random.randint(0, len(G_min_list) - 1)
  394. G_min_list = [G_min_list[idx_rdm]]
  395. return G_min_list, dis_min
  396. ###############################################################################
  397. def iam(Gn, c_ei=3, c_er=3, c_es=1, node_label='atom', edge_label='bond_type',
  398. connected=True):
  399. """See my name, then you know what I do.
  400. """
  401. # Gn = Gn[0:10]
  402. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  403. # phase 1: initilize.
  404. # compute set-median.
  405. dis_min = np.inf
  406. pi_p = []
  407. pi_all = []
  408. for idx1, G_p in enumerate(Gn):
  409. dist_sum = 0
  410. pi_all.append([])
  411. for idx2, G_p_prime in enumerate(Gn):
  412. dist_tmp, pi_tmp, _ = GED(G_p, G_p_prime)
  413. pi_all[idx1].append(pi_tmp)
  414. dist_sum += dist_tmp
  415. if dist_sum < dis_min:
  416. dis_min = dist_sum
  417. G = G_p.copy()
  418. idx_min = idx1
  419. # list of edit operations.
  420. pi_p = pi_all[idx_min]
  421. # phase 2: iteration.
  422. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  423. edge_label=edge_label)
  424. for itr in range(0, 10): # @todo: the convergence condition?
  425. G_new = G.copy()
  426. # update vertex labels.
  427. # pre-compute h_i0 for each label.
  428. # for label in get_node_labels(Gn, node_label):
  429. # print(label)
  430. # for nd in G.nodes(data=True):
  431. # pass
  432. if not ds_attrs['node_attr_dim']: # labels are symbolic
  433. for nd, _ in G.nodes(data=True):
  434. h_i0_list = []
  435. label_list = []
  436. for label in get_node_labels(Gn, node_label):
  437. h_i0 = 0
  438. for idx, g in enumerate(Gn):
  439. pi_i = pi_p[idx][nd]
  440. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  441. h_i0 += 1
  442. h_i0_list.append(h_i0)
  443. label_list.append(label)
  444. # choose one of the best randomly.
  445. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  446. idx_rdm = random.randint(0, len(idx_max) - 1)
  447. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  448. else: # labels are non-symbolic
  449. for nd, _ in G.nodes(data=True):
  450. Si_norm = 0
  451. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  452. for idx, g in enumerate(Gn):
  453. pi_i = pi_p[idx][nd]
  454. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  455. Si_norm += 1
  456. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  457. phi_i_bar /= Si_norm
  458. G_new.nodes[nd]['attributes'] = phi_i_bar
  459. # update edge labels and adjacency matrix.
  460. if ds_attrs['edge_labeled']:
  461. for nd1, nd2, _ in G.edges(data=True):
  462. h_ij0_list = []
  463. label_list = []
  464. for label in get_edge_labels(Gn, edge_label):
  465. h_ij0 = 0
  466. for idx, g in enumerate(Gn):
  467. pi_i = pi_p[idx][nd1]
  468. pi_j = pi_p[idx][nd2]
  469. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  470. g.has_edge(pi_i, pi_j) and
  471. g.edges[pi_i, pi_j][edge_label] == label)
  472. h_ij0 += h_ij0_p
  473. h_ij0_list.append(h_ij0)
  474. label_list.append(label)
  475. # choose one of the best randomly.
  476. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  477. h_ij0_max = h_ij0_list[idx_max[0]]
  478. idx_rdm = random.randint(0, len(idx_max) - 1)
  479. best_label = label_list[idx_max[idx_rdm]]
  480. # check whether a_ij is 0 or 1.
  481. sij_norm = 0
  482. for idx, g in enumerate(Gn):
  483. pi_i = pi_p[idx][nd1]
  484. pi_j = pi_p[idx][nd2]
  485. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  486. sij_norm += 1
  487. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  488. if not G_new.has_edge(nd1, nd2):
  489. G_new.add_edge(nd1, nd2)
  490. G_new.edges[nd1, nd2][edge_label] = best_label
  491. else:
  492. if G_new.has_edge(nd1, nd2):
  493. G_new.remove_edge(nd1, nd2)
  494. else: # if edges are unlabeled
  495. for nd1, nd2, _ in G.edges(data=True):
  496. sij_norm = 0
  497. for idx, g in enumerate(Gn):
  498. pi_i = pi_p[idx][nd1]
  499. pi_j = pi_p[idx][nd2]
  500. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  501. sij_norm += 1
  502. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  503. if not G_new.has_edge(nd1, nd2):
  504. G_new.add_edge(nd1, nd2)
  505. else:
  506. if G_new.has_edge(nd1, nd2):
  507. G_new.remove_edge(nd1, nd2)
  508. G = G_new.copy()
  509. # update pi_p
  510. pi_p = []
  511. for idx1, G_p in enumerate(Gn):
  512. dist_tmp, pi_tmp, _ = GED(G, G_p)
  513. pi_p.append(pi_tmp)
  514. return G
  515. def GED(g1, g2, lib='gedlib', cost='CHEM_1', method='IPFP', saveGXL='benoit',
  516. stabilizer='min'):
  517. """
  518. Compute GED.
  519. """
  520. if lib == 'gedlib':
  521. # transform dataset to the 'xml' file as the GedLib required.
  522. saveDataset([g1, g2], [None, None], group='xml', filename='ged_tmp/tmp',
  523. xparams={'method': saveGXL})
  524. # script.appel()
  525. script.PyRestartEnv()
  526. script.PyLoadGXLGraph('ged_tmp/', 'ged_tmp/tmp.xml')
  527. listID = script.PyGetGraphIds()
  528. script.PySetEditCost(cost) #("CHEM_1")
  529. script.PyInitEnv()
  530. script.PySetMethod(method, "")
  531. script.PyInitMethod()
  532. g = listID[0]
  533. h = listID[1]
  534. if stabilizer == None:
  535. script.PyRunMethod(g, h)
  536. pi_forward, pi_backward = script.PyGetAllMap(g, h)
  537. upper = script.PyGetUpperBound(g, h)
  538. lower = script.PyGetLowerBound(g, h)
  539. elif stabilizer == 'min':
  540. upper = np.inf
  541. for itr in range(50):
  542. script.PyRunMethod(g, h)
  543. upper_tmp = script.PyGetUpperBound(g, h)
  544. if upper_tmp < upper:
  545. upper = upper_tmp
  546. pi_forward, pi_backward = script.PyGetAllMap(g, h)
  547. lower = script.PyGetLowerBound(g, h)
  548. if upper == 0:
  549. break
  550. dis = upper
  551. # make the map label correct (label remove map as np.inf)
  552. nodes1 = [n for n in g1.nodes()]
  553. nodes2 = [n for n in g2.nodes()]
  554. nb1 = nx.number_of_nodes(g1)
  555. nb2 = nx.number_of_nodes(g2)
  556. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  557. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  558. return dis, pi_forward, pi_backward
  559. def median_distance(Gn, Gn_median, measure='ged', verbose=False,
  560. ged_cost='CHEM_1', ged_method='IPFP', saveGXL='benoit'):
  561. dis_list = []
  562. pi_forward_list = []
  563. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  564. file=sys.stdout) if verbose else enumerate(Gn):
  565. dis_sum = 0
  566. pi_forward_list.append([])
  567. for G_p in Gn_median:
  568. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p,
  569. cost=ged_cost, method=ged_method, saveGXL=saveGXL)
  570. pi_forward_list[idx].append(pi_tmp_forward)
  571. dis_sum += dis_tmp
  572. dis_list.append(dis_sum)
  573. return dis_list, pi_forward_list
  574. # --------------------------- These are tests --------------------------------#
  575. def test_iam_with_more_graphs_as_init(Gn, G_candidate, c_ei=3, c_er=3, c_es=1,
  576. node_label='atom', edge_label='bond_type'):
  577. """See my name, then you know what I do.
  578. """
  579. # Gn = Gn[0:10]
  580. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  581. # phase 1: initilize.
  582. # compute set-median.
  583. dis_min = np.inf
  584. # pi_p = []
  585. pi_all_forward = []
  586. pi_all_backward = []
  587. for idx1, G_p in tqdm(enumerate(G_candidate), desc='computing GEDs', file=sys.stdout):
  588. dist_sum = 0
  589. pi_all_forward.append([])
  590. pi_all_backward.append([])
  591. for idx2, G_p_prime in enumerate(Gn):
  592. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_p, G_p_prime)
  593. pi_all_forward[idx1].append(pi_tmp_forward)
  594. pi_all_backward[idx1].append(pi_tmp_backward)
  595. dist_sum += dist_tmp
  596. if dist_sum <= dis_min:
  597. dis_min = dist_sum
  598. G = G_p.copy()
  599. idx_min = idx1
  600. # list of edit operations.
  601. pi_p_forward = pi_all_forward[idx_min]
  602. pi_p_backward = pi_all_backward[idx_min]
  603. # phase 2: iteration.
  604. ds_attrs = get_dataset_attributes(Gn + [G], attr_names=['edge_labeled', 'node_attr_dim'],
  605. edge_label=edge_label)
  606. label_set = get_node_labels(Gn + [G], node_label)
  607. for itr in range(0, 10): # @todo: the convergence condition?
  608. G_new = G.copy()
  609. # update vertex labels.
  610. # pre-compute h_i0 for each label.
  611. # for label in get_node_labels(Gn, node_label):
  612. # print(label)
  613. # for nd in G.nodes(data=True):
  614. # pass
  615. if not ds_attrs['node_attr_dim']: # labels are symbolic
  616. for nd in G.nodes():
  617. h_i0_list = []
  618. label_list = []
  619. for label in label_set:
  620. h_i0 = 0
  621. for idx, g in enumerate(Gn):
  622. pi_i = pi_p_forward[idx][nd]
  623. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  624. h_i0 += 1
  625. h_i0_list.append(h_i0)
  626. label_list.append(label)
  627. # choose one of the best randomly.
  628. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  629. idx_rdm = random.randint(0, len(idx_max) - 1)
  630. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  631. else: # labels are non-symbolic
  632. for nd in G.nodes():
  633. Si_norm = 0
  634. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  635. for idx, g in enumerate(Gn):
  636. pi_i = pi_p_forward[idx][nd]
  637. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  638. Si_norm += 1
  639. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  640. phi_i_bar /= Si_norm
  641. G_new.nodes[nd]['attributes'] = phi_i_bar
  642. # update edge labels and adjacency matrix.
  643. if ds_attrs['edge_labeled']:
  644. for nd1, nd2, _ in G.edges(data=True):
  645. h_ij0_list = []
  646. label_list = []
  647. for label in get_edge_labels(Gn, edge_label):
  648. h_ij0 = 0
  649. for idx, g in enumerate(Gn):
  650. pi_i = pi_p_forward[idx][nd1]
  651. pi_j = pi_p_forward[idx][nd2]
  652. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  653. g.has_edge(pi_i, pi_j) and
  654. g.edges[pi_i, pi_j][edge_label] == label)
  655. h_ij0 += h_ij0_p
  656. h_ij0_list.append(h_ij0)
  657. label_list.append(label)
  658. # choose one of the best randomly.
  659. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  660. h_ij0_max = h_ij0_list[idx_max[0]]
  661. idx_rdm = random.randint(0, len(idx_max) - 1)
  662. best_label = label_list[idx_max[idx_rdm]]
  663. # check whether a_ij is 0 or 1.
  664. sij_norm = 0
  665. for idx, g in enumerate(Gn):
  666. pi_i = pi_p_forward[idx][nd1]
  667. pi_j = pi_p_forward[idx][nd2]
  668. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  669. sij_norm += 1
  670. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  671. if not G_new.has_edge(nd1, nd2):
  672. G_new.add_edge(nd1, nd2)
  673. G_new.edges[nd1, nd2][edge_label] = best_label
  674. else:
  675. if G_new.has_edge(nd1, nd2):
  676. G_new.remove_edge(nd1, nd2)
  677. else: # if edges are unlabeled
  678. # @todo: works only for undirected graphs.
  679. for nd1 in range(nx.number_of_nodes(G)):
  680. for nd2 in range(nd1 + 1, nx.number_of_nodes(G)):
  681. sij_norm = 0
  682. for idx, g in enumerate(Gn):
  683. pi_i = pi_p_forward[idx][nd1]
  684. pi_j = pi_p_forward[idx][nd2]
  685. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  686. sij_norm += 1
  687. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  688. if not G_new.has_edge(nd1, nd2):
  689. G_new.add_edge(nd1, nd2)
  690. elif sij_norm < len(Gn) * c_er / (c_er + c_ei):
  691. if G_new.has_edge(nd1, nd2):
  692. G_new.remove_edge(nd1, nd2)
  693. # do not change anything when equal.
  694. G = G_new.copy()
  695. # update pi_p
  696. pi_p_forward = []
  697. for G_p in Gn:
  698. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  699. pi_p_forward.append(pi_tmp_forward)
  700. return G
  701. ###############################################################################
  702. if __name__ == '__main__':
  703. from pygraph.utils.graphfiles import loadDataset
  704. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  705. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  706. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  707. # 'extra_params': {}} # node nsymb
  708. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  709. # 'extra_params': {}}
  710. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  711. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.