You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

find_best_k.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Jan 9 11:54:32 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import random
  9. import csv
  10. import sys
  11. sys.path.insert(0, "../")
  12. from pygraph.utils.graphfiles import loadDataset
  13. from preimage.test_k_closest_graphs import median_on_k_closest_graphs
  14. def find_best_k():
  15. ds = {'name': 'monoterpenoides',
  16. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  17. Gn, y_all = loadDataset(ds['dataset'])
  18. # Gn = Gn[0:50]
  19. gkernel = 'treeletkernel'
  20. node_label = 'atom'
  21. edge_label = 'bond_type'
  22. ds_name = 'mono'
  23. dir_output = 'results/test_find_best_k/'
  24. repeats = 50
  25. k_list = range(2, 11)
  26. fit_method = 'k-graphs'
  27. # fitted on the whole dataset - treelet - mono
  28. edit_costs = [0.1268873773592978, 0.004084633224249829, 0.0897581955378986, 0.15328856114451297, 0.3109956881625734, 0.0]
  29. # create result files.
  30. fn_output_detail = 'results_detail.' + fit_method + '.csv'
  31. f_detail = open(dir_output + fn_output_detail, 'a')
  32. csv.writer(f_detail).writerow(['dataset', 'graph kernel', 'fit method', 'k',
  33. 'repeat', 'median set', 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  34. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  35. 'dis_k gi -> GM'])
  36. f_detail.close()
  37. fn_output_summary = 'results_summary.csv'
  38. f_summary = open(dir_output + fn_output_summary, 'a')
  39. csv.writer(f_summary).writerow(['dataset', 'graph kernel', 'fit method', 'k',
  40. 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  41. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  42. 'dis_k gi -> GM', '# SOD SM -> GM', '# dis_k SM -> GM',
  43. '# dis_k gi -> SM', '# dis_k gi -> GM', 'repeats better SOD SM -> GM',
  44. 'repeats better dis_k SM -> GM', 'repeats better dis_k gi -> SM',
  45. 'repeats better dis_k gi -> GM'])
  46. f_summary.close()
  47. random.seed(1)
  48. rdn_seed_list = random.sample(range(0, repeats * 100), repeats)
  49. for k in k_list:
  50. print('\n--------- k =', k, '----------')
  51. sod_sm_list = []
  52. sod_gm_list = []
  53. dis_k_sm_list = []
  54. dis_k_gm_list = []
  55. dis_k_gi_min_list = []
  56. nb_sod_sm2gm = [0, 0, 0]
  57. nb_dis_k_sm2gm = [0, 0, 0]
  58. nb_dis_k_gi2sm = [0, 0, 0]
  59. nb_dis_k_gi2gm = [0, 0, 0]
  60. repeats_better_sod_sm2gm = []
  61. repeats_better_dis_k_sm2gm = []
  62. repeats_better_dis_k_gi2sm = []
  63. repeats_better_dis_k_gi2gm = []
  64. for repeat in range(repeats):
  65. print('\nrepeat =', repeat)
  66. random.seed(rdn_seed_list[repeat])
  67. median_set_idx = random.sample(range(0, len(Gn)), k)
  68. print('median set: ', median_set_idx)
  69. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  70. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  71. fit_method='k-graphs',
  72. edit_costs=edit_costs,
  73. group_min=median_set_idx,
  74. parallel=False)
  75. # write result detail.
  76. sod_sm2gm = getRelations(np.sign(sod_gm - sod_sm))
  77. dis_k_sm2gm = getRelations(np.sign(dis_k_gm - dis_k_sm))
  78. dis_k_gi2sm = getRelations(np.sign(dis_k_sm - dis_k_gi_min))
  79. dis_k_gi2gm = getRelations(np.sign(dis_k_gm - dis_k_gi_min))
  80. f_detail = open(dir_output + fn_output_detail, 'a')
  81. csv.writer(f_detail).writerow([ds_name, gkernel, fit_method, k, repeat,
  82. median_set_idx, sod_sm, sod_gm, dis_k_sm, dis_k_gm,
  83. dis_k_gi_min, sod_sm2gm, dis_k_sm2gm, dis_k_gi2sm,
  84. dis_k_gi2gm])
  85. f_detail.close()
  86. # compute result summary.
  87. sod_sm_list.append(sod_sm)
  88. sod_gm_list.append(sod_gm)
  89. dis_k_sm_list.append(dis_k_sm)
  90. dis_k_gm_list.append(dis_k_gm)
  91. dis_k_gi_min_list.append(dis_k_gi_min)
  92. # # SOD SM -> GM
  93. if sod_sm > sod_gm:
  94. nb_sod_sm2gm[0] += 1
  95. repeats_better_sod_sm2gm.append(repeat)
  96. elif sod_sm == sod_gm:
  97. nb_sod_sm2gm[1] += 1
  98. elif sod_sm < sod_gm:
  99. nb_sod_sm2gm[2] += 1
  100. # # dis_k SM -> GM
  101. if dis_k_sm > dis_k_gm:
  102. nb_dis_k_sm2gm[0] += 1
  103. repeats_better_dis_k_sm2gm.append(repeat)
  104. elif dis_k_sm == dis_k_gm:
  105. nb_dis_k_sm2gm[1] += 1
  106. elif dis_k_sm < dis_k_gm:
  107. nb_dis_k_sm2gm[2] += 1
  108. # # dis_k gi -> SM
  109. if dis_k_gi_min > dis_k_sm:
  110. nb_dis_k_gi2sm[0] += 1
  111. repeats_better_dis_k_gi2sm.append(repeat)
  112. elif dis_k_gi_min == dis_k_sm:
  113. nb_dis_k_gi2sm[1] += 1
  114. elif dis_k_gi_min < dis_k_sm:
  115. nb_dis_k_gi2sm[2] += 1
  116. # # dis_k gi -> GM
  117. if dis_k_gi_min > dis_k_gm:
  118. nb_dis_k_gi2gm[0] += 1
  119. repeats_better_dis_k_gi2gm.append(repeat)
  120. elif dis_k_gi_min == dis_k_gm:
  121. nb_dis_k_gi2gm[1] += 1
  122. elif dis_k_gi_min < dis_k_gm:
  123. nb_dis_k_gi2gm[2] += 1
  124. # write result summary.
  125. sod_sm_mean = np.mean(sod_sm_list)
  126. sod_gm_mean = np.mean(sod_gm_list)
  127. dis_k_sm_mean = np.mean(dis_k_sm_list)
  128. dis_k_gm_mean = np.mean(dis_k_gm_list)
  129. dis_k_gi_min_mean = np.mean(dis_k_gi_min_list)
  130. sod_sm2gm_mean = getRelations(np.sign(sod_gm_mean - sod_sm_mean))
  131. dis_k_sm2gm_mean = getRelations(np.sign(dis_k_gm_mean - dis_k_sm_mean))
  132. dis_k_gi2sm_mean = getRelations(np.sign(dis_k_sm_mean - dis_k_gi_min_mean))
  133. dis_k_gi2gm_mean = getRelations(np.sign(dis_k_gm_mean - dis_k_gi_min_mean))
  134. f_summary = open(dir_output + fn_output_summary, 'a')
  135. csv.writer(f_summary).writerow([ds_name, gkernel, fit_method, k,
  136. sod_sm_mean, sod_gm_mean, dis_k_sm_mean, dis_k_gm_mean,
  137. dis_k_gi_min_mean, sod_sm2gm_mean, dis_k_sm2gm_mean,
  138. dis_k_gi2sm_mean, dis_k_gi2gm_mean, nb_sod_sm2gm,
  139. nb_dis_k_sm2gm, nb_dis_k_gi2sm, nb_dis_k_gi2gm,
  140. repeats_better_sod_sm2gm, repeats_better_dis_k_sm2gm,
  141. repeats_better_dis_k_gi2sm, repeats_better_dis_k_gi2gm])
  142. f_summary.close()
  143. print('\ncomplete.')
  144. return
  145. def getRelations(sign):
  146. if sign == -1:
  147. return 'better'
  148. elif sign == 0:
  149. return 'same'
  150. elif sign == 1:
  151. return 'worse'
  152. if __name__ == '__main__':
  153. find_best_k()

A Python package for graph kernels, graph edit distances and graph pre-image problem.