You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. import sys
  12. #from Cython_GedLib_2 import librariesImport, script
  13. import librariesImport, script
  14. sys.path.insert(0, "../")
  15. from pygraph.utils.graphfiles import saveDataset
  16. from pygraph.utils.graphdataset import get_dataset_attributes
  17. from pygraph.utils.utils import graph_isIdentical
  18. #from pygraph.utils.utils import graph_deepcopy
  19. def iam(Gn, c_ei=3, c_er=3, c_es=1, node_label='atom', edge_label='bond_type',
  20. connected=True):
  21. """See my name, then you know what I do.
  22. """
  23. # Gn = Gn[0:10]
  24. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  25. # phase 1: initilize.
  26. # compute set-median.
  27. dis_min = np.inf
  28. pi_p = []
  29. pi_all = []
  30. for idx1, G_p in enumerate(Gn):
  31. dist_sum = 0
  32. pi_all.append([])
  33. for idx2, G_p_prime in enumerate(Gn):
  34. dist_tmp, pi_tmp, _ = GED(G_p, G_p_prime)
  35. pi_all[idx1].append(pi_tmp)
  36. dist_sum += dist_tmp
  37. if dist_sum < dis_min:
  38. dis_min = dist_sum
  39. G = G_p.copy()
  40. idx_min = idx1
  41. # list of edit operations.
  42. pi_p = pi_all[idx_min]
  43. # phase 2: iteration.
  44. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  45. edge_label=edge_label)
  46. for itr in range(0, 10): # @todo: the convergence condition?
  47. G_new = G.copy()
  48. # update vertex labels.
  49. # pre-compute h_i0 for each label.
  50. # for label in get_node_labels(Gn, node_label):
  51. # print(label)
  52. # for nd in G.nodes(data=True):
  53. # pass
  54. if not ds_attrs['node_attr_dim']: # labels are symbolic
  55. for nd, _ in G.nodes(data=True):
  56. h_i0_list = []
  57. label_list = []
  58. for label in get_node_labels(Gn, node_label):
  59. h_i0 = 0
  60. for idx, g in enumerate(Gn):
  61. pi_i = pi_p[idx][nd]
  62. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  63. h_i0 += 1
  64. h_i0_list.append(h_i0)
  65. label_list.append(label)
  66. # choose one of the best randomly.
  67. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  68. idx_rdm = random.randint(0, len(idx_max) - 1)
  69. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  70. else: # labels are non-symbolic
  71. for nd, _ in G.nodes(data=True):
  72. Si_norm = 0
  73. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  74. for idx, g in enumerate(Gn):
  75. pi_i = pi_p[idx][nd]
  76. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  77. Si_norm += 1
  78. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  79. phi_i_bar /= Si_norm
  80. G_new.nodes[nd]['attributes'] = phi_i_bar
  81. # update edge labels and adjacency matrix.
  82. if ds_attrs['edge_labeled']:
  83. for nd1, nd2, _ in G.edges(data=True):
  84. h_ij0_list = []
  85. label_list = []
  86. for label in get_edge_labels(Gn, edge_label):
  87. h_ij0 = 0
  88. for idx, g in enumerate(Gn):
  89. pi_i = pi_p[idx][nd1]
  90. pi_j = pi_p[idx][nd2]
  91. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  92. g.has_edge(pi_i, pi_j) and
  93. g.edges[pi_i, pi_j][edge_label] == label)
  94. h_ij0 += h_ij0_p
  95. h_ij0_list.append(h_ij0)
  96. label_list.append(label)
  97. # choose one of the best randomly.
  98. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  99. h_ij0_max = h_ij0_list[idx_max[0]]
  100. idx_rdm = random.randint(0, len(idx_max) - 1)
  101. best_label = label_list[idx_max[idx_rdm]]
  102. # check whether a_ij is 0 or 1.
  103. sij_norm = 0
  104. for idx, g in enumerate(Gn):
  105. pi_i = pi_p[idx][nd1]
  106. pi_j = pi_p[idx][nd2]
  107. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  108. sij_norm += 1
  109. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  110. if not G_new.has_edge(nd1, nd2):
  111. G_new.add_edge(nd1, nd2)
  112. G_new.edges[nd1, nd2][edge_label] = best_label
  113. else:
  114. if G_new.has_edge(nd1, nd2):
  115. G_new.remove_edge(nd1, nd2)
  116. else: # if edges are unlabeled
  117. for nd1, nd2, _ in G.edges(data=True):
  118. sij_norm = 0
  119. for idx, g in enumerate(Gn):
  120. pi_i = pi_p[idx][nd1]
  121. pi_j = pi_p[idx][nd2]
  122. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  123. sij_norm += 1
  124. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  125. if not G_new.has_edge(nd1, nd2):
  126. G_new.add_edge(nd1, nd2)
  127. else:
  128. if G_new.has_edge(nd1, nd2):
  129. G_new.remove_edge(nd1, nd2)
  130. G = G_new.copy()
  131. # update pi_p
  132. pi_p = []
  133. for idx1, G_p in enumerate(Gn):
  134. dist_tmp, pi_tmp, _ = GED(G, G_p)
  135. pi_p.append(pi_tmp)
  136. return G
  137. def GED(g1, g2, lib='gedlib'):
  138. """
  139. Compute GED.
  140. """
  141. if lib == 'gedlib':
  142. # transform dataset to the 'xml' file as the GedLib required.
  143. saveDataset([g1, g2], [None, None], group='xml', filename='ged_tmp/tmp')
  144. # script.appel()
  145. script.PyRestartEnv()
  146. script.PyLoadGXLGraph('ged_tmp/', 'ged_tmp/tmp.xml')
  147. listID = script.PyGetGraphIds()
  148. script.PySetEditCost("CHEM_2")
  149. script.PyInitEnv()
  150. script.PySetMethod("BIPARTITE", "")
  151. script.PyInitMethod()
  152. g = listID[0]
  153. h = listID[1]
  154. script.PyRunMethod(g, h)
  155. pi_forward, pi_backward = script.PyGetAllMap(g, h)
  156. upper = script.PyGetUpperBound(g, h)
  157. lower = script.PyGetLowerBound(g, h)
  158. dis = (upper + lower) / 2
  159. return dis, pi_forward, pi_backward
  160. def get_node_labels(Gn, node_label):
  161. nl = set()
  162. for G in Gn:
  163. nl = nl | set(nx.get_node_attributes(G, node_label).values())
  164. return nl
  165. def get_edge_labels(Gn, edge_label):
  166. el = set()
  167. for G in Gn:
  168. el = el | set(nx.get_edge_attributes(G, edge_label).values())
  169. return el
  170. # --------------------------- These are tests --------------------------------#
  171. def test_iam_with_more_graphs_as_init(Gn, G_candidate, c_ei=3, c_er=3, c_es=1,
  172. node_label='atom', edge_label='bond_type'):
  173. """See my name, then you know what I do.
  174. """
  175. from tqdm import tqdm
  176. # Gn = Gn[0:10]
  177. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  178. # phase 1: initilize.
  179. # compute set-median.
  180. dis_min = np.inf
  181. # pi_p = []
  182. pi_all_forward = []
  183. pi_all_backward = []
  184. for idx1, G_p in tqdm(enumerate(G_candidate), desc='computing GEDs', file=sys.stdout):
  185. dist_sum = 0
  186. pi_all_forward.append([])
  187. pi_all_backward.append([])
  188. for idx2, G_p_prime in enumerate(Gn):
  189. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G_p, G_p_prime)
  190. pi_all_forward[idx1].append(pi_tmp_forward)
  191. pi_all_backward[idx1].append(pi_tmp_backward)
  192. dist_sum += dist_tmp
  193. if dist_sum <= dis_min:
  194. dis_min = dist_sum
  195. G = G_p.copy()
  196. idx_min = idx1
  197. # list of edit operations.
  198. pi_p_forward = pi_all_forward[idx_min]
  199. pi_p_backward = pi_all_backward[idx_min]
  200. # phase 2: iteration.
  201. ds_attrs = get_dataset_attributes(Gn + [G], attr_names=['edge_labeled', 'node_attr_dim'],
  202. edge_label=edge_label)
  203. label_set = get_node_labels(Gn + [G], node_label)
  204. for itr in range(0, 10): # @todo: the convergence condition?
  205. G_new = G.copy()
  206. # update vertex labels.
  207. # pre-compute h_i0 for each label.
  208. # for label in get_node_labels(Gn, node_label):
  209. # print(label)
  210. # for nd in G.nodes(data=True):
  211. # pass
  212. if not ds_attrs['node_attr_dim']: # labels are symbolic
  213. for nd in G.nodes():
  214. h_i0_list = []
  215. label_list = []
  216. for label in label_set:
  217. h_i0 = 0
  218. for idx, g in enumerate(Gn):
  219. pi_i = pi_p_forward[idx][nd]
  220. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  221. h_i0 += 1
  222. h_i0_list.append(h_i0)
  223. label_list.append(label)
  224. # choose one of the best randomly.
  225. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  226. idx_rdm = random.randint(0, len(idx_max) - 1)
  227. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  228. else: # labels are non-symbolic
  229. for nd in G.nodes():
  230. Si_norm = 0
  231. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  232. for idx, g in enumerate(Gn):
  233. pi_i = pi_p_forward[idx][nd]
  234. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  235. Si_norm += 1
  236. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  237. phi_i_bar /= Si_norm
  238. G_new.nodes[nd]['attributes'] = phi_i_bar
  239. # update edge labels and adjacency matrix.
  240. if ds_attrs['edge_labeled']:
  241. for nd1, nd2, _ in G.edges(data=True):
  242. h_ij0_list = []
  243. label_list = []
  244. for label in get_edge_labels(Gn, edge_label):
  245. h_ij0 = 0
  246. for idx, g in enumerate(Gn):
  247. pi_i = pi_p_forward[idx][nd1]
  248. pi_j = pi_p_forward[idx][nd2]
  249. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  250. g.has_edge(pi_i, pi_j) and
  251. g.edges[pi_i, pi_j][edge_label] == label)
  252. h_ij0 += h_ij0_p
  253. h_ij0_list.append(h_ij0)
  254. label_list.append(label)
  255. # choose one of the best randomly.
  256. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  257. h_ij0_max = h_ij0_list[idx_max[0]]
  258. idx_rdm = random.randint(0, len(idx_max) - 1)
  259. best_label = label_list[idx_max[idx_rdm]]
  260. # check whether a_ij is 0 or 1.
  261. sij_norm = 0
  262. for idx, g in enumerate(Gn):
  263. pi_i = pi_p_forward[idx][nd1]
  264. pi_j = pi_p_forward[idx][nd2]
  265. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  266. sij_norm += 1
  267. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  268. if not G_new.has_edge(nd1, nd2):
  269. G_new.add_edge(nd1, nd2)
  270. G_new.edges[nd1, nd2][edge_label] = best_label
  271. else:
  272. if G_new.has_edge(nd1, nd2):
  273. G_new.remove_edge(nd1, nd2)
  274. else: # if edges are unlabeled
  275. # @todo: works only for undirected graphs.
  276. for nd1 in range(nx.number_of_nodes(G)):
  277. for nd2 in range(nd1 + 1, nx.number_of_nodes(G)):
  278. sij_norm = 0
  279. for idx, g in enumerate(Gn):
  280. pi_i = pi_p_forward[idx][nd1]
  281. pi_j = pi_p_forward[idx][nd2]
  282. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  283. sij_norm += 1
  284. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  285. if not G_new.has_edge(nd1, nd2):
  286. G_new.add_edge(nd1, nd2)
  287. elif sij_norm < len(Gn) * c_er / (c_er + c_ei):
  288. if G_new.has_edge(nd1, nd2):
  289. G_new.remove_edge(nd1, nd2)
  290. # do not change anything when equal.
  291. G = G_new.copy()
  292. # update pi_p
  293. pi_p_forward = []
  294. for G_p in Gn:
  295. dist_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  296. pi_p_forward.append(pi_tmp_forward)
  297. return G
  298. def test_iam_moreGraphsAsInit_tryAllPossibleBestGraphs_deleteNodesInIterations(
  299. Gn_median, Gn_candidate, c_ei=3, c_er=3, c_es=1, node_label='atom',
  300. edge_label='bond_type', connected=True):
  301. """See my name, then you know what I do.
  302. """
  303. from tqdm import tqdm
  304. # Gn_median = Gn_median[0:10]
  305. # Gn_median = [nx.convert_node_labels_to_integers(g) for g in Gn_median]
  306. node_ir = sys.maxsize * 2 # Max number for c++, corresponding to the node remove and insertion.
  307. label_r = 'thanksdanny' # the label for node remove. # @todo: make this label unrepeatable.
  308. ds_attrs = get_dataset_attributes(Gn_median + Gn_candidate,
  309. attr_names=['edge_labeled', 'node_attr_dim'],
  310. edge_label=edge_label)
  311. def generate_graph(G, pi_p_forward, label_set):
  312. G_new_list = [G.copy()] # all "best" graphs generated in this iteration.
  313. # nx.draw_networkx(G)
  314. # import matplotlib.pyplot as plt
  315. # plt.show()
  316. # print(pi_p_forward)
  317. # update vertex labels.
  318. # pre-compute h_i0 for each label.
  319. # for label in get_node_labels(Gn, node_label):
  320. # print(label)
  321. # for nd in G.nodes(data=True):
  322. # pass
  323. if not ds_attrs['node_attr_dim']: # labels are symbolic
  324. for ndi, (nd, _) in enumerate(G.nodes(data=True)):
  325. h_i0_list = []
  326. label_list = []
  327. for label in label_set:
  328. h_i0 = 0
  329. for idx, g in enumerate(Gn_median):
  330. pi_i = pi_p_forward[idx][ndi]
  331. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  332. h_i0 += 1
  333. h_i0_list.append(h_i0)
  334. label_list.append(label)
  335. # case when the node is to be removed.
  336. h_i0_remove = 0
  337. for idx, g in enumerate(Gn_median):
  338. pi_i = pi_p_forward[idx][ndi]
  339. if pi_i == node_ir:
  340. h_i0_remove += 1
  341. h_i0_list.append(h_i0_remove)
  342. label_list.append(label_r)
  343. # get the best labels.
  344. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  345. nlabel_best = [label_list[idx] for idx in idx_max]
  346. # generate "best" graphs with regard to "best" node labels.
  347. G_new_list_nd = []
  348. for g in G_new_list:
  349. for nl in nlabel_best:
  350. g_tmp = g.copy()
  351. if nl == label_r:
  352. g_tmp.remove_node(nd)
  353. else:
  354. g_tmp.nodes[nd][node_label] = nl
  355. G_new_list_nd.append(g_tmp)
  356. # nx.draw_networkx(g_tmp)
  357. # import matplotlib.pyplot as plt
  358. # plt.show()
  359. # print(g_tmp.nodes(data=True))
  360. # print(g_tmp.edges(data=True))
  361. G_new_list = G_new_list_nd[:]
  362. else: # labels are non-symbolic
  363. for nd in G.nodes():
  364. Si_norm = 0
  365. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  366. for idx, g in enumerate(Gn_median):
  367. pi_i = pi_p_forward[idx][nd]
  368. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  369. Si_norm += 1
  370. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  371. phi_i_bar /= Si_norm
  372. G_new.nodes[nd]['attributes'] = phi_i_bar
  373. # update edge labels and adjacency matrix.
  374. if ds_attrs['edge_labeled']:
  375. for nd1, nd2, _ in G.edges(data=True):
  376. h_ij0_list = []
  377. label_list = []
  378. for label in get_edge_labels(Gn_median, edge_label):
  379. h_ij0 = 0
  380. for idx, g in enumerate(Gn_median):
  381. pi_i = pi_p_forward[idx][nd1]
  382. pi_j = pi_p_forward[idx][nd2]
  383. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  384. g.has_edge(pi_i, pi_j) and
  385. g.edges[pi_i, pi_j][edge_label] == label)
  386. h_ij0 += h_ij0_p
  387. h_ij0_list.append(h_ij0)
  388. label_list.append(label)
  389. # choose one of the best randomly.
  390. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  391. h_ij0_max = h_ij0_list[idx_max[0]]
  392. idx_rdm = random.randint(0, len(idx_max) - 1)
  393. best_label = label_list[idx_max[idx_rdm]]
  394. # check whether a_ij is 0 or 1.
  395. sij_norm = 0
  396. for idx, g in enumerate(Gn_median):
  397. pi_i = pi_p_forward[idx][nd1]
  398. pi_j = pi_p_forward[idx][nd2]
  399. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  400. sij_norm += 1
  401. if h_ij0_max > len(Gn_median) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  402. if not G_new.has_edge(nd1, nd2):
  403. G_new.add_edge(nd1, nd2)
  404. G_new.edges[nd1, nd2][edge_label] = best_label
  405. else:
  406. if G_new.has_edge(nd1, nd2):
  407. G_new.remove_edge(nd1, nd2)
  408. else: # if edges are unlabeled
  409. # @todo: works only for undirected graphs.
  410. nd_list = [n for n in G.nodes()]
  411. for g_tmp in G_new_list:
  412. for nd1i in range(nx.number_of_nodes(G)):
  413. nd1 = nd_list[nd1i]
  414. for nd2i in range(nd1i + 1, nx.number_of_nodes(G)):
  415. nd2 = nd_list[nd2i]
  416. sij_norm = 0
  417. for idx, g in enumerate(Gn_median):
  418. pi_i = pi_p_forward[idx][nd1i]
  419. pi_j = pi_p_forward[idx][nd2i]
  420. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  421. sij_norm += 1
  422. if sij_norm > len(Gn_median) * c_er / (c_er + c_ei):
  423. # @todo: should we consider if nd1 and nd2 in g_tmp?
  424. # or just add the edge anyway?
  425. if g_tmp.has_node(nd1) and g_tmp.has_node(nd2) \
  426. and not g_tmp.has_edge(nd1, nd2):
  427. g_tmp.add_edge(nd1, nd2)
  428. elif sij_norm < len(Gn_median) * c_er / (c_er + c_ei):
  429. if g_tmp.has_edge(nd1, nd2):
  430. g_tmp.remove_edge(nd1, nd2)
  431. # do not change anything when equal.
  432. # find the best graph generated in this iteration and update pi_p.
  433. # @todo: should we update all graphs generated or just the best ones?
  434. dis_list, pi_forward_list = median_distance(G_new_list, Gn_median)
  435. # @todo: should we remove the identical and connectivity check?
  436. # Don't know which is faster.
  437. G_new_list, idx_list = remove_duplicates(G_new_list)
  438. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  439. # if connected == True:
  440. # G_new_list, idx_list = remove_disconnected(G_new_list)
  441. # pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  442. # idx_min_list = np.argwhere(dis_list == np.min(dis_list)).flatten().tolist()
  443. # dis_min = dis_list[idx_min_tmp_list[0]]
  444. # pi_forward_list = [pi_forward_list[idx] for idx in idx_min_list]
  445. # G_new_list = [G_new_list[idx] for idx in idx_min_list]
  446. for g in G_new_list:
  447. import matplotlib.pyplot as plt
  448. nx.draw_networkx(g)
  449. plt.show()
  450. print(g.nodes(data=True))
  451. print(g.edges(data=True))
  452. return G_new_list, pi_forward_list
  453. def median_distance(Gn, Gn_median, measure='ged', verbose=False):
  454. dis_list = []
  455. pi_forward_list = []
  456. for idx, G in tqdm(enumerate(Gn), desc='computing median distances',
  457. file=sys.stdout) if verbose else enumerate(Gn):
  458. dis_sum = 0
  459. pi_forward_list.append([])
  460. for G_p in Gn_median:
  461. dis_tmp, pi_tmp_forward, pi_tmp_backward = GED(G, G_p)
  462. pi_forward_list[idx].append(pi_tmp_forward)
  463. dis_sum += dis_tmp
  464. dis_list.append(dis_sum)
  465. return dis_list, pi_forward_list
  466. def best_median_graphs(Gn_candidate, dis_all, pi_all_forward):
  467. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  468. dis_min = dis_all[idx_min_list[0]]
  469. pi_forward_min_list = [pi_all_forward[idx] for idx in idx_min_list]
  470. G_min_list = [Gn_candidate[idx] for idx in idx_min_list]
  471. return G_min_list, pi_forward_min_list, dis_min
  472. def iteration_proc(G, pi_p_forward):
  473. G_list = [G]
  474. pi_forward_list = [pi_p_forward]
  475. # iterations.
  476. for itr in range(0, 10): # @todo: the convergence condition?
  477. # print('itr is', itr)
  478. G_new_list = []
  479. pi_forward_new_list = []
  480. for idx, G in enumerate(G_list):
  481. label_set = get_node_labels(Gn_median + [G], node_label)
  482. G_tmp_list, pi_forward_tmp_list = generate_graph(
  483. G, pi_forward_list[idx], label_set)
  484. G_new_list += G_tmp_list
  485. pi_forward_new_list += pi_forward_tmp_list
  486. G_list = G_new_list[:]
  487. pi_forward_list = pi_forward_new_list[:]
  488. G_list, idx_list = remove_duplicates(G_list)
  489. pi_forward_list = [pi_forward_list[idx] for idx in idx_list]
  490. # import matplotlib.pyplot as plt
  491. # for g in G_list:
  492. # nx.draw_networkx(g)
  493. # plt.show()
  494. # print(g.nodes(data=True))
  495. # print(g.edges(data=True))
  496. return G_list, pi_forward_list # do we return all graphs or the best ones?
  497. def remove_duplicates(Gn):
  498. """Remove duplicate graphs from list.
  499. """
  500. Gn_new = []
  501. idx_list = []
  502. for idx, g in enumerate(Gn):
  503. dupl = False
  504. for g_new in Gn_new:
  505. if graph_isIdentical(g_new, g):
  506. dupl = True
  507. break
  508. if not dupl:
  509. Gn_new.append(g)
  510. idx_list.append(idx)
  511. return Gn_new, idx_list
  512. def remove_disconnected(Gn):
  513. """Remove disconnected graphs from list.
  514. """
  515. Gn_new = []
  516. idx_list = []
  517. for idx, g in enumerate(Gn):
  518. if nx.is_connected(g):
  519. Gn_new.append(g)
  520. idx_list.append(idx)
  521. return Gn_new, idx_list
  522. # phase 1: initilize.
  523. # compute set-median.
  524. dis_min = np.inf
  525. dis_all, pi_all_forward = median_distance(Gn_candidate[::-1], Gn_median)
  526. # find all smallest distances.
  527. idx_min_list = np.argwhere(dis_all == np.min(dis_all)).flatten().tolist()
  528. dis_min = dis_all[idx_min_list[0]]
  529. # phase 2: iteration.
  530. G_list = []
  531. for idx_min in idx_min_list[::-1]:
  532. # print('idx_min is', idx_min)
  533. G = Gn_candidate[idx_min].copy()
  534. # list of edit operations.
  535. pi_p_forward = pi_all_forward[idx_min]
  536. # pi_p_backward = pi_all_backward[idx_min]
  537. Gi_list, pi_i_forward_list = iteration_proc(G, pi_p_forward)
  538. G_list += Gi_list
  539. G_list, _ = remove_duplicates(G_list)
  540. if connected == True:
  541. G_list, _ = remove_disconnected(G_list)
  542. import matplotlib.pyplot as plt
  543. for g in G_list:
  544. nx.draw_networkx(g)
  545. plt.show()
  546. print(g.nodes(data=True))
  547. print(g.edges(data=True))
  548. # get the best median graphs
  549. dis_all, pi_all_forward = median_distance(G_list, Gn_median)
  550. G_min_list, pi_forward_min_list, dis_min = best_median_graphs(
  551. G_list, dis_all, pi_all_forward)
  552. for g in G_min_list:
  553. nx.draw_networkx(g)
  554. plt.show()
  555. print(g.nodes(data=True))
  556. print(g.edges(data=True))
  557. return G_min_list
  558. if __name__ == '__main__':
  559. from pygraph.utils.graphfiles import loadDataset
  560. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  561. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  562. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  563. # 'extra_params': {}} # node nsymb
  564. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  565. # 'extra_params': {}}
  566. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  567. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.