You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

knn.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Jan 10 13:22:04 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. #import matplotlib.pyplot as plt
  9. from tqdm import tqdm
  10. import random
  11. #import csv
  12. from shutil import copyfile
  13. import os
  14. from gklearn.preimage.iam import iam_bash
  15. from gklearn.utils.graphfiles import loadDataset, loadGXL
  16. from gklearn.preimage.ged import GED
  17. from gklearn.preimage.utils import get_same_item_indices
  18. def test_knn():
  19. ds = {'name': 'monoterpenoides',
  20. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  21. Gn, y_all = loadDataset(ds['dataset'])
  22. # Gn = Gn[0:50]
  23. # gkernel = 'treeletkernel'
  24. # node_label = 'atom'
  25. # edge_label = 'bond_type'
  26. # ds_name = 'mono'
  27. dir_output = 'results/knn/'
  28. graph_dir = os.path.dirname(os.path.realpath(__file__)) + '../../datasets/monoterpenoides/'
  29. k_nn = 1
  30. percent = 0.1
  31. repeats = 50
  32. edit_cost_constant = [3, 3, 1, 3, 3, 1]
  33. # get indices by classes.
  34. y_idx = get_same_item_indices(y_all)
  35. sod_sm_list_list
  36. for repeat in range(0, repeats):
  37. print('\n---------------------------------')
  38. print('repeat =', repeat)
  39. accuracy_sm_list = []
  40. accuracy_gm_list = []
  41. sod_sm_list = []
  42. sod_gm_list = []
  43. random.seed(repeat)
  44. set_median_list = []
  45. gen_median_list = []
  46. train_y_set = []
  47. for y, values in y_idx.items():
  48. print('\ny =', y)
  49. size_median_set = int(len(values) * percent)
  50. median_set_idx = random.sample(values, size_median_set)
  51. print('median set: ', median_set_idx)
  52. # compute set median and gen median using IAM (C++ through bash).
  53. # Gn_median = [Gn[idx] for idx in median_set_idx]
  54. group_fnames = [Gn[g].graph['filename'] for g in median_set_idx]
  55. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constant,
  56. graph_dir=graph_dir)
  57. print('sod_sm, sod_gm:', sod_sm, sod_gm)
  58. sod_sm_list.append(sod_sm)
  59. sod_gm_list.append(sod_gm)
  60. fname_sm_new = dir_output + 'medians/set_median.y' + str(int(y)) + '.repeat' + str(repeat) + '.gxl'
  61. copyfile(fname_sm, fname_sm_new)
  62. fname_gm_new = dir_output + 'medians/gen_median.y' + str(int(y)) + '.repeat' + str(repeat) + '.gxl'
  63. copyfile(fname_gm, fname_gm_new)
  64. set_median_list.append(loadGXL(fname_sm_new))
  65. gen_median_list.append(loadGXL(fname_gm_new))
  66. train_y_set.append(int(y))
  67. print(sod_sm, sod_gm)
  68. # do 1-nn.
  69. test_y_set = [int(y) for y in y_all]
  70. accuracy_sm = knn(set_median_list, train_y_set, Gn, test_y_set, k=k_nn, distance='ged')
  71. accuracy_gm = knn(set_median_list, train_y_set, Gn, test_y_set, k=k_nn, distance='ged')
  72. accuracy_sm_list.append(accuracy_sm)
  73. accuracy_gm_list.append(accuracy_gm)
  74. print('current accuracy sm and gm:', accuracy_sm, accuracy_gm)
  75. # output
  76. accuracy_sm_mean = np.mean(accuracy_sm_list)
  77. accuracy_gm_mean = np.mean(accuracy_gm_list)
  78. print('\ntotal average accuracy sm and gm:', accuracy_sm_mean, accuracy_gm_mean)
  79. def knn(train_set, train_y_set, test_set, test_y_set, k=1, distance='ged'):
  80. if k == 1 and distance == 'ged':
  81. algo_options = '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  82. params_ged = {'lib': 'gedlibpy', 'cost': 'CONSTANT', 'method': 'IPFP',
  83. 'algo_options': algo_options, 'stabilizer': None}
  84. accuracy = 0
  85. for idx_test, g_test in tqdm(enumerate(test_set), desc='computing 1-nn',
  86. file=sys.stdout):
  87. dis = np.inf
  88. for idx_train, g_train in enumerate(train_set):
  89. dis_cur, _, _ = GED(g_test, g_train, **params_ged)
  90. if dis_cur < dis:
  91. dis = dis_cur
  92. test_y_cur = train_y_set[idx_train]
  93. if test_y_cur == test_y_set[idx_test]:
  94. accuracy += 1
  95. accuracy = accuracy / len(test_set)
  96. return accuracy
  97. if __name__ == '__main__':
  98. test_knn()

A Python package for graph kernels, graph edit distances and graph pre-image problem.