You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_iam.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 5 15:59:00 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import matplotlib.pyplot as plt
  10. import time
  11. import random
  12. #from tqdm import tqdm
  13. #import os
  14. import sys
  15. sys.path.insert(0, "../")
  16. from pygraph.utils.graphfiles import loadDataset
  17. from iam import iam_upgraded
  18. from utils import remove_edges, compute_kernel, get_same_item_indices
  19. from ged import ged_median
  20. ###############################################################################
  21. # tests on different numbers of median-sets.
  22. def test_iam_median_nb():
  23. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
  24. 'extra_params': {}} # node/edge symb
  25. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  26. # Gn = Gn[0:50]
  27. remove_edges(Gn)
  28. gkernel = 'marginalizedkernel'
  29. # lmbda = 0.03 # termination probalility
  30. # r_max = 10 # iteration limit for pre-image.
  31. # alpha_range = np.linspace(0.5, 0.5, 1)
  32. # k = 5 # k nearest neighbors
  33. # epsilon = 1e-6
  34. # InitIAMWithAllDk = True
  35. # parameters for GED function
  36. ged_cost='CHEM_1'
  37. ged_method='IPFP'
  38. saveGXL='gedlib'
  39. # parameters for IAM function
  40. c_ei=1
  41. c_er=1
  42. c_es=1
  43. ite_max_iam = 50
  44. epsilon_iam = 0.001
  45. removeNodes = False
  46. connected_iam = False
  47. # number of graphs; we what to compute the median of these graphs.
  48. nb_median_range = [2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
  49. # find out all the graphs classified to positive group 1.
  50. idx_dict = get_same_item_indices(y_all)
  51. Gn = [Gn[i] for i in idx_dict[1]]
  52. # # compute Gram matrix.
  53. # time0 = time.time()
  54. # km = compute_kernel(Gn, gkernel, True)
  55. # time_km = time.time() - time0
  56. # # write Gram matrix to file.
  57. # np.savez('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm', gm=km, gmtime=time_km)
  58. time_list = []
  59. dis_ks_min_list = []
  60. sod_gs_list = []
  61. sod_gs_min_list = []
  62. nb_updated_list = []
  63. nb_updated_k_list = []
  64. g_best = []
  65. for nb_median in nb_median_range:
  66. print('\n-------------------------------------------------------')
  67. print('number of median graphs =', nb_median)
  68. random.seed(1)
  69. idx_rdm = random.sample(range(len(Gn)), nb_median)
  70. print('graphs chosen:', idx_rdm)
  71. Gn_median = [Gn[idx].copy() for idx in idx_rdm]
  72. Gn_candidate = [g.copy() for g in Gn_median]
  73. # for g in Gn_median:
  74. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  75. ## plt.savefig("results/preimage_mix/mutag.png", format="PNG")
  76. # plt.show()
  77. # plt.clf()
  78. ###################################################################
  79. gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm.npz')
  80. km_tmp = gmfile['gm']
  81. time_km = gmfile['gmtime']
  82. # modify mixed gram matrix.
  83. km = np.zeros((len(Gn) + nb_median, len(Gn) + nb_median))
  84. for i in range(len(Gn)):
  85. for j in range(i, len(Gn)):
  86. km[i, j] = km_tmp[i, j]
  87. km[j, i] = km[i, j]
  88. for i in range(len(Gn)):
  89. for j, idx in enumerate(idx_rdm):
  90. km[i, len(Gn) + j] = km[i, idx]
  91. km[len(Gn) + j, i] = km[i, idx]
  92. for i, idx1 in enumerate(idx_rdm):
  93. for j, idx2 in enumerate(idx_rdm):
  94. km[len(Gn) + i, len(Gn) + j] = km[idx1, idx2]
  95. ###################################################################
  96. alpha_range = [1 / nb_median] * nb_median
  97. time0 = time.time()
  98. ghat_new_list, dis_min = iam_upgraded(Gn_median, Gn_candidate,
  99. c_ei=c_ei, c_er=c_er, c_es=c_es, ite_max=ite_max_iam,
  100. epsilon=epsilon_iam, removeNodes=removeNodes,
  101. connected=connected_iam,
  102. params_ged={'ged_cost': ged_cost, 'ged_method': ged_method,
  103. 'saveGXL': saveGXL})
  104. time_total = time.time() - time0
  105. print('\ntime: ', time_total)
  106. time_list.append(time_total)
  107. print('\nsmallest distance in kernel space: ', dhat)
  108. dis_ks_min_list.append(dhat)
  109. g_best.append(ghat_list)
  110. print('\nnumber of updates of the best graph: ', nb_updated)
  111. nb_updated_list.append(nb_updated)
  112. print('\nnumber of updates of k nearest graphs: ', nb_updated_k)
  113. nb_updated_k_list.append(nb_updated_k)
  114. # show the best graph and save it to file.
  115. print('the shortest distance is', dhat)
  116. print('one of the possible corresponding pre-images is')
  117. nx.draw(ghat_list[0], labels=nx.get_node_attributes(ghat_list[0], 'atom'),
  118. with_labels=True)
  119. plt.show()
  120. plt.savefig('results/preimage_iam/mutag_median_nb' + str(nb_median) +
  121. '.png', format="PNG")
  122. plt.clf()
  123. # print(ghat_list[0].nodes(data=True))
  124. # print(ghat_list[0].edges(data=True))
  125. # compute the corresponding sod in graph space.
  126. sod_tmp, _ = ged_median([ghat_list[0]], Gn_median, ged_cost=ged_cost,
  127. ged_method=ged_method, saveGXL=saveGXL)
  128. sod_gs_list.append(sod_tmp)
  129. sod_gs_min_list.append(np.min(sod_tmp))
  130. print('\nsmallest sod in graph space: ', np.min(sod_tmp))
  131. print('\nsods in graph space: ', sod_gs_list)
  132. print('\nsmallest sod in graph space for each set of median graphs: ', sod_gs_min_list)
  133. print('\nsmallest distance in kernel space for each set of median graphs: ',
  134. dis_ks_min_list)
  135. print('\nnumber of updates of the best graph for each set of median graphs by IAM: ',
  136. nb_updated_list)
  137. print('\nnumber of updates of k nearest graphs for each set of median graphs by IAM: ',
  138. nb_updated_k_list)
  139. print('\ntimes:', time_list)
  140. ###############################################################################
  141. if __name__ == '__main__':
  142. ###############################################################################
  143. # tests on different numbers of median-sets.
  144. test_iam_median_nb()

A Python package for graph kernels, graph edit distances and graph pre-image problem.