You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. import sys
  12. #from Cython_GedLib_2 import librariesImport, script
  13. import librariesImport, script
  14. sys.path.insert(0, "../")
  15. from pygraph.utils.graphfiles import saveDataset
  16. from pygraph.utils.graphdataset import get_dataset_attributes
  17. from pygraph.utils.utils import graph_isIdentical, get_node_labels, get_edge_labels
  18. #from pygraph.utils.utils import graph_deepcopy
  19. def iam(Gn, c_ei=3, c_er=3, c_es=1, node_label='atom', edge_label='bond_type',
  20. connected=True):
  21. """See my name, then you know what I do.
  22. """
  23. # Gn = Gn[0:10]
  24. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  25. # phase 1: initilize.
  26. # compute set-median.
  27. dis_min = np.inf
  28. pi_p = []
  29. pi_all = []
  30. for idx1, G_p in enumerate(Gn):
  31. dist_sum = 0
  32. pi_all.append([])
  33. for idx2, G_p_prime in enumerate(Gn):
  34. dist_tmp, pi_tmp, _ = GED(G_p, G_p_prime)
  35. pi_all[idx1].append(pi_tmp)
  36. dist_sum += dist_tmp
  37. if dist_sum < dis_min:
  38. dis_min = dist_sum
  39. G = G_p.copy()
  40. idx_min = idx1
  41. # list of edit operations.
  42. pi_p = pi_all[idx_min]
  43. # phase 2: iteration.
  44. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  45. edge_label=edge_label)
  46. for itr in range(0, 10): # @todo: the convergence condition?
  47. G_new = G.copy()
  48. # update vertex labels.
  49. # pre-compute h_i0 for each label.
  50. # for label in get_node_labels(Gn, node_label):
  51. # print(label)
  52. # for nd in G.nodes(data=True):
  53. # pass
  54. if not ds_attrs['node_attr_dim']: # labels are symbolic
  55. for nd, _ in G.nodes(data=True):
  56. h_i0_list = []
  57. label_list = []
  58. for label in get_node_labels(Gn, node_label):
  59. h_i0 = 0
  60. for idx, g in enumerate(Gn):
  61. pi_i = pi_p[idx][nd]
  62. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  63. h_i0 += 1
  64. h_i0_list.append(h_i0)
  65. label_list.append(label)
  66. # choose one of the best randomly.
  67. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  68. idx_rdm = random.randint(0, len(idx_max) - 1)
  69. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  70. else: # labels are non-symbolic
  71. for nd, _ in G.nodes(data=True):
  72. Si_norm = 0
  73. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  74. for idx, g in enumerate(Gn):
  75. pi_i = pi_p[idx][nd]
  76. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  77. Si_norm += 1
  78. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  79. phi_i_bar /= Si_norm
  80. G_new.nodes[nd]['attributes'] = phi_i_bar
  81. # update edge labels and adjacency matrix.
  82. if ds_attrs['edge_labeled']:
  83. for nd1, nd2, _ in G.edges(data=True):
  84. h_ij0_list = []
  85. label_list = []
  86. for label in get_edge_labels(Gn, edge_label):
  87. h_ij0 = 0
  88. for idx, g in enumerate(Gn):
  89. pi_i = pi_p[idx][nd1]
  90. pi_j = pi_p[idx][nd2]
  91. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  92. g.has_edge(pi_i, pi_j) and
  93. g.edges[pi_i, pi_j][edge_label] == label)
  94. h_ij0 += h_ij0_p
  95. h_ij0_list.append(h_ij0)
  96. label_list.append(label)
  97. # choose one of the best randomly.
  98. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  99. h_ij0_max = h_ij0_list[idx_max[0]]
  100. idx_rdm = random.randint(0, len(idx_max) - 1)
  101. best_label = label_list[idx_max[idx_rdm]]
  102. # check whether a_ij is 0 or 1.
  103. sij_norm = 0
  104. for idx, g in enumerate(Gn):
  105. pi_i = pi_p[idx][nd1]
  106. pi_j = pi_p[idx][nd2]
  107. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  108. sij_norm += 1
  109. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  110. if not G_new.has_edge(nd1, nd2):
  111. G_new.add_edge(nd1, nd2)
  112. G_new.edges[nd1, nd2][edge_label] = best_label
  113. else:
  114. if G_new.has_edge(nd1, nd2):
  115. G_new.remove_edge(nd1, nd2)
  116. else: # if edges are unlabeled
  117. for nd1, nd2, _ in G.edges(data=True):
  118. sij_norm = 0
  119. for idx, g in enumerate(Gn):
  120. pi_i = pi_p[idx][nd1]
  121. pi_j = pi_p[idx][nd2]
  122. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  123. sij_norm += 1
  124. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  125. if not G_new.has_edge(nd1, nd2):
  126. G_new.add_edge(nd1, nd2)
  127. else:
  128. if G_new.has_edge(nd1, nd2):
  129. G_new.remove_edge(nd1, nd2)
  130. G = G_new.copy()
  131. # update pi_p
  132. pi_p = []
  133. for idx1, G_p in enumerate(Gn):
  134. dist_tmp, pi_tmp, _ = GED(G, G_p)
  135. pi_p.append(pi_tmp)
  136. return G
  137. def GED(g1, g2, lib='gedlib'):
  138. """
  139. Compute GED.
  140. """
  141. if lib == 'gedlib':
  142. # transform dataset to the 'xml' file as the GedLib required.
  143. saveDataset([g1, g2], [None, None], group='xml', filename='ged_tmp/tmp')
  144. # script.appel()
  145. script.PyRestartEnv()
  146. script.PyLoadGXLGraph('ged_tmp/', 'ged_tmp/tmp.xml')
  147. listID = script.PyGetGraphIds()
  148. script.PySetEditCost("LETTER") #("CHEM_1")
  149. script.PyInitEnv()
  150. script.PySetMethod("IPFP", "")
  151. script.PyInitMethod()
  152. g = listID[0]
  153. h = listID[1]
  154. script.PyRunMethod(g, h)
  155. pi_forward, pi_backward = script.PyGetAllMap(g, h)
  156. upper = script.PyGetUpperBound(g, h)
  157. lower = script.PyGetLowerBound(g, h)
  158. dis = upper
  159. # make the map label correct (label remove map as np.inf)
  160. nodes1 = [n for n in g1.nodes()]
  161. nodes2 = [n for n in g2.nodes()]
  162. nb1 = nx.number_of_nodes(g1)
  163. nb2 = nx.number_of_nodes(g2)
  164. pi_forward = [nodes2[pi] if pi < nb2 else np.inf for pi in pi_forward]
  165. pi_backward = [nodes1[pi] if pi < nb1 else np.inf for pi in pi_backward]
  166. return dis, pi_forward, pi_backward
  167. # --------------------------- These are tests --------------------------------#
  168. def test_iam_with_more_graphs_as_init(Gn, G_candidate, c_ei=3, c_er=3, c_es=1,
  169. node_label='atom', edge_label='bond_type'):
  170. """See my name, then you know what I do.
  171. """
  172. from tqdm import tqdm
  173. # Gn = Gn[0:10]
  174. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  175. # phase 1: initilize.
  176. # compute set-median.
  177. dis_min = np.inf
  178. # pi_p = []
  179. pi_all_forward = []
  180. pi_all_backward = []
  181. for idx1, G_p in tqdm(enumerate(G_candidate), desc='computing GEDs', file=sys.stdout):
  182. dist_sum = 0
  183. pi_all_forward.append([])
  184. pi_all_backward.append([])
  185. for idx2, G_p_prime in enumerate(Gn):
  186. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_p, G_p_prime)
  187. pi_all_forward[idx1].append(pi_tmp_forward)
  188. pi_all_backward[idx1].append(pi_tmp_backward)
  189. dist_sum += dist_tmp
  190. if dist_sum <= dis_min:
  191. dis_min = dist_sum
  192. G = G_p.copy()
  193. idx_min = idx1
  194. # list of edit operations.
  195. pi_p_forward = pi_all_forward[idx_min]
  196. pi_p_backward = pi_all_backward[idx_min]
  197. # phase 2: iteration.
  198. ds_attrs = get_dataset_attributes(Gn + [G], attr_names=['edge_labeled', 'node_attr_dim'],
  199. edge_label=edge_label)
  200. label_set = get_node_labels(Gn + [G], node_label)
  201. for itr in range(0, 10): # @todo: the convergence condition?
  202. G_new = G.copy()
  203. # update vertex labels.
  204. # pre-compute h_i0 for each label.
  205. # for label in get_node_labels(Gn, node_label):
  206. # print(label)
  207. # for nd in G.nodes(data=True):
  208. # pass
  209. if not ds_attrs['node_attr_dim']: # labels are symbolic
  210. for nd in G.nodes():
  211. h_i0_list = []
  212. label_list = []
  213. for label in label_set:
  214. h_i0 = 0
  215. for idx, g in enumerate(Gn):
  216. pi_i = pi_p_forward[idx][nd]
  217. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  218. h_i0 += 1
  219. h_i0_list.append(h_i0)
  220. label_list.append(label)
  221. # choose one of the best randomly.
  222. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  223. idx_rdm = random.randint(0, len(idx_max) - 1)
  224. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  225. else: # labels are non-symbolic
  226. for nd in G.nodes():
  227. Si_norm = 0
  228. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  229. for idx, g in enumerate(Gn):
  230. pi_i = pi_p_forward[idx][nd]
  231. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  232. Si_norm += 1
  233. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  234. phi_i_bar /= Si_norm
  235. G_new.nodes[nd]['attributes'] = phi_i_bar
  236. # update edge labels and adjacency matrix.
  237. if ds_attrs['edge_labeled']:
  238. for nd1, nd2, _ in G.edges(data=True):
  239. h_ij0_list = []
  240. label_list = []
  241. for label in get_edge_labels(Gn, edge_label):
  242. h_ij0 = 0
  243. for idx, g in enumerate(Gn):
  244. pi_i = pi_p_forward[idx][nd1]
  245. pi_j = pi_p_forward[idx][nd2]
  246. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  247. g.has_edge(pi_i, pi_j) and
  248. g.edges[pi_i, pi_j][edge_label] == label)
  249. h_ij0 += h_ij0_p
  250. h_ij0_list.append(h_ij0)
  251. label_list.append(label)
  252. # choose one of the best randomly.
  253. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  254. h_ij0_max = h_ij0_list[idx_max[0]]
  255. idx_rdm = random.randint(0, len(idx_max) - 1)
  256. best_label = label_list[idx_max[idx_rdm]]
  257. # check whether a_ij is 0 or 1.
  258. sij_norm = 0
  259. for idx, g in enumerate(Gn):
  260. pi_i = pi_p_forward[idx][nd1]
  261. pi_j = pi_p_forward[idx][nd2]
  262. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  263. sij_norm += 1
  264. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  265. if not G_new.has_edge(nd1, nd2):
  266. G_new.add_edge(nd1, nd2)
  267. G_new.edges[nd1, nd2][edge_label] = best_label
  268. else:
  269. if G_new.has_edge(nd1, nd2):
  270. G_new.remove_edge(nd1, nd2)
  271. else: # if edges are unlabeled
  272. # @todo: works only for undirected graphs.
  273. for nd1 in range(nx.number_of_nodes(G)):
  274. for nd2 in range(nd1 + 1, nx.number_of_nodes(G)):
  275. sij_norm = 0
  276. for idx, g in enumerate(Gn):
  277. pi_i = pi_p_forward[idx][nd1]
  278. pi_j = pi_p_forward[idx][nd2]
  279. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  280. sij_norm += 1
  281. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  282. if not G_new.has_edge(nd1, nd2):
  283. G_new.add_edge(nd1, nd2)
  284. elif sij_norm < len(Gn) * c_er / (c_er + c_ei):
  285. if G_new.has_edge(nd1, nd2):
  286. G_new.remove_edge(nd1, nd2)
  287. # do not change anything when equal.
  288. G = G_new.copy()
  289. # update pi_p
  290. pi_p_forward = []
  291. for G_p in Gn:
  292. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  293. pi_p_forward.append(pi_tmp_forward)
  294. return G
  295. def test_iam_moreGraphsAsInit_tryAllPossibleBestGraphs_deleteNodesInIterations(
  296. Gn_median, Gn_candidate, c_ei=3, c_er=3, c_es=1, node_label='atom',
  297. edge_label='bond_type', connected=True):
  298. """See my name, then you know what I do.
  299. """
  300. from tqdm import tqdm
  301. # Gn_median = Gn_median[0:10]
  302. # Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
  303. node_ir = np.inf # corresponding to the node remove and insertion.
  304. label_r = 'thanksdanny' # the label for node remove. # @todo: make this label unrepeatable.
  305. ds_attrs = get_dataset_attributes(Gn_median + Gn_candidate,
  306. attr_names=['edge_labeled', 'node_attr_dim'],
  307. edge_label=edge_label)
  308. def generate_graph(G, pi_p_forward, label_set):
  309. G_new_list = [G.copy()] # all "best" graphs generated in this iteration.
  310. # nx.draw_networkx(G)
  311. # import matplotlib.pyplot as plt
  312. # plt.show()
  313. # print(pi_p_forward)
  314. # update vertex labels.
  315. # pre-compute h_i0 for each label.
  316. # for label in get_node_labels(Gn, node_label):
  317. # print(label)
  318. # for nd in G.nodes(data=True):
  319. # pass
  320. if not ds_attrs['node_attr_dim']: # labels are symbolic
  321. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  322. h_i0_list = []
  323. label_list = []
  324. for label in label_set:
  325. h_i0 = 0
  326. for idx, g in enumerate(Gn_median):
  327. pi_i = pi_p_forward[idx][ndi]
  328. if pi_i != node_ir and g.nodes[pi_i][node_label] == label:
  329. h_i0 += 1
  330. h_i0_list.append(h_i0)
  331. label_list.append(label)
  332. # case when the node is to be removed.
  333. h_i0_remove = 0
  334. for idx, g in enumerate(Gn_median):
  335. pi_i = pi_p_forward[idx][ndi]
  336. if pi_i == node_ir:
  337. h_i0_remove += 1
  338. h_i0_list.append(h_i0_remove)
  339. label_list.append(label_r)
  340. # get the best labels.
  341. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  342. nlabel_best = [label_list[idx] for idx in idx_max]
  343. # generate "best" graphs with regard to "best" node labels.
  344. G_new_list_nd = []
  345. for g in G_new_list: # @todo: seems it can be simplified. The G_new_list will only contain 1 graph for now.
  346. for nl in nlabel_best:
  347. g_tmp = g.copy()
  348. if nl == label_r:
  349. g_tmp.remove_node(nd)
  350. else:
  351. g_tmp.nodes[nd][node_label] = nl
  352. G_new_list_nd.append(g_tmp)
  353. # nx.draw_networkx(g_tmp)
  354. # import matplotlib.pyplot as plt
  355. # plt.show()
  356. # print(g_tmp.nodes(data=True))
  357. # print(g_tmp.edges(data=True))
  358. G_new_list = G_new_list_nd[:]
  359. else: # labels are non-symbolic
  360. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  361. Si_norm = 0
  362. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  363. for idx, g in enumerate(Gn_median):
  364. pi_i = pi_p_forward[idx][ndi]
  365. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  366. Si_norm += 1
  367. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  368. phi_i_bar /= Si_norm
  369. G_new_list[0].nodes[nd]['attributes'] = phi_i_bar
  370. # update edge labels and adjacency matrix.
  371. if ds_attrs['edge_labeled']:
  372. for nd1, nd2, _ in G.edges(data=True):
  373. h_ij0_list = []
  374. label_list = []
  375. for label in get_edge_labels(Gn_median, edge_label):
  376. h_ij0 = 0
  377. for idx, g in enumerate(Gn_median):
  378. pi_i = pi_p_forward[idx][nd1]
  379. pi_j = pi_p_forward[idx][nd2]
  380. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  381. g.has_edge(pi_i, pi_j) and
  382. g.edges[pi_i, pi_j][edge_label] == label)
  383. h_ij0 += h_ij0_p
  384. h_ij0_list.append(h_ij0)
  385. label_list.append(label)
  386. # choose one of the best randomly.
  387. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  388. h_ij0_max = h_ij0_list[idx_max[0]]
  389. idx_rdm = random.randint(0, len(idx_max) - 1)
  390. best_label = label_list[idx_max[idx_rdm]]
  391. # check whether a_ij is 0 or 1.
  392. sij_norm = 0
  393. for idx, g in enumerate(Gn_median):
  394. pi_i = pi_p_forward[idx][nd1]
  395. pi_j = pi_p_forward[idx][nd2]
  396. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  397. sij_norm += 1
  398. if h_ij0_max > len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  399. if not G_new.has_edge(nd1, nd2):
  400. G_new.add_edge(nd1, nd2)
  401. G_new.edges[nd1, nd2][edge_label] = best_label
  402. else:
  403. if G_new.has_edge(nd1, nd2):
  404. G_new.remove_edge(nd1, nd2)
  405. else: # if edges are unlabeled
  406. # @todo: works only for undirected graphs.
  407. nd_list = [n for n in G.nodes()]
  408. for g_tmp in G_new_list:
  409. for nd1i in range(nx.number_of_nodes(G)):
  410. nd1 = nd_list[nd1i]
  411. for nd2i in range(nd1i + 1, nx.number_of_nodes(G)):
  412. nd2 = nd_list[nd2i]
  413. sij_norm = 0
  414. for idx, g in enumerate(Gn_median):
  415. pi_i = pi_p_forward[idx][nd1i]
  416. pi_j = pi_p_forward[idx][nd2i]
  417. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  418. sij_norm += 1
  419. if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
  420. # @todo: should we consider if nd1 and nd2 in g_tmp?
  421. # or just add the edge anyway?
  422. if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
  423. and not g_tmp.has_edge(nd1, nd2):
  424. g_tmp.add_edge(nd1, nd2)
  425. elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
  426. if g_tmp.has_edge(nd1, nd2):
  427. g_tmp.remove_edge(nd1, nd2)
  428. # do not change anything when equal.
  429. # find the best graph generated in this iteration and update pi_p.
  430. # @todo: should we update all graphs generated or just the best ones?
  431. dis_list, pi_forward_list = median_distance(G_new_list, Gn_median)
  432. # @todo: should we remove the identical and connectivity check?
  433. # Don't know which is faster.
  434. G_new_list, idx_list = remove_duplicates(G_new_list)
  435. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  436. # if connected == True:
  437. # G_new_list, idx_list = remove_disconnected(G_new_list)
  438. # pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  439. # idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  440. # dis_min = dis_list[idx_min_tmp_list[0]]
  441. # pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
  442. # G_new_list = [G_new_list[idx] for idx in idx_min_list]
  443. # for g in G_new_list:
  444. # import matplotlib.pyplot as plt
  445. # nx.draw_networkx(g)
  446. # plt.show()
  447. # print(g.nodes(data=True))
  448. # print(g.edges(data=True))
  449. return G_new_list, pi_forward_list
  450. def median_distance(Gn, Gn_median, measure='ged', verbose=False):
  451. dis_list = []
  452. pi_forward_list = []
  453. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  454. file=sys.stdout) if verbose else enumerate(Gn):
  455. dis_sum = 0
  456. pi_forward_list.append([])
  457. for G_p in Gn_median:
  458. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  459. pi_forward_list[idx].append(pi_tmp_forward)
  460. dis_sum += dis_tmp
  461. dis_list.append(dis_sum)
  462. return dis_list, pi_forward_list
  463. def best_median_graphs(Gn_candidate, dis_all, pi_all_forward):
  464. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  465. dis_min = dis_all[idx_min_list[0]]
  466. pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
  467. G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
  468. return G_min_list, pi_forward_min_list, dis_min
  469. def iteration_proc(G, pi_p_forward):
  470. G_list = [G]
  471. pi_forward_list = [pi_p_forward]
  472. # iterations.
  473. for itr in range(0, 5): # @todo: the convergence condition?
  474. # print('itr is', itr)
  475. G_new_list = []
  476. pi_forward_new_list = []
  477. for idx, G in enumerate(G_list):
  478. label_set = get_node_labels(Gn_median + [G], node_label)
  479. G_tmp_list, pi_forward_tmp_list = generate_graph(
  480. G, pi_forward_list[idx], label_set)
  481. G_new_list += G_tmp_list
  482. pi_forward_new_list += pi_forward_tmp_list
  483. G_list = G_new_list[:]
  484. pi_forward_list = pi_forward_new_list[:]
  485. G_list, idx_list = remove_duplicates(G_list)
  486. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  487. # import matplotlib.pyplot as plt
  488. # for g in G_list:
  489. # nx.draw_networkx(g)
  490. # plt.show()
  491. # print(g.nodes(data=True))
  492. # print(g.edges(data=True))
  493. return G_list, pi_forward_list # do we return all graphs or the best ones?
  494. def remove_duplicates(Gn):
  495. """Remove duplicate graphs from list.
  496. """
  497. Gn_new = []
  498. idx_list = []
  499. for idx, g in enumerate(Gn):
  500. dupl = False
  501. for g_new in Gn_new:
  502. if graph_isIdentical(g_new, g):
  503. dupl = True
  504. break
  505. if not dupl:
  506. Gn_new.append(g)
  507. idx_list.append(idx)
  508. return Gn_new, idx_list
  509. def remove_disconnected(Gn):
  510. """Remove disconnected graphs from list.
  511. """
  512. Gn_new = []
  513. idx_list = []
  514. for idx, g in enumerate(Gn):
  515. if nx.is_connected(g):
  516. Gn_new.append(g)
  517. idx_list.append(idx)
  518. return Gn_new, idx_list
  519. # phase 1: initilize.
  520. # compute set-median.
  521. dis_min = np.inf
  522. dis_all, pi_all_forward = median_distance(Gn_candidate, Gn_median)
  523. # find all smallest distances.
  524. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  525. dis_min = dis_all[idx_min_list[0]]
  526. # phase 2: iteration.
  527. G_list = []
  528. for idx_min in idx_min_list[::-1]:
  529. # print('idx_min is', idx_min)
  530. G = Gn_candidate[idx_min].copy()
  531. # list of edit operations.
  532. pi_p_forward = pi_all_forward[idx_min]
  533. # pi_p_backward = pi_all_backward[idx_min]
  534. Gi_list, pi_i_forward_list = iteration_proc(G, pi_p_forward)
  535. G_list += Gi_list
  536. G_list, _ = remove_duplicates(G_list)
  537. if connected == True:
  538. G_list_con, _ = remove_disconnected(G_list)
  539. # if there is no connected graphs at all, then remain the disconnected ones.
  540. if len(G_list_con) > 0: # @todo: ??????????????????????????
  541. G_list = G_list_con
  542. # import matplotlib.pyplot as plt
  543. # for g in G_list:
  544. # nx.draw_networkx(g)
  545. # plt.show()
  546. # print(g.nodes(data=True))
  547. # print(g.edges(data=True))
  548. # get the best median graphs
  549. dis_all, pi_all_forward = median_distance(G_list, Gn_median)
  550. G_min_list, pi_forward_min_list, dis_min = best_median_graphs(
  551. G_list, dis_all, pi_all_forward)
  552. # for g in G_min_list:
  553. # nx.draw_networkx(g)
  554. # plt.show()
  555. # print(g.nodes(data=True))
  556. # print(g.edges(data=True))
  557. return G_min_list
  558. if __name__ == '__main__':
  559. from pygraph.utils.graphfiles import loadDataset
  560. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  561. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  562. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  563. # 'extra_params': {}} # node nsymb
  564. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  565. # 'extra_params': {}}
  566. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  567. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.