You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

find_best_k.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Jan 9 11:54:32 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import random
  9. import csv
  10. from gklearn.utils.graphfiles import loadDataset
  11. from gklearn.preimage.test_k_closest_graphs import median_on_k_closest_graphs
  12. def find_best_k():
  13. ds = {'name': 'monoterpenoides',
  14. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  15. Gn, y_all = loadDataset(ds['dataset'])
  16. # Gn = Gn[0:50]
  17. gkernel = 'treeletkernel'
  18. node_label = 'atom'
  19. edge_label = 'bond_type'
  20. ds_name = 'mono'
  21. dir_output = 'results/test_find_best_k/'
  22. repeats = 50
  23. k_list = range(2, 11)
  24. fit_method = 'k-graphs'
  25. # fitted on the whole dataset - treelet - mono
  26. edit_costs = [0.1268873773592978, 0.004084633224249829, 0.0897581955378986, 0.15328856114451297, 0.3109956881625734, 0.0]
  27. # create result files.
  28. fn_output_detail = 'results_detail.' + fit_method + '.csv'
  29. f_detail = open(dir_output + fn_output_detail, 'a')
  30. csv.writer(f_detail).writerow(['dataset', 'graph kernel', 'fit method', 'k',
  31. 'repeat', 'median set', 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  32. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  33. 'dis_k gi -> GM'])
  34. f_detail.close()
  35. fn_output_summary = 'results_summary.csv'
  36. f_summary = open(dir_output + fn_output_summary, 'a')
  37. csv.writer(f_summary).writerow(['dataset', 'graph kernel', 'fit method', 'k',
  38. 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  39. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  40. 'dis_k gi -> GM', '# SOD SM -> GM', '# dis_k SM -> GM',
  41. '# dis_k gi -> SM', '# dis_k gi -> GM', 'repeats better SOD SM -> GM',
  42. 'repeats better dis_k SM -> GM', 'repeats better dis_k gi -> SM',
  43. 'repeats better dis_k gi -> GM'])
  44. f_summary.close()
  45. random.seed(1)
  46. rdn_seed_list = random.sample(range(0, repeats * 100), repeats)
  47. for k in k_list:
  48. print('\n--------- k =', k, '----------')
  49. sod_sm_list = []
  50. sod_gm_list = []
  51. dis_k_sm_list = []
  52. dis_k_gm_list = []
  53. dis_k_gi_min_list = []
  54. nb_sod_sm2gm = [0, 0, 0]
  55. nb_dis_k_sm2gm = [0, 0, 0]
  56. nb_dis_k_gi2sm = [0, 0, 0]
  57. nb_dis_k_gi2gm = [0, 0, 0]
  58. repeats_better_sod_sm2gm = []
  59. repeats_better_dis_k_sm2gm = []
  60. repeats_better_dis_k_gi2sm = []
  61. repeats_better_dis_k_gi2gm = []
  62. for repeat in range(repeats):
  63. print('\nrepeat =', repeat)
  64. random.seed(rdn_seed_list[repeat])
  65. median_set_idx = random.sample(range(0, len(Gn)), k)
  66. print('median set: ', median_set_idx)
  67. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  68. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  69. fit_method='k-graphs',
  70. edit_costs=edit_costs,
  71. group_min=median_set_idx,
  72. parallel=False)
  73. # write result detail.
  74. sod_sm2gm = getRelations(np.sign(sod_gm - sod_sm))
  75. dis_k_sm2gm = getRelations(np.sign(dis_k_gm - dis_k_sm))
  76. dis_k_gi2sm = getRelations(np.sign(dis_k_sm - dis_k_gi_min))
  77. dis_k_gi2gm = getRelations(np.sign(dis_k_gm - dis_k_gi_min))
  78. f_detail = open(dir_output + fn_output_detail, 'a')
  79. csv.writer(f_detail).writerow([ds_name, gkernel, fit_method, k, repeat,
  80. median_set_idx, sod_sm, sod_gm, dis_k_sm, dis_k_gm,
  81. dis_k_gi_min, sod_sm2gm, dis_k_sm2gm, dis_k_gi2sm,
  82. dis_k_gi2gm])
  83. f_detail.close()
  84. # compute result summary.
  85. sod_sm_list.append(sod_sm)
  86. sod_gm_list.append(sod_gm)
  87. dis_k_sm_list.append(dis_k_sm)
  88. dis_k_gm_list.append(dis_k_gm)
  89. dis_k_gi_min_list.append(dis_k_gi_min)
  90. # # SOD SM -> GM
  91. if sod_sm > sod_gm:
  92. nb_sod_sm2gm[0] += 1
  93. repeats_better_sod_sm2gm.append(repeat)
  94. elif sod_sm == sod_gm:
  95. nb_sod_sm2gm[1] += 1
  96. elif sod_sm < sod_gm:
  97. nb_sod_sm2gm[2] += 1
  98. # # dis_k SM -> GM
  99. if dis_k_sm > dis_k_gm:
  100. nb_dis_k_sm2gm[0] += 1
  101. repeats_better_dis_k_sm2gm.append(repeat)
  102. elif dis_k_sm == dis_k_gm:
  103. nb_dis_k_sm2gm[1] += 1
  104. elif dis_k_sm < dis_k_gm:
  105. nb_dis_k_sm2gm[2] += 1
  106. # # dis_k gi -> SM
  107. if dis_k_gi_min > dis_k_sm:
  108. nb_dis_k_gi2sm[0] += 1
  109. repeats_better_dis_k_gi2sm.append(repeat)
  110. elif dis_k_gi_min == dis_k_sm:
  111. nb_dis_k_gi2sm[1] += 1
  112. elif dis_k_gi_min < dis_k_sm:
  113. nb_dis_k_gi2sm[2] += 1
  114. # # dis_k gi -> GM
  115. if dis_k_gi_min > dis_k_gm:
  116. nb_dis_k_gi2gm[0] += 1
  117. repeats_better_dis_k_gi2gm.append(repeat)
  118. elif dis_k_gi_min == dis_k_gm:
  119. nb_dis_k_gi2gm[1] += 1
  120. elif dis_k_gi_min < dis_k_gm:
  121. nb_dis_k_gi2gm[2] += 1
  122. # write result summary.
  123. sod_sm_mean = np.mean(sod_sm_list)
  124. sod_gm_mean = np.mean(sod_gm_list)
  125. dis_k_sm_mean = np.mean(dis_k_sm_list)
  126. dis_k_gm_mean = np.mean(dis_k_gm_list)
  127. dis_k_gi_min_mean = np.mean(dis_k_gi_min_list)
  128. sod_sm2gm_mean = getRelations(np.sign(sod_gm_mean - sod_sm_mean))
  129. dis_k_sm2gm_mean = getRelations(np.sign(dis_k_gm_mean - dis_k_sm_mean))
  130. dis_k_gi2sm_mean = getRelations(np.sign(dis_k_sm_mean - dis_k_gi_min_mean))
  131. dis_k_gi2gm_mean = getRelations(np.sign(dis_k_gm_mean - dis_k_gi_min_mean))
  132. f_summary = open(dir_output + fn_output_summary, 'a')
  133. csv.writer(f_summary).writerow([ds_name, gkernel, fit_method, k,
  134. sod_sm_mean, sod_gm_mean, dis_k_sm_mean, dis_k_gm_mean,
  135. dis_k_gi_min_mean, sod_sm2gm_mean, dis_k_sm2gm_mean,
  136. dis_k_gi2sm_mean, dis_k_gi2gm_mean, nb_sod_sm2gm,
  137. nb_dis_k_sm2gm, nb_dis_k_gi2sm, nb_dis_k_gi2gm,
  138. repeats_better_sod_sm2gm, repeats_better_dis_k_sm2gm,
  139. repeats_better_dis_k_gi2sm, repeats_better_dis_k_gi2gm])
  140. f_summary.close()
  141. print('\ncomplete.')
  142. return
  143. def getRelations(sign):
  144. if sign == -1:
  145. return 'better'
  146. elif sign == 0:
  147. return 'same'
  148. elif sign == 1:
  149. return 'worse'
  150. if __name__ == '__main__':
  151. find_best_k()

A Python package for graph kernels, graph edit distances and graph pre-image problem.