You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_k_closest_graphs.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Dec 16 11:53:54 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import math
  9. import networkx as nx
  10. import matplotlib.pyplot as plt
  11. import time
  12. import random
  13. from tqdm import tqdm
  14. from itertools import combinations, islice
  15. import multiprocessing
  16. from multiprocessing import Pool
  17. from functools import partial
  18. #import os
  19. import sys
  20. sys.path.insert(0, "../")
  21. from pygraph.utils.graphfiles import loadDataset, loadGXL
  22. #from pygraph.utils.logger2file import *
  23. from iam import iam_upgraded, iam_bash
  24. from utils import compute_kernel, dis_gstar, kernel_distance_matrix
  25. from fitDistance import fit_GED_to_kernel_distance
  26. #from ged import ged_median
  27. def median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k, fit_method,
  28. graph_dir='/media/ljia/DATA/research-repo/codes/Linlin/py-graph/datasets/monoterpenoides/',
  29. edit_costs=None, group_min=None, dataset='monoterpenoides',
  30. cost='CONSTANT', parallel=True):
  31. dataset = dataset.lower()
  32. # # compute distances in kernel space.
  33. # dis_mat, _, _, _ = kernel_distance_matrix(Gn, node_label, edge_label,
  34. # Kmatrix=None, gkernel=gkernel)
  35. # # ged.
  36. # gmfile = np.load('results/test_k_closest_graphs/ged_mat.fit_on_whole_dataset.with_medians.gm.npz')
  37. # ged_mat = gmfile['ged_mat']
  38. # dis_mat = ged_mat[0:len(Gn), 0:len(Gn)]
  39. # # choose k closest graphs
  40. # time0 = time.time()
  41. # sod_ks_min, group_min = get_closest_k_graphs(dis_mat, k, parallel)
  42. # time_spent = time.time() - time0
  43. # print('closest graphs:', sod_ks_min, group_min)
  44. # print('time spent:', time_spent)
  45. # group_min = (12, 13, 22, 29) # closest w.r.t path kernel
  46. # group_min = (77, 85, 160, 171) # closest w.r.t ged
  47. # group_min = (0,1,2,3,4,5,6,7,8,9,10,11) # closest w.r.t treelet kernel
  48. Gn_median = [Gn[g].copy() for g in group_min]
  49. # fit edit costs.
  50. if fit_method == 'random': # random
  51. if cost == 'LETTER':
  52. edit_cost_constant = random.sample(range(1, 10), 3)
  53. edit_cost_constant = [item * 0.1 for item in edit_cost_constant]
  54. elif cost == 'LETTER2':
  55. random.seed(time.time())
  56. edit_cost_constant = random.sample(range(1, 10), 5)
  57. # edit_cost_constant = [item * 0.1 for item in edit_cost_constant]
  58. else:
  59. edit_cost_constant = random.sample(range(1, 10), 6)
  60. print('edit costs used:', edit_cost_constant)
  61. elif fit_method == 'expert': # expert
  62. edit_cost_constant = [3, 3, 1, 3, 3, 1]
  63. elif fit_method == 'k-graphs':
  64. itr_max = 6
  65. if cost == 'LETTER':
  66. init_costs = [0.9, 1.7, 0.75]
  67. elif cost == 'LETTER2':
  68. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  69. else:
  70. init_costs = [3, 3, 1, 3, 3, 1]
  71. algo_options = '--threads 1 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  72. params_ged = {'lib': 'gedlibpy', 'cost': cost, 'method': 'IPFP',
  73. 'algo_options': algo_options, 'stabilizer': None}
  74. # fit on k-graph subset
  75. edit_cost_constant, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn_median,
  76. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  77. init_costs=init_costs, dataset=dataset, parallel=True)
  78. elif fit_method == 'whole-dataset':
  79. itr_max = 6
  80. if cost == 'LETTER':
  81. init_costs = [0.9, 1.7, 0.75]
  82. elif cost == 'LETTER2':
  83. init_costs = [0.675, 0.675, 0.75, 0.425, 0.425]
  84. else:
  85. init_costs = [3, 3, 1, 3, 3, 1]
  86. algo_options = '--threads 1 --initial-solutions 40 --ratio-runs-from-initial-solutions 1'
  87. params_ged = {'lib': 'gedlibpy', 'cost': cost, 'method': 'IPFP',
  88. 'algo_options': algo_options, 'stabilizer': None}
  89. # fit on all subset
  90. edit_cost_constant, _, _, _, _, _, _ = fit_GED_to_kernel_distance(Gn,
  91. node_label, edge_label, gkernel, itr_max, params_ged=params_ged,
  92. init_costs=init_costs, dataset=dataset, parallel=True)
  93. elif fit_method == 'precomputed':
  94. edit_cost_constant = edit_costs
  95. # compute set median and gen median using IAM (C++ through bash).
  96. group_fnames = [Gn[g].graph['filename'] for g in group_min]
  97. sod_sm, sod_gm, fname_sm, fname_gm = iam_bash(group_fnames, edit_cost_constant,
  98. cost=cost, graph_dir=graph_dir,
  99. dataset=dataset)
  100. # compute distances in kernel space.
  101. Gn_median = [Gn[g].copy() for g in group_min]
  102. set_median = loadGXL(fname_sm)
  103. gen_median = loadGXL(fname_gm)
  104. # print(gen_median.nodes(data=True))
  105. # print(gen_median.edges(data=True))
  106. if dataset == 'letter':
  107. for g in Gn_median:
  108. reform_attributes(g)
  109. reform_attributes(set_median)
  110. reform_attributes(gen_median)
  111. # compute distance in kernel space for set median.
  112. Kmatrix_sm = compute_kernel([set_median] + Gn_median, gkernel,
  113. None if dataset == 'letter' else 'chem',
  114. None if dataset == 'letter' else 'valence',
  115. False)
  116. dis_k_sm = dis_gstar(0, range(1, 1+len(Gn_median)),
  117. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_sm, withterm3=False)
  118. # print(gen_median.nodes(data=True))
  119. # print(gen_median.edges(data=True))
  120. # print(set_median.nodes(data=True))
  121. # print(set_median.edges(data=True))
  122. # compute distance in kernel space for generalized median.
  123. Kmatrix_gm = compute_kernel([gen_median] + Gn_median, gkernel,
  124. None if dataset == 'letter' else 'chem',
  125. None if dataset == 'letter' else 'valence',
  126. False)
  127. dis_k_gm = dis_gstar(0, range(1, 1+len(Gn_median)),
  128. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False)
  129. # compute distance in kernel space for each graph in median set.
  130. dis_k_gi = []
  131. for idx in range(len(Gn_median)):
  132. dis_k_gi.append(dis_gstar(idx+1, range(1, 1+len(Gn_median)),
  133. [1 / len(Gn_median)] * len(Gn_median), Kmatrix_gm, withterm3=False))
  134. print('sod_sm:', sod_sm)
  135. print('sod_gm:', sod_gm)
  136. print('dis_k_sm:', dis_k_sm)
  137. print('dis_k_gm:', dis_k_gm)
  138. print('dis_k_gi:', dis_k_gi)
  139. idx_dis_k_gi_min = np.argmin(dis_k_gi)
  140. dis_k_gi_min = dis_k_gi[idx_dis_k_gi_min]
  141. print('index min dis_k_gi:', group_min[idx_dis_k_gi_min])
  142. print('min dis_k_gi:', dis_k_gi_min)
  143. return sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min, group_min[idx_dis_k_gi_min]
  144. def reform_attributes(G):
  145. for node in G.nodes:
  146. G.nodes[node]['attributes'] = [G.nodes[node]['x'], G.nodes[node]['y']]
  147. def get_closest_k_graphs(dis_mat, k, parallel):
  148. k_graph_groups = combinations(range(0, len(dis_mat)), k)
  149. sod_ks_min = np.inf
  150. if parallel:
  151. len_combination = get_combination_length(len(dis_mat), k)
  152. len_itr_max = int(len_combination if len_combination < 1e7 else 1e7)
  153. # pos_cur = 0
  154. graph_groups_slices = split_iterable(k_graph_groups, len_itr_max, len_combination)
  155. for graph_groups_cur in graph_groups_slices:
  156. # while True:
  157. # graph_groups_cur = islice(k_graph_groups, pos_cur, pos_cur + len_itr_max)
  158. graph_groups_cur_list = list(graph_groups_cur)
  159. print('current position:', graph_groups_cur_list[0])
  160. len_itr_cur = len(graph_groups_cur_list)
  161. # if len_itr_cur < len_itr_max:
  162. # break
  163. itr = zip(graph_groups_cur_list, range(0, len_itr_cur))
  164. sod_k_list = np.empty(len_itr_cur)
  165. graphs_list = [None] * len_itr_cur
  166. n_jobs = multiprocessing.cpu_count()
  167. chunksize = int(len_itr_max / n_jobs + 1)
  168. n_jobs = multiprocessing.cpu_count()
  169. def init_worker(dis_mat_toshare):
  170. global G_dis_mat
  171. G_dis_mat = dis_mat_toshare
  172. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(dis_mat,))
  173. # iterator = tqdm(pool.imap_unordered(_get_closest_k_graphs_parallel,
  174. # itr, chunksize),
  175. # desc='Choosing k closest graphs', file=sys.stdout)
  176. iterator = pool.imap_unordered(_get_closest_k_graphs_parallel, itr, chunksize)
  177. for graphs, i, sod_ks in iterator:
  178. sod_k_list[i] = sod_ks
  179. graphs_list[i] = graphs
  180. pool.close()
  181. pool.join()
  182. arg_min = np.argmin(sod_k_list)
  183. sod_ks_cur = sod_k_list[arg_min]
  184. group_cur = graphs_list[arg_min]
  185. if sod_ks_cur < sod_ks_min:
  186. sod_ks_min = sod_ks_cur
  187. group_min = group_cur
  188. print('get closer graphs:', sod_ks_min, group_min)
  189. else:
  190. for items in tqdm(k_graph_groups, desc='Choosing k closest graphs', file=sys.stdout):
  191. # if items[0] != itmp:
  192. # itmp = items[0]
  193. # print(items)
  194. k_graph_pairs = combinations(items, 2)
  195. sod_ks = 0
  196. for i1, i2 in k_graph_pairs:
  197. sod_ks += dis_mat[i1, i2]
  198. if sod_ks < sod_ks_min:
  199. sod_ks_min = sod_ks
  200. group_min = items
  201. print('get closer graphs:', sod_ks_min, group_min)
  202. return sod_ks_min, group_min
  203. def _get_closest_k_graphs_parallel(itr):
  204. k_graph_pairs = combinations(itr[0], 2)
  205. sod_ks = 0
  206. for i1, i2 in k_graph_pairs:
  207. sod_ks += G_dis_mat[i1, i2]
  208. return itr[0], itr[1], sod_ks
  209. def split_iterable(iterable, n, len_iter):
  210. it = iter(iterable)
  211. for i in range(0, len_iter, n):
  212. piece = islice(it, n)
  213. yield piece
  214. def get_combination_length(n, k):
  215. len_combination = 1
  216. for i in range(n, n - k, -1):
  217. len_combination *= i
  218. return int(len_combination / math.factorial(k))
  219. ###############################################################################
  220. def test_k_closest_graphs():
  221. ds = {'name': 'monoterpenoides',
  222. 'dataset': '../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  223. Gn, y_all = loadDataset(ds['dataset'])
  224. # Gn = Gn[0:50]
  225. # gkernel = 'untilhpathkernel'
  226. # gkernel = 'weisfeilerlehmankernel'
  227. gkernel = 'treeletkernel'
  228. node_label = 'atom'
  229. edge_label = 'bond_type'
  230. k = 5
  231. edit_costs = [0.16229209837639536, 0.06612870523413916, 0.04030113378793905, 0.20723547009415202, 0.3338607220394598, 0.27054392518077297]
  232. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  233. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  234. # 'precomputed', edit_costs=edit_costs,
  235. ## 'k-graphs',
  236. # parallel=False)
  237. #
  238. # sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  239. # = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  240. # 'expert', parallel=False)
  241. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  242. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel, k,
  243. 'expert', parallel=False)
  244. return
  245. def test_k_closest_graphs_with_cv():
  246. gkernel = 'untilhpathkernel'
  247. node_label = 'atom'
  248. edge_label = 'bond_type'
  249. k = 4
  250. y_all = ['3', '1', '4', '6', '7', '8', '9', '2']
  251. repeats = 50
  252. collection_path = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/'
  253. graph_dir = collection_path + 'gxl/'
  254. sod_sm_list = []
  255. sod_gm_list = []
  256. dis_k_sm_list = []
  257. dis_k_gm_list = []
  258. dis_k_gi_min_list = []
  259. for y in y_all:
  260. print('\n-------------------------------------------------------')
  261. print('class of y:', y)
  262. sod_sm_list.append([])
  263. sod_gm_list.append([])
  264. dis_k_sm_list.append([])
  265. dis_k_gm_list.append([])
  266. dis_k_gi_min_list.append([])
  267. for repeat in range(repeats):
  268. print('\nrepeat ', repeat)
  269. collection_file = collection_path + 'monoterpenoides_' + y + '_' + str(repeat) + '.xml'
  270. Gn, _ = loadDataset(collection_file, extra_params=graph_dir)
  271. sod_sm, sod_gm, dis_k_sm, dis_k_gm, dis_k_gi, dis_k_gi_min \
  272. = median_on_k_closest_graphs(Gn, node_label, edge_label, gkernel,
  273. k, 'whole-dataset', graph_dir=graph_dir,
  274. parallel=False)
  275. sod_sm_list[-1].append(sod_sm)
  276. sod_gm_list[-1].append(sod_gm)
  277. dis_k_sm_list[-1].append(dis_k_sm)
  278. dis_k_gm_list[-1].append(dis_k_gm)
  279. dis_k_gi_min_list[-1].append(dis_k_gi_min)
  280. print('\nsods of the set median for this class:', sod_sm_list[-1])
  281. print('\nsods of the gen median for this class:', sod_gm_list[-1])
  282. print('\ndistances in kernel space of set median for this class:',
  283. dis_k_sm_list[-1])
  284. print('\ndistances in kernel space of gen median for this class:',
  285. dis_k_gm_list[-1])
  286. print('\ndistances in kernel space of min graph for this class:',
  287. dis_k_gi_min_list[-1])
  288. sod_sm_list[-1] = np.mean(sod_sm_list[-1])
  289. sod_gm_list[-1] = np.mean(sod_gm_list[-1])
  290. dis_k_sm_list[-1] = np.mean(dis_k_sm_list[-1])
  291. dis_k_gm_list[-1] = np.mean(dis_k_gm_list[-1])
  292. dis_k_gi_min_list[-1] = np.mean(dis_k_gi_min_list[-1])
  293. print()
  294. print('\nmean sods of the set median for each class:', sod_sm_list)
  295. print('\nmean sods of the gen median for each class:', sod_gm_list)
  296. print('\nmean distance in kernel space of set median for each class:',
  297. dis_k_sm_list)
  298. print('\nmean distances in kernel space of gen median for each class:',
  299. dis_k_gm_list)
  300. print('\nmean distances in kernel space of min graph for each class:',
  301. dis_k_gi_min_list)
  302. print('\nmean sods of the set median of all:', np.mean(sod_sm_list))
  303. print('\nmean sods of the gen median of all:', np.mean(sod_gm_list))
  304. print('\nmean distances in kernel space of set median of all:',
  305. np.mean(dis_k_sm_list))
  306. print('\nmean distances in kernel space of gen median of all:',
  307. np.mean(dis_k_gm_list))
  308. print('\nmean distances in kernel space of min graph of all:',
  309. np.mean(dis_k_gi_min_list))
  310. return
  311. if __name__ == '__main__':
  312. test_k_closest_graphs()
  313. # test_k_closest_graphs_with_cv()

A Python package for graph kernels, graph edit distances and graph pre-image problem.