You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_k_closest_graphs.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Dec 16 11:53:54 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import math
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import time
  12. import random
  13. from tqdm import tqdm
  14. from itertools import combinations, islice
  15. import multiprocessing
  16. from multiprocessing import Pool
  17. from functools import partial
  18. #import os
  19. import sys
  20. sys.path.insert(0, "../")
  21. from pygraph.utils.graphfiles import loadDataset, loadGXL
  22. #from pygraph.utils.logger2file import *
  23. from iam import iam_upgraded, iam_bash
  24. from utils import compute_kernel, dis_gstar, kernel_distance_matrix
  25. from fitDistance import fit_GED_to_kernel_distance
  26. #from ged import ged_median
  27. def median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k, fit_method,
  28. graph_dir='/media/ljia/DATA/research-repo/codes/Linlin/py-graph/datasets/monoterpenoides/',
  29. edit_costs=None, group_min=None, dataset='monoterpenoides',
  30. parallel=True):
  31. # # compute distances in kernel space.
  32. # dis_mat, _, _, _ = kernel_distance_matrix(Gn, node_label, edge_label,
  33. # Kmatrix=None, gkernel=gkernel)
  34. # # ged.
  35. # gmfile = np.load('results/test_k_closest_graphs/ged_mat.fit_on_whole_dataset.with_medians.gm.npz')
  36. # ged_mat = gmfile['ged_mat']
  37. # dis_mat = ged_mat[0:len(Gn), 0:len(Gn)]
  38. # # choose k closest graphs
  39. # time0 = time.time()
  40. # sod_ks_min, group_min = get_closest_k_graphs(dis_mat, k, parallel)
  41. # time_spent = time.time() - time0
  42. # print('closest graphs:', sod_ks_min, group_min)
  43. # print('time spent:', time_spent)
  44. # group_min = (12, 13, 22, 29) # closest w.r.t path kernel
  45. # group_min = (77, 85, 160, 171) # closest w.r.t ged
  46. # group_min = (0,1,2,3,4,5,6,7,8,9,10,11) # closest w.r.t treelet kernel
  47. Gn_median = [Gn[g].copy() for g in group_min]
  48. # fit edit costs.
  49. if fit_method == 'random': # random
  50. edit_cost_constant = random.sample(range(1, 10), 6)
  51. print('edit costs used:', edit_cost_constant)
  52. elif fit_method == 'expert': # expert
  53. edit_cost_constant = [3, 3, 1, 3, 3, 1]
  54. elif fit_method == 'k-graphs':
  55. itr_max = 6
  56. algo_options = '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  57. params_ged = {'lib': 'gedlibpy', 'cost': 'CONSTANT', 'method': 'IPFP',
  58. 'algo_options': algo_options, 'stabilizer': None}
  59. # fit on k-graph subset
  60. edit_cost_constant, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn_median,
  61. node_label, edge_label, gkernel, itr_max, params_ged=params_ged, parallel=True)
  62. elif fit_method == 'whole-dataset':
  63. itr_max = 6
  64. algo_options = '--threads 8 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  65. params_ged = {'lib': 'gedlibpy', 'cost': 'CONSTANT', 'method': 'IPFP',
  66. 'algo_options': algo_options, 'stabilizer': None}
  67. # fit on all subset
  68. edit_cost_constant, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn,
  69. node_label, edge_label, gkernel, itr_max, params_ged=params_ged, parallel=True)
  70. elif fit_method == 'precomputed':
  71. edit_cost_constant = edit_costs
  72. # compute set median and gen median using IAM (C++ through bash).
  73. group_fnames = [Gn[g].graph['filename'] for g in group_min]
  74. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constant,
  75. graph_dir=graph_dir, dataset=dataset)
  76. # compute distances in kernel space.
  77. Gn_median = [Gn[g].copy() for g in group_min]
  78. set_median = loadGXL(fname_sm)
  79. gen_median = loadGXL(fname_gm)
  80. if dataset == 'Letter':
  81. for g in Gn_median:
  82. reform_attributes(g)
  83. reform_attributes(set_median)
  84. reform_attributes(gen_median)
  85. # compute distance in kernel space for set median.
  86. Kmatrix_sm = compute_kernel([set_median] + Gn_median, gkernel,
  87. None if dataset == 'Letter' else 'chem',
  88. None if dataset == 'Letter' else 'valence',
  89. False)
  90. dis_k_sm = dis_gstar(0, range(1, 1+len(Gn_median)),
  91. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_sm, withterm3=False)
  92. # compute distance in kernel space for generalized median.
  93. Kmatrix_gm = compute_kernel([gen_median] + Gn_median, gkernel,
  94. None if dataset == 'Letter' else 'chem',
  95. None if dataset == 'Letter' else 'valence',
  96. False)
  97. dis_k_gm = dis_gstar(0, range(1, 1+len(Gn_median)),
  98. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False)
  99. # compute distance in kernel space for each graph in median set.
  100. dis_k_gi = []
  101. for idx in range(len(Gn_median)):
  102. dis_k_gi.append(dis_gstar(idx+1, range(1, 1+len(Gn_median)),
  103. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False))
  104. print('sod_sm:', sod_sm)
  105. print('sod_gm:', sod_gm)
  106. print('dis_k_sm:', dis_k_sm)
  107. print('dis_k_gm:', dis_k_gm)
  108. print('dis_k_gi:', dis_k_gi)
  109. idx_dis_k_gi_min = np.argmin(dis_k_gi)
  110. dis_k_gi_min = dis_k_gi[idx_dis_k_gi_min]
  111. print('index min dis_k_gi:', group_min[idx_dis_k_gi_min])
  112. print('min dis_k_gi:', dis_k_gi_min)
  113. return sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, group_min[idx_dis_k_gi_min]
  114. def reform_attributes(G):
  115. for node in G.nodes:
  116. G.nodes[node]['attributes'] = [G.nodes[node]['x'], G.nodes[node]['y']]
  117. def get_closest_k_graphs(dis_mat, k, parallel):
  118. k_graph_groups = combinations(range(0, len(dis_mat)), k)
  119. sod_ks_min = np.inf
  120. if parallel:
  121. len_combination = get_combination_length(len(dis_mat), k)
  122. len_itr_max = int(len_combination if len_combination < 1e7 else 1e7)
  123. # pos_cur = 0
  124. graph_groups_slices = split_iterable(k_graph_groups, len_itr_max, len_combination)
  125. for graph_groups_cur in graph_groups_slices:
  126. # while True:
  127. # graph_groups_cur = islice(k_graph_groups, pos_cur, pos_cur + len_itr_max)
  128. graph_groups_cur_list = list(graph_groups_cur)
  129. print('current position:', graph_groups_cur_list[0])
  130. len_itr_cur = len(graph_groups_cur_list)
  131. # if len_itr_cur < len_itr_max:
  132. # break
  133. itr = zip(graph_groups_cur_list, range(0, len_itr_cur))
  134. sod_k_list = np.empty(len_itr_cur)
  135. graphs_list = [None] * len_itr_cur
  136. n_jobs = multiprocessing.cpu_count()
  137. chunksize = int(len_itr_max / n_jobs + 1)
  138. n_jobs = multiprocessing.cpu_count()
  139. def init_worker(dis_mat_toshare):
  140. global G_dis_mat
  141. G_dis_mat = dis_mat_toshare
  142. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(dis_mat,))
  143. # iterator = tqdm(pool.imap_unordered(_get_closest_k_graphs_parallel,
  144. # itr, chunksize),
  145. # desc='Choosing k closest graphs', file=sys.stdout)
  146. iterator = pool.imap_unordered(_get_closest_k_graphs_parallel, itr, chunksize)
  147. for graphs, i, sod_ks in iterator:
  148. sod_k_list[i] = sod_ks
  149. graphs_list[i] = graphs
  150. pool.close()
  151. pool.join()
  152. arg_min = np.argmin(sod_k_list)
  153. sod_ks_cur = sod_k_list[arg_min]
  154. group_cur = graphs_list[arg_min]
  155. if sod_ks_cur < sod_ks_min:
  156. sod_ks_min = sod_ks_cur
  157. group_min = group_cur
  158. print('get closer graphs:', sod_ks_min, group_min)
  159. else:
  160. for items in tqdm(k_graph_groups, desc='Choosing k closest graphs', file=sys.stdout):
  161. # if items[0] != itmp:
  162. # itmp = items[0]
  163. # print(items)
  164. k_graph_pairs = combinations(items, 2)
  165. sod_ks = 0
  166. for i1, i2 in k_graph_pairs:
  167. sod_ks += dis_mat[i1, i2]
  168. if sod_ks < sod_ks_min:
  169. sod_ks_min = sod_ks
  170. group_min = items
  171. print('get closer graphs:', sod_ks_min, group_min)
  172. return sod_ks_min, group_min
  173. def _get_closest_k_graphs_parallel(itr):
  174. k_graph_pairs = combinations(itr[0], 2)
  175. sod_ks = 0
  176. for i1, i2 in k_graph_pairs:
  177. sod_ks += G_dis_mat[i1, i2]
  178. return itr[0], itr[1], sod_ks
  179. def split_iterable(iterable, n, len_iter):
  180. it = iter(iterable)
  181. for i in range(0, len_iter, n):
  182. piece = islice(it, n)
  183. yield piece
  184. def get_combination_length(n, k):
  185. len_combination = 1
  186. for i in range(n, n - k, -1):
  187. len_combination *= i
  188. return int(len_combination / math.factorial(k))
  189. ###############################################################################
  190. def test_k_closest_graphs():
  191. ds = {'name': 'monoterpenoides',
  192. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  193. Gn, y_all = loadDataset(ds['dataset'])
  194. # Gn = Gn[0:50]
  195. # gkernel = 'untilhpathkernel'
  196. # gkernel = 'weisfeilerlehmankernel'
  197. gkernel = 'treeletkernel'
  198. node_label = 'atom'
  199. edge_label = 'bond_type'
  200. k = 5
  201. edit_costs = [0.16229209837639536, 0.06612870523413916, 0.04030113378793905, 0.20723547009415202, 0.3338607220394598, 0.27054392518077297]
  202. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  203. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  204. # 'precomputed', edit_costs=edit_costs,
  205. ## 'k-graphs',
  206. # parallel=False)
  207. #
  208. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  209. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  210. # 'expert', parallel=False)
  211. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  212. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  213. 'expert', parallel=False)
  214. return
  215. def test_k_closest_graphs_with_cv():
  216. gkernel = 'untilhpathkernel'
  217. node_label = 'atom'
  218. edge_label = 'bond_type'
  219. k = 4
  220. y_all = ['3', '1', '4', '6', '7', '8', '9', '2']
  221. repeats = 50
  222. collection_path = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/'
  223. graph_dir = collection_path + 'gxl/'
  224. sod_sm_list = []
  225. sod_gm_list = []
  226. dis_k_sm_list = []
  227. dis_k_gm_list = []
  228. dis_k_gi_min_list = []
  229. for y in y_all:
  230. print('\n-------------------------------------------------------')
  231. print('class of y:', y)
  232. sod_sm_list.append([])
  233. sod_gm_list.append([])
  234. dis_k_sm_list.append([])
  235. dis_k_gm_list.append([])
  236. dis_k_gi_min_list.append([])
  237. for repeat in range(repeats):
  238. print('\nrepeat ', repeat)
  239. collection_file = collection_path + 'monoterpenoides_' + y + '_' + str(repeat) + '.xml'
  240. Gn, _ = loadDataset(collection_file, extra_params=graph_dir)
  241. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  242. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel,
  243. k, 'whole-dataset', graph_dir=graph_dir,
  244. parallel=False)
  245. sod_sm_list[-1].append(sod_sm)
  246. sod_gm_list[-1].append(sod_gm)
  247. dis_k_sm_list[-1].append(dis_k_sm)
  248. dis_k_gm_list[-1].append(dis_k_gm)
  249. dis_k_gi_min_list[-1].append(dis_k_gi_min)
  250. print('\nsods of the set median for this class:', sod_sm_list[-1])
  251. print('\nsods of the gen median for this class:', sod_gm_list[-1])
  252. print('\ndistances in kernel space of set median for this class:',
  253. dis_k_sm_list[-1])
  254. print('\ndistances in kernel space of gen median for this class:',
  255. dis_k_gm_list[-1])
  256. print('\ndistances in kernel space of min graph for this class:',
  257. dis_k_gi_min_list[-1])
  258. sod_sm_list[-1] = np.mean(sod_sm_list[-1])
  259. sod_gm_list[-1] = np.mean(sod_gm_list[-1])
  260. dis_k_sm_list[-1] = np.mean(dis_k_sm_list[-1])
  261. dis_k_gm_list[-1] = np.mean(dis_k_gm_list[-1])
  262. dis_k_gi_min_list[-1] = np.mean(dis_k_gi_min_list[-1])
  263. print()
  264. print('\nmean sods of the set median for each class:', sod_sm_list)
  265. print('\nmean sods of the gen median for each class:', sod_gm_list)
  266. print('\nmean distance in kernel space of set median for each class:',
  267. dis_k_sm_list)
  268. print('\nmean distances in kernel space of gen median for each class:',
  269. dis_k_gm_list)
  270. print('\nmean distances in kernel space of min graph for each class:',
  271. dis_k_gi_min_list)
  272. print('\nmean sods of the set median of all:', np.mean(sod_sm_list))
  273. print('\nmean sods of the gen median of all:', np.mean(sod_gm_list))
  274. print('\nmean distances in kernel space of set median of all:',
  275. np.mean(dis_k_sm_list))
  276. print('\nmean distances in kernel space of gen median of all:',
  277. np.mean(dis_k_gm_list))
  278. print('\nmean distances in kernel space of min graph of all:',
  279. np.mean(dis_k_gi_min_list))
  280. return
  281. if __name__ == '__main__':
  282. test_k_closest_graphs()
  283. # test_k_closest_graphs_with_cv()

A Python package for graph kernels, graph edit distances and graph pre-image problem.