You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

preimage.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Mar 6 16:03:11 2019
  5. pre-image
  6. @author: ljia
  7. """
  8. import sys
  9. import numpy as np
  10. import random
  11. import multiprocessing
  12. from tqdm import tqdm
  13. import networkx as nx
  14. import matplotlib.pyplot as plt
  15. sys.path.insert(0, "../")
  16. from pygraph.utils.graphfiles import loadDataset
  17. from pygraph.kernels.marginalizedKernel import marginalizedkernel
  18. from pygraph.kernels.untilHPathKernel import untilhpathkernel
  19. from pygraph.kernels.spKernel import spkernel
  20. import functools
  21. from pygraph.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  22. from pygraph.kernels.structuralspKernel import structuralspkernel
  23. def compute_kernel(Gn, graph_kernel, verbose):
  24. if graph_kernel == 'marginalizedkernel':
  25. Kmatrix, _ = marginalizedkernel(Gn, node_label='atom', edge_label=None,
  26. p_quit=0.03, n_iteration=20, remove_totters=False,
  27. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  28. elif graph_kernel == 'untilhpathkernel':
  29. Kmatrix, _ = untilhpathkernel(Gn, node_label='atom', edge_label='bond_type',
  30. depth=10, k_func='MinMax', compute_method='trie',
  31. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  32. elif graph_kernel == 'spkernel':
  33. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  34. Kmatrix, _, _ = spkernel(Gn, node_label='atom', node_kernels=
  35. {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel},
  36. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  37. elif graph_kernel == 'structuralspkernel':
  38. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  39. Kmatrix, _ = structuralspkernel(Gn, node_label='atom', node_kernels=
  40. {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel},
  41. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  42. # normalization
  43. # Kmatrix_diag = Kmatrix.diagonal().copy()
  44. # for i in range(len(Kmatrix)):
  45. # for j in range(i, len(Kmatrix)):
  46. # Kmatrix[i][j] /= np.sqrt(Kmatrix_diag[i] * Kmatrix_diag[j])
  47. # Kmatrix[j][i] = Kmatrix[i][j]
  48. return Kmatrix
  49. if __name__ == '__main__':
  50. # ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
  51. # 'extra_params': {}} # node/edge symb
  52. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  53. # 'extra_params': {}} # node nsymb
  54. # ds = {'name': 'Acyclic', 'dataset': '../datasets/monoterpenoides/trainset_9.ds',
  55. # 'extra_params': {}}
  56. ds = {'name': 'Acyclic', 'dataset': '../datasets/acyclic/dataset_bps.ds',
  57. 'extra_params': {}} # node symb
  58. DN, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  59. #DN = DN[0:10]
  60. lmbda = 0.03 # termination probalility
  61. r_max = 10 # recursions
  62. l = 500
  63. alpha_range = np.linspace(0.5, 0.5, 1)
  64. #alpha_range = np.linspace(0.1, 0.9, 9)
  65. k = 5 # k nearest neighbors
  66. # randomly select two molecules
  67. #np.random.seed(1)
  68. #idx1, idx2 = np.random.randint(0, len(DN), 2)
  69. #g1 = DN[idx1]
  70. #g2 = DN[idx2]
  71. idx1 = 0
  72. idx2 = 6
  73. g1 = DN[idx1]
  74. g2 = DN[idx2]
  75. # compute
  76. k_list = [] # kernel between each graph and itself.
  77. k_g1_list = [] # kernel between each graph and g1
  78. k_g2_list = [] # kernel between each graph and g2
  79. for ig, g in tqdm(enumerate(DN), desc='computing self kernels', file=sys.stdout):
  80. # ktemp = marginalizedkernel([g, g1, g2], node_label='atom', edge_label=None,
  81. # p_quit=lmbda, n_iteration=20, remove_totters=False,
  82. # n_jobs=multiprocessing.cpu_count(), verbose=False)
  83. ktemp = compute_kernel([g, g1, g2], 'untilhpathkernel', verbose=False)
  84. k_list.append(ktemp[0, 0])
  85. k_g1_list.append(ktemp[0, 1])
  86. k_g2_list.append(ktemp[0, 2])
  87. g_best = []
  88. dis_best = []
  89. # for each alpha
  90. for alpha in alpha_range:
  91. print('alpha =', alpha)
  92. # compute k nearest neighbors of phi in DN.
  93. dis_list = [] # distance between g_star and each graph.
  94. for ig, g in tqdm(enumerate(DN), desc='computing distances', file=sys.stdout):
  95. dtemp = k_list[ig] - 2 * (alpha * k_g1_list[ig] + (1 - alpha) *
  96. k_g2_list[ig]) + (alpha * alpha * k_list[idx1] + alpha *
  97. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  98. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  99. dis_list.append(np.sqrt(dtemp))
  100. # sort
  101. sort_idx = np.argsort(dis_list)
  102. dis_gs = [dis_list[idis] for idis in sort_idx[0:k]]
  103. g0hat = DN[sort_idx[0]] # the nearest neighbor of phi in DN
  104. if dis_gs[0] == 0: # the exact pre-image.
  105. print('The exact pre-image is found from the input dataset.')
  106. g_pimg = g0hat
  107. break
  108. dhat = dis_gs[0] # the nearest distance
  109. Dk = [DN[ig] for ig in sort_idx[0:k]] # the k nearest neighbors
  110. gihat_list = []
  111. i = 1
  112. r = 1
  113. while r < r_max:
  114. print('r =', r)
  115. found = False
  116. for ig, gs in enumerate(Dk + gihat_list):
  117. # nx.draw_networkx(gs)
  118. # plt.show()
  119. # @todo what if the log is negetive?
  120. fdgs = int(np.abs(np.ceil(np.log(alpha * dis_gs[ig]))))
  121. for trail in tqdm(range(0, l), desc='l loop', file=sys.stdout):
  122. # add and delete edges.
  123. gtemp = gs.copy()
  124. np.random.seed()
  125. # which edges to change.
  126. # @todo: should we use just half of the adjacency matrix for undirected graphs?
  127. nb_vpairs = nx.number_of_nodes(gs) * (nx.number_of_nodes(gs) - 1)
  128. # @todo: what if fdgs is bigger than nb_vpairs?
  129. idx_change = random.sample(range(nb_vpairs), fdgs if fdgs < nb_vpairs else nb_vpairs)
  130. # idx_change = np.random.randint(0, nx.number_of_nodes(gs) *
  131. # (nx.number_of_nodes(gs) - 1), fdgs)
  132. for item in idx_change:
  133. node1 = int(item / (nx.number_of_nodes(gs) - 1))
  134. node2 = (item - node1 * (nx.number_of_nodes(gs) - 1))
  135. if node2 >= node1: # skip the self pair.
  136. node2 += 1
  137. # @todo: is the randomness correct?
  138. if not gtemp.has_edge(node1, node2):
  139. # @todo: how to update the bond_type? 0 or 1?
  140. gtemp.add_edges_from([(node1, node2, {'bond_type': 1})])
  141. # nx.draw_networkx(gs)
  142. # plt.show()
  143. # nx.draw_networkx(gtemp)
  144. # plt.show()
  145. else:
  146. gtemp.remove_edge(node1, node2)
  147. # nx.draw_networkx(gs)
  148. # plt.show()
  149. # nx.draw_networkx(gtemp)
  150. # plt.show()
  151. # nx.draw_networkx(gtemp)
  152. # plt.show()
  153. # compute distance between phi and the new generated graph.
  154. # knew = marginalizedkernel([gtemp, g1, g2], node_label='atom', edge_label=None,
  155. # p_quit=lmbda, n_iteration=20, remove_totters=False,
  156. # n_jobs=multiprocessing.cpu_count(), verbose=False)
  157. knew = compute_kernel([gtemp, g1, g2], 'untilhpathkernel', verbose=False)
  158. dnew = np.sqrt(knew[0, 0] - 2 * (alpha * knew[0, 1] + (1 - alpha) *
  159. knew[0, 2]) + (alpha * alpha * k_list[idx1] + alpha *
  160. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  161. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2]))
  162. if dnew < dhat: # @todo: the new distance is smaller or also equal?
  163. print('I am smaller!')
  164. print(dhat, '->', dnew)
  165. nx.draw_networkx(gtemp)
  166. plt.show()
  167. print(gtemp.nodes(data=True))
  168. print(gtemp.edges(data=True))
  169. dhat = dnew
  170. gnew = gtemp.copy()
  171. found = True # found better graph.
  172. r = 0
  173. elif dnew == dhat:
  174. print('I am equal!')
  175. if found:
  176. gihat_list = [gnew]
  177. dis_gs.append(dhat)
  178. else:
  179. r += 1
  180. dis_best.append(dhat)
  181. g_best += ([g0hat] if len(gihat_list) == 0 else gihat_list)
  182. for idx, item in enumerate(alpha_range):
  183. print('when alpha is', item, 'the shortest distance is', dis_best[idx])
  184. print('the corresponding pre-image is')
  185. nx.draw_networkx(g_best[idx])
  186. plt.show()

A Python package for graph kernels, graph edit distances and graph pre-image problem.