You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

knn.py 4.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Jan 10 13:22:04 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. #import matplotlib.pyplot as plt
  9. from tqdm import tqdm
  10. import random
  11. #import csv
  12. from shutil import copyfile
  13. import sys
  14. sys.path.insert(0, "../")
  15. from preimage.iam import iam_bash
  16. from gklearn.utils.graphfiles import loadDataset, loadGXL
  17. from preimage.ged import GED
  18. from preimage.utils import get_same_item_indices
  19. def test_knn():
  20. ds = {'name': 'monoterpenoides',
  21. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  22. Gn, y_all = loadDataset(ds['dataset'])
  23. # Gn = Gn[0:50]
  24. # gkernel = 'treeletkernel'
  25. # node_label = 'atom'
  26. # edge_label = 'bond_type'
  27. # ds_name = 'mono'
  28. dir_output = 'results/knn/'
  29. graph_dir='/media/ljia/DATA/research-repo/codes/Linlin/graphkit-learn/datasets/monoterpenoides/'
  30. k_nn = 1
  31. percent = 0.1
  32. repeats = 50
  33. edit_cost_constant = [3, 3, 1, 3, 3, 1]
  34. # get indices by classes.
  35. y_idx = get_same_item_indices(y_all)
  36. sod_sm_list_list
  37. for repeat in range(0, repeats):
  38. print('\n---------------------------------')
  39. print('repeat =', repeat)
  40. accuracy_sm_list = []
  41. accuracy_gm_list = []
  42. sod_sm_list = []
  43. sod_gm_list = []
  44. random.seed(repeat)
  45. set_median_list = []
  46. gen_median_list = []
  47. train_y_set = []
  48. for y, values in y_idx.items():
  49. print('\ny =', y)
  50. size_median_set = int(len(values) * percent)
  51. median_set_idx = random.sample(values, size_median_set)
  52. print('median set: ', median_set_idx)
  53. # compute set median and gen median using IAM (C++ through bash).
  54. # Gn_median = [Gn[idx] for idx in median_set_idx]
  55. group_fnames = [Gn[g].graph['filename'] for g in median_set_idx]
  56. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constant,
  57. graph_dir=graph_dir)
  58. print('sod_sm, sod_gm:', sod_sm, sod_gm)
  59. sod_sm_list.append(sod_sm)
  60. sod_gm_list.append(sod_gm)
  61. fname_sm_new = dir_output + 'medians/set_median.y' + str(int(y)) + '.repeat' + str(repeat) + '.gxl'
  62. copyfile(fname_sm, fname_sm_new)
  63. fname_gm_new = dir_output + 'medians/gen_median.y' + str(int(y)) + '.repeat' + str(repeat) + '.gxl'
  64. copyfile(fname_gm, fname_gm_new)
  65. set_median_list.append(loadGXL(fname_sm_new))
  66. gen_median_list.append(loadGXL(fname_gm_new))
  67. train_y_set.append(int(y))
  68. print(sod_sm, sod_gm)
  69. # do 1-nn.
  70. test_y_set = [int(y) for y in y_all]
  71. accuracy_sm = knn(set_median_list, train_y_set, Gn, test_y_set, k=k_nn, distance='ged')
  72. accuracy_gm = knn(set_median_list, train_y_set, Gn, test_y_set, k=k_nn, distance='ged')
  73. accuracy_sm_list.append(accuracy_sm)
  74. accuracy_gm_list.append(accuracy_gm)
  75. print('current accuracy sm and gm:', accuracy_sm, accuracy_gm)
  76. # output
  77. accuracy_sm_mean = np.mean(accuracy_sm_list)
  78. accuracy_gm_mean = np.mean(accuracy_gm_list)
  79. print('\ntotal average accuracy sm and gm:', accuracy_sm_mean, accuracy_gm_mean)
  80. def knn(train_set, train_y_set, test_set, test_y_set, k=1, distance='ged'):
  81. if k == 1 and distance == 'ged':
  82. algo_options = '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  83. params_ged = {'lib': 'gedlibpy', 'cost': 'CONSTANT', 'method': 'IPFP',
  84. 'algo_options': algo_options, 'stabilizer': None}
  85. accuracy = 0
  86. for idx_test, g_test in tqdm(enumerate(test_set), desc='computing 1-nn',
  87. file=sys.stdout):
  88. dis = np.inf
  89. for idx_train, g_train in enumerate(train_set):
  90. dis_cur, _, _ = GED(g_test, g_train, **params_ged)
  91. if dis_cur < dis:
  92. dis = dis_cur
  93. test_y_cur = train_y_set[idx_train]
  94. if test_y_cur == test_y_set[idx_test]:
  95. accuracy += 1
  96. accuracy = accuracy / len(test_set)
  97. return accuracy
  98. if __name__ == '__main__':
  99. test_knn()

A Python package for graph kernels, graph edit distances and graph pre-image problem.