You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_k_closest_graphs.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Dec 16 11:53:54 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import math
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import time
  12. import random
  13. from tqdm import tqdm
  14. from itertools import combinations, islice
  15. import multiprocessing
  16. from multiprocessing import Pool
  17. from functools import partial
  18. from gklearn.utils.graphfiles import loadDataset, loadGXL
  19. #from gklearn.utils.logger2file import *
  20. from gklearn.preimage.iam import iam_upgraded, iam_bash
  21. from gklearn.preimage.utils import compute_kernel, dis_gstar, kernel_distance_matrix
  22. from gklearn.preimage.fitDistance import fit_GED_to_kernel_distance
  23. #from gklearn.preimage.ged import ged_median
  24. def fit_edit_cost_constants(fit_method, edit_cost_name,
  25. edit_cost_constants=None, initial_solutions=1,
  26. Gn_median=None, node_label=None, edge_label=None,
  27. gkernel=None, dataset=None, init_ecc=None,
  28. Gn=None, Kmatrix_median=None):
  29. """fit edit cost constants.
  30. """
  31. if fit_method == 'random': # random
  32. if edit_cost_name == 'LETTER':
  33. edit_cost_constants = random.sample(range(1, 10), 3)
  34. edit_cost_constants = [item * 0.1 for item in edit_cost_constants]
  35. elif edit_cost_name == 'LETTER2':
  36. random.seed(time.time())
  37. edit_cost_constants = random.sample(range(1, 10), 5)
  38. # edit_cost_constants = [item * 0.1 for item in edit_cost_constants]
  39. elif edit_cost_name == 'NON_SYMBOLIC':
  40. edit_cost_constants = random.sample(range(1, 10), 6)
  41. if Gn_median[0].graph['node_attrs'] == []:
  42. edit_cost_constants[2] = 0
  43. if Gn_median[0].graph['edge_attrs'] == []:
  44. edit_cost_constants[5] = 0
  45. else:
  46. edit_cost_constants = random.sample(range(1, 10), 6)
  47. print('edit cost constants used:', edit_cost_constants)
  48. elif fit_method == 'expert': # expert
  49. if init_ecc is None:
  50. if edit_cost_name == 'LETTER':
  51. edit_cost_constants = [0.9, 1.7, 0.75]
  52. elif edit_cost_name == 'LETTER2':
  53. edit_cost_constants = [0.675, 0.675, 0.75, 0.425, 0.425]
  54. else:
  55. edit_cost_constants = [3, 3, 1, 3, 3, 1]
  56. else:
  57. edit_cost_constants = init_ecc
  58. elif fit_method == 'k-graphs':
  59. itr_max = 6
  60. if init_ecc is None:
  61. if edit_cost_name == 'LETTER':
  62. init_costs = [0.9, 1.7, 0.75]
  63. elif edit_cost_name == 'LETTER2':
  64. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  65. elif edit_cost_name == 'NON_SYMBOLIC':
  66. init_costs = [0, 0, 1, 1, 1, 0]
  67. if Gn_median[0].graph['node_attrs'] == []:
  68. init_costs[2] = 0
  69. if Gn_median[0].graph['edge_attrs'] == []:
  70. init_costs[5] = 0
  71. else:
  72. init_costs = [3, 3, 1, 3, 3, 1]
  73. else:
  74. init_costs = init_ecc
  75. algo_options = '--threads 1 --initial-solutions ' \
  76. + str(initial_solutions) + ' --ratio-runs-from-initial-solutions 1'
  77. params_ged = {'lib': 'gedlibpy', 'cost': edit_cost_name, 'method': 'IPFP',
  78. 'algo_options': algo_options, 'stabilizer': None}
  79. # fit on k-graph subset
  80. edit_cost_constants, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn_median,
  81. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  82. init_costs=init_costs, dataset=dataset, Kmatrix=Kmatrix_median,
  83. parallel=True)
  84. elif fit_method == 'whole-dataset':
  85. itr_max = 6
  86. if init_ecc is None:
  87. if edit_cost_name == 'LETTER':
  88. init_costs = [0.9, 1.7, 0.75]
  89. elif edit_cost_name == 'LETTER2':
  90. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  91. else:
  92. init_costs = [3, 3, 1, 3, 3, 1]
  93. else:
  94. init_costs = init_ecc
  95. algo_options = '--threads 1 --initial-solutions ' \
  96. + str(initial_solutions) + ' --ratio-runs-from-initial-solutions 1'
  97. params_ged = {'lib': 'gedlibpy', 'cost': edit_cost_name, 'method': 'IPFP',
  98. 'algo_options': algo_options, 'stabilizer': None}
  99. # fit on all subset
  100. edit_cost_constants, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn,
  101. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  102. init_costs=init_costs, dataset=dataset, parallel=True)
  103. elif fit_method == 'precomputed':
  104. pass
  105. return edit_cost_constants
  106. def compute_distances_to_true_median(Gn_median, fname_sm, fname_gm,
  107. gkernel, edit_cost_name,
  108. Kmatrix_median=None):
  109. # reform graphs.
  110. set_median = loadGXL(fname_sm)
  111. gen_median = loadGXL(fname_gm)
  112. # print(gen_median.nodes(data=True))
  113. # print(gen_median.edges(data=True))
  114. if edit_cost_name == 'LETTER' or edit_cost_name == 'LETTER2' or edit_cost_name == 'NON_SYMBOLIC':
  115. # dataset == 'Fingerprint':
  116. # for g in Gn_median:
  117. # reform_attributes(g)
  118. reform_attributes(set_median, Gn_median[0].graph['node_attrs'],
  119. Gn_median[0].graph['edge_attrs'])
  120. reform_attributes(gen_median, Gn_median[0].graph['node_attrs'],
  121. Gn_median[0].graph['edge_attrs'])
  122. if edit_cost_name == 'LETTER' or edit_cost_name == 'LETTER2' or edit_cost_name == 'NON_SYMBOLIC':
  123. node_label = None
  124. edge_label = None
  125. else:
  126. node_label = 'chem'
  127. edge_label = 'valence'
  128. # compute Gram matrix for median set.
  129. if Kmatrix_median is None:
  130. Kmatrix_median = compute_kernel(Gn_median, gkernel, node_label, edge_label, False)
  131. # compute distance in kernel space for set median.
  132. kernel_sm = []
  133. for G_median in Gn_median:
  134. km_tmp = compute_kernel([set_median, G_median], gkernel, node_label, edge_label, False)
  135. kernel_sm.append(km_tmp[0, 1])
  136. Kmatrix_sm = np.concatenate((np.array([kernel_sm]), np.copy(Kmatrix_median)), axis=0)
  137. Kmatrix_sm = np.concatenate((np.array([[km_tmp[0, 0]] + kernel_sm]).T, Kmatrix_sm), axis=1)
  138. # Kmatrix_sm = compute_kernel([set_median] + Gn_median, gkernel,
  139. # node_label, edge_label, False)
  140. dis_k_sm = dis_gstar(0, range(1, 1+len(Gn_median)),
  141. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_sm, withterm3=False)
  142. # print(gen_median.nodes(data=True))
  143. # print(gen_median.edges(data=True))
  144. # print(set_median.nodes(data=True))
  145. # print(set_median.edges(data=True))
  146. # compute distance in kernel space for generalized median.
  147. kernel_gm = []
  148. for G_median in Gn_median:
  149. km_tmp = compute_kernel([gen_median, G_median], gkernel, node_label, edge_label, False)
  150. kernel_gm.append(km_tmp[0, 1])
  151. Kmatrix_gm = np.concatenate((np.array([kernel_gm]), np.copy(Kmatrix_median)), axis=0)
  152. Kmatrix_gm = np.concatenate((np.array([[km_tmp[0, 0]] + kernel_gm]).T, Kmatrix_gm), axis=1)
  153. # Kmatrix_gm = compute_kernel([gen_median] + Gn_median, gkernel,
  154. # node_label, edge_label, False)
  155. dis_k_gm = dis_gstar(0, range(1, 1+len(Gn_median)),
  156. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False)
  157. # compute distance in kernel space for each graph in median set.
  158. dis_k_gi = []
  159. for idx in range(len(Gn_median)):
  160. dis_k_gi.append(dis_gstar(idx+1, range(1, 1+len(Gn_median)),
  161. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False))
  162. print('dis_k_sm:', dis_k_sm)
  163. print('dis_k_gm:', dis_k_gm)
  164. print('dis_k_gi:', dis_k_gi)
  165. idx_dis_k_gi_min = np.argmin(dis_k_gi)
  166. dis_k_gi_min = dis_k_gi[idx_dis_k_gi_min]
  167. print('min dis_k_gi:', dis_k_gi_min)
  168. return dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min
  169. def median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k, fit_method,
  170. graph_dir=None, initial_solutions=1,
  171. edit_cost_constants=None, group_min=None,
  172. dataset=None, edit_cost_name=None, init_ecc=None,
  173. Kmatrix=None, parallel=True):
  174. # dataset = dataset.lower()
  175. # # compute distances in kernel space.
  176. # dis_mat, _, _, _ = kernel_distance_matrix(Gn, node_label, edge_label,
  177. # Kmatrix=None, gkernel=gkernel)
  178. # # ged.
  179. # gmfile = np.load('results/test_k_closest_graphs/ged_mat.fit_on_whole_dataset.with_medians.gm.npz')
  180. # ged_mat = gmfile['ged_mat']
  181. # dis_mat = ged_mat[0:len(Gn), 0:len(Gn)]
  182. # # choose k closest graphs
  183. # time0 = time.time()
  184. # sod_ks_min, group_min = get_closest_k_graphs(dis_mat, k, parallel)
  185. # time_spent = time.time() - time0
  186. # print('closest graphs:', sod_ks_min, group_min)
  187. # print('time spent:', time_spent)
  188. # group_min = (12, 13, 22, 29) # closest w.r.t path kernel
  189. # group_min = (77, 85, 160, 171) # closest w.r.t ged
  190. # group_min = (0,1,2,3,4,5,6,7,8,9,10,11) # closest w.r.t treelet kernel
  191. Gn_median = [Gn[g].copy() for g in group_min]
  192. if Kmatrix is not None:
  193. Kmatrix_median = np.copy(Kmatrix[group_min,:])
  194. Kmatrix_median = Kmatrix_median[:,group_min]
  195. else:
  196. Kmatrix_median = None
  197. # 1. fit edit cost constants.
  198. time0 = time.time()
  199. edit_cost_constants = fit_edit_cost_constants(fit_method, edit_cost_name,
  200. edit_cost_constants=edit_cost_constants, initial_solutions=initial_solutions,
  201. Gn_median=Gn_median, node_label=node_label, edge_label=edge_label,
  202. gkernel=gkernel, dataset=dataset, init_ecc=init_ecc,
  203. Gn=Gn, Kmatrix_median=Kmatrix_median)
  204. time_fitting = time.time() - time0
  205. # 2. compute set median and gen median using IAM (C++ through bash).
  206. print('\nstart computing set median and gen median using IAM (C++ through bash)...\n')
  207. group_fnames = [Gn[g].graph['filename'] for g in group_min]
  208. time0 = time.time()
  209. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constants,
  210. cost=edit_cost_name, initial_solutions=initial_solutions,
  211. graph_dir=graph_dir, dataset=dataset)
  212. time_generating = time.time() - time0
  213. print('\nmedians computed.\n')
  214. # 3. compute distances to real median.
  215. print('\nstart computing distances to true median....\n')
  216. Gn_median = [Gn[g].copy() for g in group_min]
  217. dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min = \
  218. compute_distances_to_true_median(Gn_median, fname_sm, fname_gm,
  219. gkernel, edit_cost_name,
  220. Kmatrix_median=Kmatrix_median)
  221. idx_dis_k_gi_min = group_min[idx_dis_k_gi_min]
  222. print('index min dis_k_gi:', idx_dis_k_gi_min)
  223. print('sod_sm:', sod_sm)
  224. print('sod_gm:', sod_gm)
  225. # collect return values.
  226. return (sod_sm, sod_gm), \
  227. (dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min), \
  228. (time_fitting, time_generating)
  229. def reform_attributes(G, na_names=[], ea_names=[]):
  230. if not na_names == []:
  231. for node in G.nodes:
  232. G.nodes[node]['attributes'] = [G.node[node][a_name] for a_name in na_names]
  233. if not ea_names == []:
  234. for edge in G.edges:
  235. G.edges[edge]['attributes'] = [G.edge[edge][a_name] for a_name in ea_names]
  236. def get_closest_k_graphs(dis_mat, k, parallel):
  237. k_graph_groups = combinations(range(0, len(dis_mat)), k)
  238. sod_ks_min = np.inf
  239. if parallel:
  240. len_combination = get_combination_length(len(dis_mat), k)
  241. len_itr_max = int(len_combination if len_combination < 1e7 else 1e7)
  242. # pos_cur = 0
  243. graph_groups_slices = split_iterable(k_graph_groups, len_itr_max, len_combination)
  244. for graph_groups_cur in graph_groups_slices:
  245. # while True:
  246. # graph_groups_cur = islice(k_graph_groups, pos_cur, pos_cur + len_itr_max)
  247. graph_groups_cur_list = list(graph_groups_cur)
  248. print('current position:', graph_groups_cur_list[0])
  249. len_itr_cur = len(graph_groups_cur_list)
  250. # if len_itr_cur < len_itr_max:
  251. # break
  252. itr = zip(graph_groups_cur_list, range(0, len_itr_cur))
  253. sod_k_list = np.empty(len_itr_cur)
  254. graphs_list = [None] * len_itr_cur
  255. n_jobs = multiprocessing.cpu_count()
  256. chunksize = int(len_itr_max / n_jobs + 1)
  257. n_jobs = multiprocessing.cpu_count()
  258. def init_worker(dis_mat_toshare):
  259. global G_dis_mat
  260. G_dis_mat = dis_mat_toshare
  261. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(dis_mat,))
  262. # iterator = tqdm(pool.imap_unordered(_get_closest_k_graphs_parallel,
  263. # itr, chunksize),
  264. # desc='Choosing k closest graphs', file=sys.stdout)
  265. iterator = pool.imap_unordered(_get_closest_k_graphs_parallel, itr, chunksize)
  266. for graphs, i, sod_ks in iterator:
  267. sod_k_list[i] = sod_ks
  268. graphs_list[i] = graphs
  269. pool.close()
  270. pool.join()
  271. arg_min = np.argmin(sod_k_list)
  272. sod_ks_cur = sod_k_list[arg_min]
  273. group_cur = graphs_list[arg_min]
  274. if sod_ks_cur < sod_ks_min:
  275. sod_ks_min = sod_ks_cur
  276. group_min = group_cur
  277. print('get closer graphs:', sod_ks_min, group_min)
  278. else:
  279. for items in tqdm(k_graph_groups, desc='Choosing k closest graphs', file=sys.stdout):
  280. # if items[0] != itmp:
  281. # itmp = items[0]
  282. # print(items)
  283. k_graph_pairs = combinations(items, 2)
  284. sod_ks = 0
  285. for i1, i2 in k_graph_pairs:
  286. sod_ks += dis_mat[i1, i2]
  287. if sod_ks < sod_ks_min:
  288. sod_ks_min = sod_ks
  289. group_min = items
  290. print('get closer graphs:', sod_ks_min, group_min)
  291. return sod_ks_min, group_min
  292. def _get_closest_k_graphs_parallel(itr):
  293. k_graph_pairs = combinations(itr[0], 2)
  294. sod_ks = 0
  295. for i1, i2 in k_graph_pairs:
  296. sod_ks += G_dis_mat[i1, i2]
  297. return itr[0], itr[1], sod_ks
  298. def split_iterable(iterable, n, len_iter):
  299. it = iter(iterable)
  300. for i in range(0, len_iter, n):
  301. piece = islice(it, n)
  302. yield piece
  303. def get_combination_length(n, k):
  304. len_combination = 1
  305. for i in range(n, n - k, -1):
  306. len_combination *= i
  307. return int(len_combination / math.factorial(k))
  308. ###############################################################################
  309. def test_k_closest_graphs():
  310. ds = {'name': 'monoterpenoides',
  311. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  312. Gn, y_all = loadDataset(ds['dataset'])
  313. # Gn = Gn[0:50]
  314. # gkernel = 'untilhpathkernel'
  315. # gkernel = 'weisfeilerlehmankernel'
  316. gkernel = 'treeletkernel'
  317. node_label = 'atom'
  318. edge_label = 'bond_type'
  319. k = 5
  320. edit_costs = [0.16229209837639536, 0.06612870523413916, 0.04030113378793905, 0.20723547009415202, 0.3338607220394598, 0.27054392518077297]
  321. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  322. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  323. # 'precomputed', edit_costs=edit_costs,
  324. ## 'k-graphs',
  325. # parallel=False)
  326. #
  327. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  328. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  329. # 'expert', parallel=False)
  330. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  331. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  332. 'expert', parallel=False)
  333. return
  334. def test_k_closest_graphs_with_cv():
  335. gkernel = 'untilhpathkernel'
  336. node_label = 'atom'
  337. edge_label = 'bond_type'
  338. k = 4
  339. y_all = ['3', '1', '4', '6', '7', '8', '9', '2']
  340. repeats = 50
  341. collection_path = os.path.dirname(os.path.realpath(__file__)) + '/cpp_ext/generated_datsets/monoterpenoides/'
  342. graph_dir = collection_path + 'gxl/'
  343. sod_sm_list = []
  344. sod_gm_list = []
  345. dis_k_sm_list = []
  346. dis_k_gm_list = []
  347. dis_k_gi_min_list = []
  348. for y in y_all:
  349. print('\n-------------------------------------------------------')
  350. print('class of y:', y)
  351. sod_sm_list.append([])
  352. sod_gm_list.append([])
  353. dis_k_sm_list.append([])
  354. dis_k_gm_list.append([])
  355. dis_k_gi_min_list.append([])
  356. for repeat in range(repeats):
  357. print('\nrepeat ', repeat)
  358. collection_file = collection_path + 'monoterpenoides_' + y + '_' + str(repeat) + '.xml'
  359. Gn, _ = loadDataset(collection_file, extra_params=graph_dir)
  360. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  361. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel,
  362. k, 'whole-dataset', graph_dir=graph_dir,
  363. parallel=False)
  364. sod_sm_list[-1].append(sod_sm)
  365. sod_gm_list[-1].append(sod_gm)
  366. dis_k_sm_list[-1].append(dis_k_sm)
  367. dis_k_gm_list[-1].append(dis_k_gm)
  368. dis_k_gi_min_list[-1].append(dis_k_gi_min)
  369. print('\nsods of the set median for this class:', sod_sm_list[-1])
  370. print('\nsods of the gen median for this class:', sod_gm_list[-1])
  371. print('\ndistances in kernel space of set median for this class:',
  372. dis_k_sm_list[-1])
  373. print('\ndistances in kernel space of gen median for this class:',
  374. dis_k_gm_list[-1])
  375. print('\ndistances in kernel space of min graph for this class:',
  376. dis_k_gi_min_list[-1])
  377. sod_sm_list[-1] = np.mean(sod_sm_list[-1])
  378. sod_gm_list[-1] = np.mean(sod_gm_list[-1])
  379. dis_k_sm_list[-1] = np.mean(dis_k_sm_list[-1])
  380. dis_k_gm_list[-1] = np.mean(dis_k_gm_list[-1])
  381. dis_k_gi_min_list[-1] = np.mean(dis_k_gi_min_list[-1])
  382. print()
  383. print('\nmean sods of the set median for each class:', sod_sm_list)
  384. print('\nmean sods of the gen median for each class:', sod_gm_list)
  385. print('\nmean distance in kernel space of set median for each class:',
  386. dis_k_sm_list)
  387. print('\nmean distances in kernel space of gen median for each class:',
  388. dis_k_gm_list)
  389. print('\nmean distances in kernel space of min graph for each class:',
  390. dis_k_gi_min_list)
  391. print('\nmean sods of the set median of all:', np.mean(sod_sm_list))
  392. print('\nmean sods of the gen median of all:', np.mean(sod_gm_list))
  393. print('\nmean distances in kernel space of set median of all:',
  394. np.mean(dis_k_sm_list))
  395. print('\nmean distances in kernel space of gen median of all:',
  396. np.mean(dis_k_gm_list))
  397. print('\nmean distances in kernel space of min graph of all:',
  398. np.mean(dis_k_gi_min_list))
  399. return
  400. if __name__ == '__main__':
  401. test_k_closest_graphs()
  402. # test_k_closest_graphs_with_cv()

A Python package for graph kernels, graph edit distances and graph pre-image problem.