You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gk_iam.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Apr 30 17:07:43 2019
  5. A graph pre-image method combining iterative pre-image method in reference [1]
  6. and the iterative alternate minimizations (IAM) in reference [2].
  7. @author: ljia
  8. @references:
  9. [1] Gökhan H Bakir, Alexander Zien, and Koji Tsuda. Learning to and graph
  10. pre-images. In Joint Pattern Re ognition Symposium , pages 253-261. Springer, 2004.
  11. [2] Generalized median graph via iterative alternate minimization.
  12. """
  13. import numpy as np
  14. import multiprocessing
  15. from tqdm import tqdm
  16. import networkx as nx
  17. import matplotlib.pyplot as plt
  18. from iam import iam
  19. def gk_iam(Gn, alpha):
  20. """This function constructs graph pre-image by the iterative pre-image
  21. framework in reference [1], algorithm 1, where the step of generating new
  22. graphs randomly is replaced by the IAM algorithm in reference [2].
  23. notes
  24. -----
  25. Every time a better graph is acquired, the older one is replaced by it.
  26. """
  27. # compute k nearest neighbors of phi in DN.
  28. dis_list = [] # distance between g_star and each graph.
  29. for ig, g in tqdm(enumerate(Gn), desc='computing distances', file=sys.stdout):
  30. dtemp = k_list[ig] - 2 * (alpha * k_g1_list[ig] + (1 - alpha) *
  31. k_g2_list[ig]) + (alpha * alpha * k_list[idx1] + alpha *
  32. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  33. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  34. dis_list.append(dtemp)
  35. # sort
  36. sort_idx = np.argsort(dis_list)
  37. dis_gs = [dis_list[idis] for idis in sort_idx[0:k]]
  38. g0hat = Gn[sort_idx[0]] # the nearest neighbor of phi in DN
  39. if dis_gs[0] == 0: # the exact pre-image.
  40. print('The exact pre-image is found from the input dataset.')
  41. return 0, g0hat
  42. dhat = dis_gs[0] # the nearest distance
  43. Gk = [Gn[ig] for ig in sort_idx[0:k]] # the k nearest neighbors
  44. gihat_list = []
  45. # i = 1
  46. r = 1
  47. while r < r_max:
  48. print('r =', r)
  49. # found = False
  50. Gs_nearest = Gk + gihat_list
  51. g_tmp = iam(Gs_nearest)
  52. # compute distance between phi and the new generated graph.
  53. knew = marginalizedkernel([g_tmp, g1, g2], node_label='atom', edge_label=None,
  54. p_quit=lmbda, n_iteration=20, remove_totters=False,
  55. n_jobs=multiprocessing.cpu_count(), verbose=False)
  56. dnew = knew[0][0, 0] - 2 * (alpha * knew[0][0, 1] + (1 - alpha) *
  57. knew[0][0, 2]) + (alpha * alpha * k_list[idx1] + alpha *
  58. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  59. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  60. if dnew <= dhat: # the new distance is smaller
  61. print('I am smaller!')
  62. dhat = dnew
  63. g_new = g_tmp.copy() # found better graph.
  64. gihat_list = [g_new]
  65. dis_gs.append(dhat)
  66. r = 0
  67. else:
  68. r += 1
  69. ghat = ([g0hat] if len(gihat_list) == 0 else gihat_list)
  70. return dhat, ghat
  71. def gk_iam_nearest(Gn, alpha):
  72. """This function constructs graph pre-image by the iterative pre-image
  73. framework in reference [1], algorithm 1, where the step of generating new
  74. graphs randomly is replaced by the IAM algorithm in reference [2].
  75. notes
  76. -----
  77. Every time a better graph is acquired, its distance in kernel space is
  78. compared with the k nearest ones, and the k nearest distances from the k+1
  79. distances will be used as the new ones.
  80. """
  81. # compute k nearest neighbors of phi in DN.
  82. dis_list = [] # distance between g_star and each graph.
  83. for ig, g in tqdm(enumerate(Gn), desc='computing distances', file=sys.stdout):
  84. dtemp = k_list[ig] - 2 * (alpha * k_g1_list[ig] + (1 - alpha) *
  85. k_g2_list[ig]) + (alpha * alpha * k_list[idx1] + alpha *
  86. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  87. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  88. dis_list.append(dtemp)
  89. # sort
  90. sort_idx = np.argsort(dis_list)
  91. dis_gs = [dis_list[idis] for idis in sort_idx[0:k]] # the k shortest distances
  92. g0hat = Gn[sort_idx[0]] # the nearest neighbor of phi in DN
  93. if dis_gs[0] == 0: # the exact pre-image.
  94. print('The exact pre-image is found from the input dataset.')
  95. return 0, g0hat
  96. dhat = dis_gs[0] # the nearest distance
  97. ghat = g0hat
  98. Gk = [Gn[ig] for ig in sort_idx[0:k]] # the k nearest neighbors
  99. Gs_nearest = Gk
  100. # gihat_list = []
  101. # i = 1
  102. r = 1
  103. while r < r_max:
  104. print('r =', r)
  105. # found = False
  106. # Gs_nearest = Gk + gihat_list
  107. g_tmp = iam(Gs_nearest)
  108. # compute distance between phi and the new generated graph.
  109. knew = marginalizedkernel([g_tmp, g1, g2], node_label='atom', edge_label=None,
  110. p_quit=lmbda, n_iteration=20, remove_totters=False,
  111. n_jobs=multiprocessing.cpu_count(), verbose=False)
  112. dnew = knew[0][0, 0] - 2 * (alpha * knew[0][0, 1] + (1 - alpha) *
  113. knew[0][0, 2]) + (alpha * alpha * k_list[idx1] + alpha *
  114. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  115. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  116. if dnew <= dhat: # the new distance is smaller
  117. print('I am smaller!')
  118. dhat = dnew
  119. g_new = g_tmp.copy() # found better graph.
  120. ghat = g_tmp.copy()
  121. dis_gs.append(dhat) # add the new nearest distance.
  122. Gs_nearest.append(g_new) # add the corresponding graph.
  123. sort_idx = np.argsort(dis_gs)
  124. dis_gs = [dis_gs[idx] for idx in sort_idx[0:k]] # the new k nearest distances.
  125. Gs_nearest = [Gs_nearest[idx] for idx in sort_idx[0:k]]
  126. r = 0
  127. else:
  128. r += 1
  129. return dhat, ghat
  130. if __name__ == '__main__':
  131. import sys
  132. sys.path.insert(0, "../")
  133. from pygraph.kernels.marginalizedKernel import marginalizedkernel
  134. from pygraph.utils.graphfiles import loadDataset
  135. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  136. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  137. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  138. # Gn = Gn[0:10]
  139. lmbda = 0.03 # termination probalility
  140. r_max = 10 # recursions
  141. l = 500
  142. alpha_range = np.linspace(0.1, 0.9, 9)
  143. k = 5 # k nearest neighbors
  144. # randomly select two molecules
  145. np.random.seed(1)
  146. idx1, idx2 = np.random.randint(0, len(Gn), 2)
  147. g1 = Gn[idx1]
  148. g2 = Gn[idx2]
  149. # compute
  150. k_list = [] # kernel between each graph and itself.
  151. k_g1_list = [] # kernel between each graph and g1
  152. k_g2_list = [] # kernel between each graph and g2
  153. for ig, g in tqdm(enumerate(Gn), desc='computing self kernels', file=sys.stdout):
  154. ktemp = marginalizedkernel([g, g1, g2], node_label='atom', edge_label=None,
  155. p_quit=lmbda, n_iteration=20, remove_totters=False,
  156. n_jobs=multiprocessing.cpu_count(), verbose=False)
  157. k_list.append(ktemp[0][0, 0])
  158. k_g1_list.append(ktemp[0][0, 1])
  159. k_g2_list.append(ktemp[0][0, 2])
  160. g_best = []
  161. dis_best = []
  162. # for each alpha
  163. for alpha in alpha_range:
  164. print('alpha =', alpha)
  165. dhat, ghat = gk_iam_nearest(Gn, alpha)
  166. dis_best.append(dhat)
  167. g_best.append(ghat)
  168. for idx, item in enumerate(alpha_range):
  169. print('when alpha is', item, 'the shortest distance is', dis_best[idx])
  170. print('the corresponding pre-image is')
  171. nx.draw_networkx(g_best[idx])
  172. plt.show()

A Python package for graph kernels, graph edit distances and graph pre-image problem.