You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iam.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Apr 26 11:49:12 2019
  5. Iterative alternate minimizations using GED.
  6. @author: ljia
  7. """
  8. import numpy as np
  9. import random
  10. import networkx as nx
  11. import sys
  12. #from Cython_GedLib_2 import librariesImport, script
  13. import librariesImport, script
  14. sys.path.insert(0, "../")
  15. from pygraph.utils.graphfiles import saveDataset
  16. from pygraph.utils.graphdataset import get_dataset_attributes
  17. def iam(Gn, node_label='atom', edge_label='bond_type'):
  18. """See my name, then you know what I do.
  19. """
  20. # Gn = Gn[0:10]
  21. Gn = [nx.convert_node_labels_to_integers(g) for g in Gn]
  22. c_er = 1
  23. c_es = 1
  24. c_ei = 1
  25. # phase 1: initilize.
  26. # compute set-median.
  27. dis_min = np.inf
  28. pi_p = []
  29. pi_all = []
  30. for idx1, G_p in enumerate(Gn):
  31. dist_sum = 0
  32. pi_all.append([])
  33. for idx2, G_p_prime in enumerate(Gn):
  34. dist_tmp, pi_tmp = GED(G_p, G_p_prime)
  35. pi_all[idx1].append(pi_tmp)
  36. dist_sum += dist_tmp
  37. if dist_sum < dis_min:
  38. dis_min = dist_sum
  39. G = G_p.copy()
  40. idx_min = idx1
  41. # list of edit operations.
  42. pi_p = pi_all[idx_min]
  43. # phase 2: iteration.
  44. ds_attrs = get_dataset_attributes(Gn, attr_names=['edge_labeled', 'node_attr_dim'],
  45. edge_label=edge_label)
  46. for itr in range(0, 10):
  47. G_new = G.copy()
  48. # update vertex labels.
  49. # pre-compute h_i0 for each label.
  50. # for label in get_node_labels(Gn, node_label):
  51. # print(label)
  52. # for nd in G.nodes(data=True):
  53. # pass
  54. if not ds_attrs['node_attr_dim']: # labels are symbolic
  55. for nd, _ in G.nodes(data=True):
  56. h_i0_list = []
  57. label_list = []
  58. for label in get_node_labels(Gn, node_label):
  59. h_i0 = 0
  60. for idx, g in enumerate(Gn):
  61. pi_i = pi_p[idx][nd]
  62. if g.has_node(pi_i) and g.nodes[pi_i][node_label] == label:
  63. h_i0 += 1
  64. h_i0_list.append(h_i0)
  65. label_list.append(label)
  66. # choose one of the best randomly.
  67. idx_max = np.argwhere(h_i0_list == np.max(h_i0_list)).flatten().tolist()
  68. idx_rdm = random.randint(0, len(idx_max) - 1)
  69. G_new.nodes[nd][node_label] = label_list[idx_max[idx_rdm]]
  70. else: # labels are non-symbolic
  71. for nd, _ in G.nodes(data=True):
  72. Si_norm = 0
  73. phi_i_bar = np.array([0.0 for _ in range(ds_attrs['node_attr_dim'])])
  74. for idx, g in enumerate(Gn):
  75. pi_i = pi_p[idx][nd]
  76. if g.has_node(pi_i): #@todo: what if no g has node? phi_i_bar = 0?
  77. Si_norm += 1
  78. phi_i_bar += np.array([float(itm) for itm in g.nodes[pi_i]['attributes']])
  79. phi_i_bar /= Si_norm
  80. G_new.nodes[nd]['attributes'] = phi_i_bar
  81. # update edge labels and adjacency matrix.
  82. if ds_attrs['edge_labeled']:
  83. for nd1, nd2, _ in G.edges(data=True):
  84. h_ij0_list = []
  85. label_list = []
  86. for label in get_edge_labels(Gn, edge_label):
  87. h_ij0 = 0
  88. for idx, g in enumerate(Gn):
  89. pi_i = pi_p[idx][nd1]
  90. pi_j = pi_p[idx][nd2]
  91. h_ij0_p = (g.has_node(pi_i) and g.has_node(pi_j) and
  92. g.has_edge(pi_i, pi_j) and
  93. g.edges[pi_i, pi_j][edge_label] == label)
  94. h_ij0 += h_ij0_p
  95. h_ij0_list.append(h_ij0)
  96. label_list.append(label)
  97. # choose one of the best randomly.
  98. idx_max = np.argwhere(h_ij0_list == np.max(h_ij0_list)).flatten().tolist()
  99. h_ij0_max = h_ij0_list[idx_max[0]]
  100. idx_rdm = random.randint(0, len(idx_max) - 1)
  101. best_label = label_list[idx_max[idx_rdm]]
  102. # check whether a_ij is 0 or 1.
  103. sij_norm = 0
  104. for idx, g in enumerate(Gn):
  105. pi_i = pi_p[idx][nd1]
  106. pi_j = pi_p[idx][nd2]
  107. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  108. sij_norm += 1
  109. if h_ij0_max > len(Gn) * c_er / c_es + sij_norm * (1 - (c_er + c_ei) / c_es):
  110. if not G_new.has_edge(nd1, nd2):
  111. G_new.add_edge(nd1, nd2)
  112. G_new.edges[nd1, nd2][edge_label] = best_label
  113. else:
  114. if G_new.has_edge(nd1, nd2):
  115. G_new.remove_edge(nd1, nd2)
  116. else: # if edges are unlabeled
  117. for nd1, nd2, _ in G.edges(data=True):
  118. sij_norm = 0
  119. for idx, g in enumerate(Gn):
  120. pi_i = pi_p[idx][nd1]
  121. pi_j = pi_p[idx][nd2]
  122. if g.has_node(pi_i) and g.has_node(pi_j) and g.has_edge(pi_i, pi_j):
  123. sij_norm += 1
  124. if sij_norm > len(Gn) * c_er / (c_er + c_ei):
  125. if not G_new.has_edge(nd1, nd2):
  126. G_new.add_edge(nd1, nd2)
  127. else:
  128. if G_new.has_edge(nd1, nd2):
  129. G_new.remove_edge(nd1, nd2)
  130. G = G_new.copy()
  131. return G
  132. def GED(g1, g2, lib='gedlib'):
  133. """
  134. Compute GED. It is a dummy function for now.
  135. """
  136. if lib == 'gedlib':
  137. saveDataset([g1, g2], [None, None], group='xml', filename='ged_tmp/tmp')
  138. script.appel()
  139. script.PyRestartEnv()
  140. script.PyLoadGXLGraph('ged_tmp/', 'collections/tmp.xml')
  141. listID = script.PyGetGraphIds()
  142. script.PySetEditCost("CHEM_1")
  143. script.PyInitEnv()
  144. script.PySetMethod("BIPARTITE", "")
  145. script.PyInitMethod()
  146. g = listID[0]
  147. h = listID[1]
  148. script.PyRunMethod(g, h)
  149. liste = script.PyGetAllMap(g, h)
  150. upper = script.PyGetUpperBound(g, h)
  151. lower = script.PyGetLowerBound(g, h)
  152. dis = upper + lower
  153. pi = liste[0]
  154. return dis, pi
  155. def get_node_labels(Gn, node_label):
  156. nl = set()
  157. for G in Gn:
  158. nl = nl | set(nx.get_node_attributes(G, node_label).values())
  159. return nl
  160. def get_edge_labels(Gn, edge_label):
  161. el = set()
  162. for G in Gn:
  163. el = el | set(nx.get_edge_attributes(G, edge_label).values())
  164. return el
  165. if __name__ == '__main__':
  166. from pygraph.utils.graphfiles import loadDataset
  167. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  168. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  169. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  170. # 'extra_params': {}} # node nsymb
  171. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  172. # 'extra_params': {}}
  173. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  174. iam(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.