You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

preimage.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Mar 6 16:03:11 2019
  5. pre-image
  6. @author: ljia
  7. """
  8. import sys
  9. import numpy as np
  10. import multiprocessing
  11. from tqdm import tqdm
  12. import networkx as nx
  13. import matplotlib.pyplot as plt
  14. sys.path.insert(0, "../")
  15. from pygraph.kernels.marginalizedKernel import marginalizedkernel
  16. from pygraph.utils.graphfiles import loadDataset
  17. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG.mat',
  18. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  19. DN, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  20. DN = DN[0:10]
  21. lmbda = 0.03 # termination probalility
  22. r_max = 10 # recursions
  23. l = 500
  24. alpha_range = np.linspace(0.1, 0.9, 9)
  25. k = 5 # k nearest neighbors
  26. # randomly select two molecules
  27. np.random.seed(1)
  28. idx1, idx2 = np.random.randint(0, len(DN), 2)
  29. g1 = DN[idx1]
  30. g2 = DN[idx2]
  31. # compute
  32. k_list = [] # kernel between each graph and itself.
  33. k_g1_list = [] # kernel between each graph and g1
  34. k_g2_list = [] # kernel between each graph and g2
  35. for ig, g in tqdm(enumerate(DN), desc='computing self kernels', file=sys.stdout):
  36. ktemp = marginalizedkernel([g, g1, g2], node_label='atom', edge_label=None,
  37. p_quit=lmbda, n_iteration=20, remove_totters=False,
  38. n_jobs=multiprocessing.cpu_count(), verbose=False)
  39. k_list.append(ktemp[0][0, 0])
  40. k_g1_list.append(ktemp[0][0, 1])
  41. k_g2_list.append(ktemp[0][0, 2])
  42. g_best = []
  43. dis_best = []
  44. # for each alpha
  45. for alpha in alpha_range:
  46. print('alpha =', alpha)
  47. # compute k nearest neighbors of phi in DN.
  48. dis_list = [] # distance between g_star and each graph.
  49. for ig, g in tqdm(enumerate(DN), desc='computing distances', file=sys.stdout):
  50. dtemp = k_list[ig] - 2 * (alpha * k_g1_list[ig] + (1 - alpha) *
  51. k_g2_list[ig]) + (alpha * alpha * k_list[idx1] + alpha *
  52. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  53. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  54. dis_list.append(dtemp)
  55. # sort
  56. sort_idx = np.argsort(dis_list)
  57. dis_gs = [dis_list[idis] for idis in sort_idx[0:k]]
  58. g0hat = DN[sort_idx[0]] # the nearest neighbor of phi in DN
  59. if dis_gs[0] == 0: # the exact pre-image.
  60. print('The exact pre-image is found from the input dataset.')
  61. g_pimg = g0hat
  62. break
  63. dhat = dis_gs[0] # the nearest distance
  64. Dk = [DN[ig] for ig in sort_idx[0:k]] # the k nearest neighbors
  65. gihat_list = []
  66. i = 1
  67. r = 1
  68. while r < r_max:
  69. print('r =', r)
  70. found = False
  71. for ig, gs in enumerate(Dk + gihat_list):
  72. # nx.draw_networkx(gs)
  73. # plt.show()
  74. fdgs = int(np.abs(np.ceil(np.log(alpha * dis_gs[ig])))) # @todo ???
  75. for trail in tqdm(range(0, l), desc='l loop', file=sys.stdout):
  76. # add and delete edges.
  77. gtemp = gs.copy()
  78. np.random.seed()
  79. # which edges to change.
  80. idx_change = np.random.randint(0, nx.number_of_nodes(gs) *
  81. (nx.number_of_nodes(gs) - 1), fdgs)
  82. for item in idx_change:
  83. node1 = int(item / (nx.number_of_nodes(gs) - 1))
  84. node2 = (item - node1 * (nx.number_of_nodes(gs) - 1))
  85. if node2 >= node1:
  86. node2 += 1
  87. # @todo: is the randomness correct?
  88. if not gtemp.has_edge(node1, node2):
  89. gtemp.add_edges_from([(node1, node2, {'bond_type': 0})])
  90. # nx.draw_networkx(gs)
  91. # plt.show()
  92. # nx.draw_networkx(gtemp)
  93. # plt.show()
  94. else:
  95. gtemp.remove_edge(node1, node2)
  96. # nx.draw_networkx(gs)
  97. # plt.show()
  98. # nx.draw_networkx(gtemp)
  99. # plt.show()
  100. # nx.draw_networkx(gtemp)
  101. # plt.show()
  102. # compute distance between phi and the new generated graph.
  103. knew = marginalizedkernel([gtemp, g1, g2], node_label='atom', edge_label=None,
  104. p_quit=lmbda, n_iteration=20, remove_totters=False,
  105. n_jobs=multiprocessing.cpu_count(), verbose=False)
  106. dnew = knew[0][0, 0] - 2 * (alpha * knew[0][0, 1] + (1 - alpha) *
  107. knew[0][0, 2]) + (alpha * alpha * k_list[idx1] + alpha *
  108. (1 - alpha) * k_g2_list[idx1] + (1 - alpha) * alpha *
  109. k_g1_list[idx2] + (1 - alpha) * (1 - alpha) * k_list[idx2])
  110. if dnew <= dhat: # the new distance is smaller
  111. print('I am smaller!')
  112. dhat = dnew
  113. gnew = gtemp.copy()
  114. found = True # found better graph.
  115. if found:
  116. gihat_list = [gnew]
  117. dis_gs.append(dhat)
  118. else:
  119. r += 1
  120. dis_best.append(dhat)
  121. g_best += ([g0hat] if len(gihat_list) == 0 else gihat_list)
  122. for idx, item in enumerate(alpha_range):
  123. print('when alpha is', item, 'the shortest distance is', dis_best[idx])
  124. print('the corresponding pre-image is')
  125. nx.draw_networkx(g_best[idx])
  126. plt.show()

A Python package for graph kernels, graph edit distances and graph pre-image problem.