You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_k_closest_graphs.py 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Dec 16 11:53:54 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import math
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import time
  12. import random
  13. from tqdm import tqdm
  14. from itertools import combinations, islice
  15. import multiprocessing
  16. from multiprocessing import Pool
  17. from functools import partial
  18. #import os
  19. import sys
  20. sys.path.insert(0, "../")
  21. from gklearn.utils.graphfiles import loadDataset, loadGXL
  22. #from gklearn.utils.logger2file import *
  23. from iam import iam_upgraded, iam_bash
  24. from utils import compute_kernel, dis_gstar, kernel_distance_matrix
  25. from fitDistance import fit_GED_to_kernel_distance
  26. #from ged import ged_median
  27. def fit_edit_cost_constants(fit_method, edit_cost_name,
  28. edit_cost_constants=None, initial_solutions=1,
  29. Gn_median=None, node_label=None, edge_label=None,
  30. gkernel=None, dataset=None,
  31. Gn=None, Kmatrix_median=None):
  32. """fit edit cost constants.
  33. """
  34. if fit_method == 'random': # random
  35. if edit_cost_name == 'LETTER':
  36. edit_cost_constants = random.sample(range(1, 10), 3)
  37. edit_cost_constants = [item * 0.1 for item in edit_cost_constants]
  38. elif edit_cost_name == 'LETTER2':
  39. random.seed(time.time())
  40. edit_cost_constants = random.sample(range(1, 10), 5)
  41. # edit_cost_constants = [item * 0.1 for item in edit_cost_constants]
  42. elif edit_cost_name == 'NON_SYMBOLIC':
  43. edit_cost_constants = random.sample(range(1, 10), 6)
  44. if Gn_median[0].graph['node_attrs'] == []:
  45. edit_cost_constants[2] = 0
  46. if Gn_median[0].graph['edge_attrs'] == []:
  47. edit_cost_constants[5] = 0
  48. else:
  49. edit_cost_constants = random.sample(range(1, 10), 6)
  50. print('edit cost constants used:', edit_cost_constants)
  51. elif fit_method == 'expert': # expert
  52. if edit_cost_name == 'LETTER':
  53. edit_cost_constants = [0.9, 1.7, 0.75]
  54. elif edit_cost_name == 'LETTER2':
  55. edit_cost_constants = [0.675, 0.675, 0.75, 0.425, 0.425]
  56. else:
  57. edit_cost_constants = [3, 3, 1, 3, 3, 1]
  58. elif fit_method == 'k-graphs':
  59. itr_max = 6
  60. if edit_cost_name == 'LETTER':
  61. init_costs = [0.9, 1.7, 0.75]
  62. elif edit_cost_name == 'LETTER2':
  63. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  64. elif edit_cost_name == 'NON_SYMBOLIC':
  65. init_costs = [0, 0, 1, 1, 1, 0]
  66. if Gn_median[0].graph['node_attrs'] == []:
  67. init_costs[2] = 0
  68. if Gn_median[0].graph['edge_attrs'] == []:
  69. init_costs[5] = 0
  70. else:
  71. init_costs = [3, 3, 1, 3, 3, 1]
  72. algo_options = '--threads 1 --initial-solutions ' \
  73. + str(initial_solutions) + ' --ratio-runs-from-initial-solutions 1'
  74. params_ged = {'lib': 'gedlibpy', 'cost': edit_cost_name, 'method': 'IPFP',
  75. 'algo_options': algo_options, 'stabilizer': None}
  76. # fit on k-graph subset
  77. edit_cost_constants, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn_median,
  78. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  79. init_costs=init_costs, dataset=dataset, Kmatrix=Kmatrix_median,
  80. parallel=True)
  81. elif fit_method == 'whole-dataset':
  82. itr_max = 6
  83. if edit_cost_name == 'LETTER':
  84. init_costs = [0.9, 1.7, 0.75]
  85. elif edit_cost_name == 'LETTER2':
  86. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  87. else:
  88. init_costs = [3, 3, 1, 3, 3, 1]
  89. algo_options = '--threads 1 --initial-solutions ' \
  90. + str(initial_solutions) + ' --ratio-runs-from-initial-solutions 1'
  91. params_ged = {'lib': 'gedlibpy', 'cost': edit_cost_name, 'method': 'IPFP',
  92. 'algo_options': algo_options, 'stabilizer': None}
  93. # fit on all subset
  94. edit_cost_constants, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn,
  95. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  96. init_costs=init_costs, dataset=dataset, parallel=True)
  97. elif fit_method == 'precomputed':
  98. pass
  99. return edit_cost_constants
  100. def compute_distances_to_true_median(Gn_median, fname_sm, fname_gm,
  101. gkernel, edit_cost_name,
  102. Kmatrix_median=None):
  103. # reform graphs.
  104. set_median = loadGXL(fname_sm)
  105. gen_median = loadGXL(fname_gm)
  106. # print(gen_median.nodes(data=True))
  107. # print(gen_median.edges(data=True))
  108. if edit_cost_name == 'LETTER' or edit_cost_name == 'LETTER2' or edit_cost_name == 'NON_SYMBOLIC':
  109. # dataset == 'Fingerprint':
  110. # for g in Gn_median:
  111. # reform_attributes(g)
  112. reform_attributes(set_median, Gn_median[0].graph['node_attrs'],
  113. Gn_median[0].graph['edge_attrs'])
  114. reform_attributes(gen_median, Gn_median[0].graph['node_attrs'],
  115. Gn_median[0].graph['edge_attrs'])
  116. if edit_cost_name == 'LETTER' or edit_cost_name == 'LETTER2' or edit_cost_name == 'NON_SYMBOLIC':
  117. node_label = None
  118. edge_label = None
  119. else:
  120. node_label = 'chem'
  121. edge_label = 'valence'
  122. # compute Gram matrix for median set.
  123. if Kmatrix_median is None:
  124. Kmatrix_median = compute_kernel(Gn_median, gkernel, node_label, edge_label, False)
  125. # compute distance in kernel space for set median.
  126. kernel_sm = []
  127. for G_median in Gn_median:
  128. km_tmp = compute_kernel([set_median, G_median], gkernel, node_label, edge_label, False)
  129. kernel_sm.append(km_tmp[0, 1])
  130. Kmatrix_sm = np.concatenate((np.array([kernel_sm]), np.copy(Kmatrix_median)), axis=0)
  131. Kmatrix_sm = np.concatenate((np.array([[km_tmp[0, 0]] + kernel_sm]).T, Kmatrix_sm), axis=1)
  132. # Kmatrix_sm = compute_kernel([set_median] + Gn_median, gkernel,
  133. # node_label, edge_label, False)
  134. dis_k_sm = dis_gstar(0, range(1, 1+len(Gn_median)),
  135. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_sm, withterm3=False)
  136. # print(gen_median.nodes(data=True))
  137. # print(gen_median.edges(data=True))
  138. # print(set_median.nodes(data=True))
  139. # print(set_median.edges(data=True))
  140. # compute distance in kernel space for generalized median.
  141. kernel_gm = []
  142. for G_median in Gn_median:
  143. km_tmp = compute_kernel([gen_median, G_median], gkernel, node_label, edge_label, False)
  144. kernel_gm.append(km_tmp[0, 1])
  145. Kmatrix_gm = np.concatenate((np.array([kernel_gm]), np.copy(Kmatrix_median)), axis=0)
  146. Kmatrix_gm = np.concatenate((np.array([[km_tmp[0, 0]] + kernel_gm]).T, Kmatrix_gm), axis=1)
  147. # Kmatrix_gm = compute_kernel([gen_median] + Gn_median, gkernel,
  148. # node_label, edge_label, False)
  149. dis_k_gm = dis_gstar(0, range(1, 1+len(Gn_median)),
  150. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False)
  151. # compute distance in kernel space for each graph in median set.
  152. dis_k_gi = []
  153. for idx in range(len(Gn_median)):
  154. dis_k_gi.append(dis_gstar(idx+1, range(1, 1+len(Gn_median)),
  155. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False))
  156. print('dis_k_sm:', dis_k_sm)
  157. print('dis_k_gm:', dis_k_gm)
  158. print('dis_k_gi:', dis_k_gi)
  159. idx_dis_k_gi_min = np.argmin(dis_k_gi)
  160. dis_k_gi_min = dis_k_gi[idx_dis_k_gi_min]
  161. print('min dis_k_gi:', dis_k_gi_min)
  162. return dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min
  163. def median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k, fit_method,
  164. graph_dir=None, initial_solutions=1,
  165. edit_cost_constants=None, group_min=None,
  166. dataset=None, edit_cost_name=None,
  167. Kmatrix=None, parallel=True):
  168. # dataset = dataset.lower()
  169. # # compute distances in kernel space.
  170. # dis_mat, _, _, _ = kernel_distance_matrix(Gn, node_label, edge_label,
  171. # Kmatrix=None, gkernel=gkernel)
  172. # # ged.
  173. # gmfile = np.load('results/test_k_closest_graphs/ged_mat.fit_on_whole_dataset.with_medians.gm.npz')
  174. # ged_mat = gmfile['ged_mat']
  175. # dis_mat = ged_mat[0:len(Gn), 0:len(Gn)]
  176. # # choose k closest graphs
  177. # time0 = time.time()
  178. # sod_ks_min, group_min = get_closest_k_graphs(dis_mat, k, parallel)
  179. # time_spent = time.time() - time0
  180. # print('closest graphs:', sod_ks_min, group_min)
  181. # print('time spent:', time_spent)
  182. # group_min = (12, 13, 22, 29) # closest w.r.t path kernel
  183. # group_min = (77, 85, 160, 171) # closest w.r.t ged
  184. # group_min = (0,1,2,3,4,5,6,7,8,9,10,11) # closest w.r.t treelet kernel
  185. Gn_median = [Gn[g].copy() for g in group_min]
  186. if Kmatrix is not None:
  187. Kmatrix_median = np.copy(Kmatrix[group_min,:])
  188. Kmatrix_median = Kmatrix_median[:,group_min]
  189. # 1. fit edit cost constants.
  190. time0 = time.time()
  191. edit_cost_constants = fit_edit_cost_constants(fit_method, edit_cost_name,
  192. edit_cost_constants=edit_cost_constants, initial_solutions=initial_solutions,
  193. Gn_median=Gn_median, node_label=node_label, edge_label=edge_label,
  194. gkernel=gkernel, dataset=dataset,
  195. Gn=Gn, Kmatrix_median=Kmatrix_median)
  196. time_fitting = time.time() - time0
  197. # 2. compute set median and gen median using IAM (C++ through bash).
  198. print('\nstart computing set median and gen median using IAM (C++ through bash)...\n')
  199. group_fnames = [Gn[g].graph['filename'] for g in group_min]
  200. time0 = time.time()
  201. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constants,
  202. cost=edit_cost_name, initial_solutions=initial_solutions,
  203. graph_dir=graph_dir, dataset=dataset)
  204. time_generating = time.time() - time0
  205. print('\nmedians computed.\n')
  206. # 3. compute distances to real median.
  207. print('\nstart computing distances to true median....\n')
  208. Gn_median = [Gn[g].copy() for g in group_min]
  209. dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min = \
  210. compute_distances_to_true_median(Gn_median, fname_sm, fname_gm,
  211. gkernel, edit_cost_name,
  212. Kmatrix_median=Kmatrix_median)
  213. idx_dis_k_gi_min = group_min[idx_dis_k_gi_min]
  214. print('index min dis_k_gi:', idx_dis_k_gi_min)
  215. print('sod_sm:', sod_sm)
  216. print('sod_gm:', sod_gm)
  217. # collect return values.
  218. return (sod_sm, sod_gm), \
  219. (dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, idx_dis_k_gi_min), \
  220. (time_fitting, time_generating)
  221. def reform_attributes(G, na_names=[], ea_names=[]):
  222. if not na_names == []:
  223. for node in G.nodes:
  224. G.nodes[node]['attributes'] = [G.node[node][a_name] for a_name in na_names]
  225. if not ea_names == []:
  226. for edge in G.edges:
  227. G.edges[edge]['attributes'] = [G.edge[edge][a_name] for a_name in ea_names]
  228. def get_closest_k_graphs(dis_mat, k, parallel):
  229. k_graph_groups = combinations(range(0, len(dis_mat)), k)
  230. sod_ks_min = np.inf
  231. if parallel:
  232. len_combination = get_combination_length(len(dis_mat), k)
  233. len_itr_max = int(len_combination if len_combination < 1e7 else 1e7)
  234. # pos_cur = 0
  235. graph_groups_slices = split_iterable(k_graph_groups, len_itr_max, len_combination)
  236. for graph_groups_cur in graph_groups_slices:
  237. # while True:
  238. # graph_groups_cur = islice(k_graph_groups, pos_cur, pos_cur + len_itr_max)
  239. graph_groups_cur_list = list(graph_groups_cur)
  240. print('current position:', graph_groups_cur_list[0])
  241. len_itr_cur = len(graph_groups_cur_list)
  242. # if len_itr_cur < len_itr_max:
  243. # break
  244. itr = zip(graph_groups_cur_list, range(0, len_itr_cur))
  245. sod_k_list = np.empty(len_itr_cur)
  246. graphs_list = [None] * len_itr_cur
  247. n_jobs = multiprocessing.cpu_count()
  248. chunksize = int(len_itr_max / n_jobs + 1)
  249. n_jobs = multiprocessing.cpu_count()
  250. def init_worker(dis_mat_toshare):
  251. global G_dis_mat
  252. G_dis_mat = dis_mat_toshare
  253. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(dis_mat,))
  254. # iterator = tqdm(pool.imap_unordered(_get_closest_k_graphs_parallel,
  255. # itr, chunksize),
  256. # desc='Choosing k closest graphs', file=sys.stdout)
  257. iterator = pool.imap_unordered(_get_closest_k_graphs_parallel, itr, chunksize)
  258. for graphs, i, sod_ks in iterator:
  259. sod_k_list[i] = sod_ks
  260. graphs_list[i] = graphs
  261. pool.close()
  262. pool.join()
  263. arg_min = np.argmin(sod_k_list)
  264. sod_ks_cur = sod_k_list[arg_min]
  265. group_cur = graphs_list[arg_min]
  266. if sod_ks_cur < sod_ks_min:
  267. sod_ks_min = sod_ks_cur
  268. group_min = group_cur
  269. print('get closer graphs:', sod_ks_min, group_min)
  270. else:
  271. for items in tqdm(k_graph_groups, desc='Choosing k closest graphs', file=sys.stdout):
  272. # if items[0] != itmp:
  273. # itmp = items[0]
  274. # print(items)
  275. k_graph_pairs = combinations(items, 2)
  276. sod_ks = 0
  277. for i1, i2 in k_graph_pairs:
  278. sod_ks += dis_mat[i1, i2]
  279. if sod_ks < sod_ks_min:
  280. sod_ks_min = sod_ks
  281. group_min = items
  282. print('get closer graphs:', sod_ks_min, group_min)
  283. return sod_ks_min, group_min
  284. def _get_closest_k_graphs_parallel(itr):
  285. k_graph_pairs = combinations(itr[0], 2)
  286. sod_ks = 0
  287. for i1, i2 in k_graph_pairs:
  288. sod_ks += G_dis_mat[i1, i2]
  289. return itr[0], itr[1], sod_ks
  290. def split_iterable(iterable, n, len_iter):
  291. it = iter(iterable)
  292. for i in range(0, len_iter, n):
  293. piece = islice(it, n)
  294. yield piece
  295. def get_combination_length(n, k):
  296. len_combination = 1
  297. for i in range(n, n - k, -1):
  298. len_combination *= i
  299. return int(len_combination / math.factorial(k))
  300. ###############################################################################
  301. def test_k_closest_graphs():
  302. ds = {'name': 'monoterpenoides',
  303. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  304. Gn, y_all = loadDataset(ds['dataset'])
  305. # Gn = Gn[0:50]
  306. # gkernel = 'untilhpathkernel'
  307. # gkernel = 'weisfeilerlehmankernel'
  308. gkernel = 'treeletkernel'
  309. node_label = 'atom'
  310. edge_label = 'bond_type'
  311. k = 5
  312. edit_costs = [0.16229209837639536, 0.06612870523413916, 0.04030113378793905, 0.20723547009415202, 0.3338607220394598, 0.27054392518077297]
  313. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  314. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  315. # 'precomputed', edit_costs=edit_costs,
  316. ## 'k-graphs',
  317. # parallel=False)
  318. #
  319. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  320. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  321. # 'expert', parallel=False)
  322. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  323. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  324. 'expert', parallel=False)
  325. return
  326. def test_k_closest_graphs_with_cv():
  327. gkernel = 'untilhpathkernel'
  328. node_label = 'atom'
  329. edge_label = 'bond_type'
  330. k = 4
  331. y_all = ['3', '1', '4', '6', '7', '8', '9', '2']
  332. repeats = 50
  333. collection_path = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/'
  334. graph_dir = collection_path + 'gxl/'
  335. sod_sm_list = []
  336. sod_gm_list = []
  337. dis_k_sm_list = []
  338. dis_k_gm_list = []
  339. dis_k_gi_min_list = []
  340. for y in y_all:
  341. print('\n-------------------------------------------------------')
  342. print('class of y:', y)
  343. sod_sm_list.append([])
  344. sod_gm_list.append([])
  345. dis_k_sm_list.append([])
  346. dis_k_gm_list.append([])
  347. dis_k_gi_min_list.append([])
  348. for repeat in range(repeats):
  349. print('\nrepeat ', repeat)
  350. collection_file = collection_path + 'monoterpenoides_' + y + '_' + str(repeat) + '.xml'
  351. Gn, _ = loadDataset(collection_file, extra_params=graph_dir)
  352. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  353. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel,
  354. k, 'whole-dataset', graph_dir=graph_dir,
  355. parallel=False)
  356. sod_sm_list[-1].append(sod_sm)
  357. sod_gm_list[-1].append(sod_gm)
  358. dis_k_sm_list[-1].append(dis_k_sm)
  359. dis_k_gm_list[-1].append(dis_k_gm)
  360. dis_k_gi_min_list[-1].append(dis_k_gi_min)
  361. print('\nsods of the set median for this class:', sod_sm_list[-1])
  362. print('\nsods of the gen median for this class:', sod_gm_list[-1])
  363. print('\ndistances in kernel space of set median for this class:',
  364. dis_k_sm_list[-1])
  365. print('\ndistances in kernel space of gen median for this class:',
  366. dis_k_gm_list[-1])
  367. print('\ndistances in kernel space of min graph for this class:',
  368. dis_k_gi_min_list[-1])
  369. sod_sm_list[-1] = np.mean(sod_sm_list[-1])
  370. sod_gm_list[-1] = np.mean(sod_gm_list[-1])
  371. dis_k_sm_list[-1] = np.mean(dis_k_sm_list[-1])
  372. dis_k_gm_list[-1] = np.mean(dis_k_gm_list[-1])
  373. dis_k_gi_min_list[-1] = np.mean(dis_k_gi_min_list[-1])
  374. print()
  375. print('\nmean sods of the set median for each class:', sod_sm_list)
  376. print('\nmean sods of the gen median for each class:', sod_gm_list)
  377. print('\nmean distance in kernel space of set median for each class:',
  378. dis_k_sm_list)
  379. print('\nmean distances in kernel space of gen median for each class:',
  380. dis_k_gm_list)
  381. print('\nmean distances in kernel space of min graph for each class:',
  382. dis_k_gi_min_list)
  383. print('\nmean sods of the set median of all:', np.mean(sod_sm_list))
  384. print('\nmean sods of the gen median of all:', np.mean(sod_gm_list))
  385. print('\nmean distances in kernel space of set median of all:',
  386. np.mean(dis_k_sm_list))
  387. print('\nmean distances in kernel space of gen median of all:',
  388. np.mean(dis_k_gm_list))
  389. print('\nmean distances in kernel space of min graph of all:',
  390. np.mean(dis_k_gi_min_list))
  391. return
  392. if __name__ == '__main__':
  393. test_k_closest_graphs()
  394. # test_k_closest_graphs_with_cv()

A Python package for graph kernels, graph edit distances and graph pre-image problem.