You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 17 19:05:07 2019
  5. Useful functions.
  6. @author: ljia
  7. """
  8. #import networkx as nx
  9. import multiprocessing
  10. import numpy as np
  11. from gklearn.kernels.marginalizedKernel import marginalizedkernel
  12. from gklearn.kernels.untilHPathKernel import untilhpathkernel
  13. from gklearn.kernels.spKernel import spkernel
  14. import functools
  15. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct, polynomialkernel
  16. from gklearn.kernels.structuralspKernel import structuralspkernel
  17. from gklearn.kernels.treeletKernel import treeletkernel
  18. from gklearn.kernels.weisfeilerLehmanKernel import weisfeilerlehmankernel
  19. from gklearn.utils import Dataset
  20. import csv
  21. import networkx as nx
  22. import os
  23. def generate_median_preimages_by_class(ds_name, mpg_options, kernel_options, ged_options, mge_options, save_results=True, save_medians=True, plot_medians=True, load_gm='auto', dir_save='', irrelevant_labels=None, edge_required=False, cut_range=None):
  24. import os.path
  25. from gklearn.preimage import MedianPreimageGenerator
  26. from gklearn.utils import split_dataset_by_target
  27. from gklearn.utils.graphfiles import saveGXL
  28. # 1. get dataset.
  29. print('1. getting dataset...')
  30. dataset_all = Dataset()
  31. dataset_all.load_predefined_dataset(ds_name)
  32. dataset_all.trim_dataset(edge_required=edge_required)
  33. if irrelevant_labels is not None:
  34. dataset_all.remove_labels(**irrelevant_labels)
  35. if cut_range is not None:
  36. dataset_all.cut_graphs(cut_range)
  37. datasets = split_dataset_by_target(dataset_all)
  38. if save_results:
  39. # create result files.
  40. print('creating output files...')
  41. fn_output_detail, fn_output_summary = _init_output_file_preimage(ds_name, kernel_options['name'], mpg_options['fit_method'], dir_save)
  42. sod_sm_list = []
  43. sod_gm_list = []
  44. dis_k_sm_list = []
  45. dis_k_gm_list = []
  46. dis_k_gi_min_list = []
  47. time_optimize_ec_list = []
  48. time_generate_list = []
  49. time_total_list = []
  50. itrs_list = []
  51. converged_list = []
  52. num_updates_ecc_list = []
  53. mge_decrease_order_list = []
  54. mge_increase_order_list = []
  55. mge_converged_order_list = []
  56. nb_sod_sm2gm = [0, 0, 0]
  57. nb_dis_k_sm2gm = [0, 0, 0]
  58. nb_dis_k_gi2sm = [0, 0, 0]
  59. nb_dis_k_gi2gm = [0, 0, 0]
  60. dis_k_max_list = []
  61. dis_k_min_list = []
  62. dis_k_mean_list = []
  63. if load_gm == 'auto':
  64. gm_fname = dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm.npz'
  65. gmfile_exist = os.path.isfile(os.path.abspath(gm_fname))
  66. if gmfile_exist:
  67. gmfile = np.load(gm_fname, allow_pickle=True) # @todo: may not be safe.
  68. gram_matrix_unnorm_list = [item for item in gmfile['gram_matrix_unnorm_list']]
  69. time_precompute_gm_list = gmfile['run_time_list'].tolist()
  70. else:
  71. gram_matrix_unnorm_list = []
  72. time_precompute_gm_list = []
  73. elif not load_gm:
  74. gram_matrix_unnorm_list = []
  75. time_precompute_gm_list = []
  76. else:
  77. gm_fname = dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm.npz'
  78. gmfile = np.load(gm_fname, allow_pickle=True) # @todo: may not be safe.
  79. gram_matrix_unnorm_list = [item for item in gmfile['gram_matrix_unnorm_list']]
  80. time_precompute_gm_list = gmfile['run_time_list'].tolist()
  81. # repeats_better_sod_sm2gm = []
  82. # repeats_better_dis_k_sm2gm = []
  83. # repeats_better_dis_k_gi2sm = []
  84. # repeats_better_dis_k_gi2gm = []
  85. print('starting generating preimage for each class of target...')
  86. idx_offset = 0
  87. for idx, dataset in enumerate(datasets):
  88. target = dataset.targets[0]
  89. print('\ntarget =', target, '\n')
  90. # if target != 1:
  91. # continue
  92. num_graphs = len(dataset.graphs)
  93. if num_graphs < 2:
  94. print('\nnumber of graphs = ', num_graphs, ', skip.\n')
  95. idx_offset += 1
  96. continue
  97. # 2. set parameters.
  98. print('2. initializing mpg and setting parameters...')
  99. if load_gm:
  100. if gmfile_exist:
  101. mpg_options['gram_matrix_unnorm'] = gram_matrix_unnorm_list[idx - idx_offset]
  102. mpg_options['runtime_precompute_gm'] = time_precompute_gm_list[idx - idx_offset]
  103. mpg = MedianPreimageGenerator()
  104. mpg.dataset = dataset
  105. mpg.set_options(**mpg_options.copy())
  106. mpg.kernel_options = kernel_options.copy()
  107. mpg.ged_options = ged_options.copy()
  108. mpg.mge_options = mge_options.copy()
  109. # 3. compute median preimage.
  110. print('3. computing median preimage...')
  111. mpg.run()
  112. results = mpg.get_results()
  113. # 4. compute pairwise kernel distances.
  114. print('4. computing pairwise kernel distances...')
  115. _, dis_k_max, dis_k_min, dis_k_mean = mpg.graph_kernel.compute_distance_matrix()
  116. dis_k_max_list.append(dis_k_max)
  117. dis_k_min_list.append(dis_k_min)
  118. dis_k_mean_list.append(dis_k_mean)
  119. # 5. save results (and median graphs).
  120. print('5. saving results (and median graphs)...')
  121. # write result detail.
  122. if save_results:
  123. print('writing results to files...')
  124. sod_sm2gm = get_relations(np.sign(results['sod_gen_median'] - results['sod_set_median']))
  125. dis_k_sm2gm = get_relations(np.sign(results['k_dis_gen_median'] - results['k_dis_set_median']))
  126. dis_k_gi2sm = get_relations(np.sign(results['k_dis_set_median'] - results['k_dis_dataset']))
  127. dis_k_gi2gm = get_relations(np.sign(results['k_dis_gen_median'] - results['k_dis_dataset']))
  128. f_detail = open(dir_save + fn_output_detail, 'a')
  129. csv.writer(f_detail).writerow([ds_name, kernel_options['name'],
  130. ged_options['edit_cost'], ged_options['method'],
  131. ged_options['attr_distance'], mpg_options['fit_method'],
  132. num_graphs, target, 1,
  133. results['sod_set_median'], results['sod_gen_median'],
  134. results['k_dis_set_median'], results['k_dis_gen_median'],
  135. results['k_dis_dataset'], sod_sm2gm, dis_k_sm2gm,
  136. dis_k_gi2sm, dis_k_gi2gm, results['edit_cost_constants'],
  137. results['runtime_precompute_gm'], results['runtime_optimize_ec'],
  138. results['runtime_generate_preimage'], results['runtime_total'],
  139. results['itrs'], results['converged'],
  140. results['num_updates_ecc'],
  141. results['mge']['num_decrease_order'] > 0, # @todo: not suitable for multi-start mge
  142. results['mge']['num_increase_order'] > 0,
  143. results['mge']['num_converged_descents'] > 0])
  144. f_detail.close()
  145. # compute result summary.
  146. sod_sm_list.append(results['sod_set_median'])
  147. sod_gm_list.append(results['sod_gen_median'])
  148. dis_k_sm_list.append(results['k_dis_set_median'])
  149. dis_k_gm_list.append(results['k_dis_gen_median'])
  150. dis_k_gi_min_list.append(results['k_dis_dataset'])
  151. time_precompute_gm_list.append(results['runtime_precompute_gm'])
  152. time_optimize_ec_list.append(results['runtime_optimize_ec'])
  153. time_generate_list.append(results['runtime_generate_preimage'])
  154. time_total_list.append(results['runtime_total'])
  155. itrs_list.append(results['itrs'])
  156. converged_list.append(results['converged'])
  157. num_updates_ecc_list.append(results['num_updates_ecc'])
  158. mge_decrease_order_list.append(results['mge']['num_decrease_order'] > 0)
  159. mge_increase_order_list.append(results['mge']['num_increase_order'] > 0)
  160. mge_converged_order_list.append(results['mge']['num_converged_descents'] > 0)
  161. # # SOD SM -> GM
  162. if results['sod_set_median'] > results['sod_gen_median']:
  163. nb_sod_sm2gm[0] += 1
  164. # repeats_better_sod_sm2gm.append(1)
  165. elif results['sod_set_median'] == results['sod_gen_median']:
  166. nb_sod_sm2gm[1] += 1
  167. elif results['sod_set_median'] < results['sod_gen_median']:
  168. nb_sod_sm2gm[2] += 1
  169. # # dis_k SM -> GM
  170. if results['k_dis_set_median'] > results['k_dis_gen_median']:
  171. nb_dis_k_sm2gm[0] += 1
  172. # repeats_better_dis_k_sm2gm.append(1)
  173. elif results['k_dis_set_median'] == results['k_dis_gen_median']:
  174. nb_dis_k_sm2gm[1] += 1
  175. elif results['k_dis_set_median'] < results['k_dis_gen_median']:
  176. nb_dis_k_sm2gm[2] += 1
  177. # # dis_k gi -> SM
  178. if results['k_dis_dataset'] > results['k_dis_set_median']:
  179. nb_dis_k_gi2sm[0] += 1
  180. # repeats_better_dis_k_gi2sm.append(1)
  181. elif results['k_dis_dataset'] == results['k_dis_set_median']:
  182. nb_dis_k_gi2sm[1] += 1
  183. elif results['k_dis_dataset'] < results['k_dis_set_median']:
  184. nb_dis_k_gi2sm[2] += 1
  185. # # dis_k gi -> GM
  186. if results['k_dis_dataset'] > results['k_dis_gen_median']:
  187. nb_dis_k_gi2gm[0] += 1
  188. # repeats_better_dis_k_gi2gm.append(1)
  189. elif results['k_dis_dataset'] == results['k_dis_gen_median']:
  190. nb_dis_k_gi2gm[1] += 1
  191. elif results['k_dis_dataset'] < results['k_dis_gen_median']:
  192. nb_dis_k_gi2gm[2] += 1
  193. # write result summary for each letter.
  194. f_summary = open(dir_save + fn_output_summary, 'a')
  195. csv.writer(f_summary).writerow([ds_name, kernel_options['name'],
  196. ged_options['edit_cost'], ged_options['method'],
  197. ged_options['attr_distance'], mpg_options['fit_method'],
  198. num_graphs, target,
  199. results['sod_set_median'], results['sod_gen_median'],
  200. results['k_dis_set_median'], results['k_dis_gen_median'],
  201. results['k_dis_dataset'], sod_sm2gm, dis_k_sm2gm,
  202. dis_k_gi2sm, dis_k_gi2gm,
  203. results['runtime_precompute_gm'], results['runtime_optimize_ec'],
  204. results['runtime_generate_preimage'], results['runtime_total'],
  205. results['itrs'], results['converged'],
  206. results['num_updates_ecc'],
  207. results['mge']['num_decrease_order'] > 0, # @todo: not suitable for multi-start mge
  208. results['mge']['num_increase_order'] > 0,
  209. results['mge']['num_converged_descents'] > 0,
  210. nb_sod_sm2gm,
  211. nb_dis_k_sm2gm, nb_dis_k_gi2sm, nb_dis_k_gi2gm])
  212. f_summary.close()
  213. # save median graphs.
  214. if save_medians:
  215. os.makedirs(dir_save + 'medians/', exist_ok=True)
  216. print('Saving median graphs to files...')
  217. fn_pre_sm = dir_save + 'medians/set_median.' + mpg_options['fit_method'] + '.nbg' + str(num_graphs) + '.y' + str(target) + '.repeat' + str(1)
  218. saveGXL(mpg.set_median, fn_pre_sm + '.gxl', method='default',
  219. node_labels=dataset.node_labels, edge_labels=dataset.edge_labels,
  220. node_attrs=dataset.node_attrs, edge_attrs=dataset.edge_attrs)
  221. fn_pre_gm = dir_save + 'medians/gen_median.' + mpg_options['fit_method'] + '.nbg' + str(num_graphs) + '.y' + str(target) + '.repeat' + str(1)
  222. saveGXL(mpg.gen_median, fn_pre_gm + '.gxl', method='default',
  223. node_labels=dataset.node_labels, edge_labels=dataset.edge_labels,
  224. node_attrs=dataset.node_attrs, edge_attrs=dataset.edge_attrs)
  225. fn_best_dataset = dir_save + 'medians/g_best_dataset.' + mpg_options['fit_method'] + '.nbg' + str(num_graphs) + '.y' + str(target) + '.repeat' + str(1)
  226. saveGXL(mpg.best_from_dataset, fn_best_dataset + '.gxl', method='default',
  227. node_labels=dataset.node_labels, edge_labels=dataset.edge_labels,
  228. node_attrs=dataset.node_attrs, edge_attrs=dataset.edge_attrs)
  229. # plot median graphs.
  230. if plot_medians and save_medians:
  231. if ged_options['edit_cost'] == 'LETTER2' or ged_options['edit_cost'] == 'LETTER' or ds_name == 'Letter-high' or ds_name == 'Letter-med' or ds_name == 'Letter-low':
  232. draw_Letter_graph(mpg.set_median, fn_pre_sm)
  233. draw_Letter_graph(mpg.gen_median, fn_pre_gm)
  234. draw_Letter_graph(mpg.best_from_dataset, fn_best_dataset)
  235. if (load_gm == 'auto' and not gmfile_exist) or not load_gm:
  236. gram_matrix_unnorm_list.append(mpg.gram_matrix_unnorm)
  237. # write result summary for each class.
  238. if save_results:
  239. sod_sm_mean = np.mean(sod_sm_list)
  240. sod_gm_mean = np.mean(sod_gm_list)
  241. dis_k_sm_mean = np.mean(dis_k_sm_list)
  242. dis_k_gm_mean = np.mean(dis_k_gm_list)
  243. dis_k_gi_min_mean = np.mean(dis_k_gi_min_list)
  244. time_precompute_gm_mean = np.mean(time_precompute_gm_list)
  245. time_optimize_ec_mean = np.mean(time_optimize_ec_list)
  246. time_generate_mean = np.mean(time_generate_list)
  247. time_total_mean = np.mean(time_total_list)
  248. itrs_mean = np.mean(itrs_list)
  249. num_converged = np.sum(converged_list)
  250. num_updates_ecc_mean = np.mean(num_updates_ecc_list)
  251. num_mge_decrease_order = np.sum(mge_decrease_order_list)
  252. num_mge_increase_order = np.sum(mge_increase_order_list)
  253. num_mge_converged = np.sum(mge_converged_order_list)
  254. sod_sm2gm_mean = get_relations(np.sign(sod_gm_mean - sod_sm_mean))
  255. dis_k_sm2gm_mean = get_relations(np.sign(dis_k_gm_mean - dis_k_sm_mean))
  256. dis_k_gi2sm_mean = get_relations(np.sign(dis_k_sm_mean - dis_k_gi_min_mean))
  257. dis_k_gi2gm_mean = get_relations(np.sign(dis_k_gm_mean - dis_k_gi_min_mean))
  258. f_summary = open(dir_save + fn_output_summary, 'a')
  259. csv.writer(f_summary).writerow([ds_name, kernel_options['name'],
  260. ged_options['edit_cost'], ged_options['method'],
  261. ged_options['attr_distance'], mpg_options['fit_method'],
  262. num_graphs, 'all',
  263. sod_sm_mean, sod_gm_mean, dis_k_sm_mean, dis_k_gm_mean,
  264. dis_k_gi_min_mean, sod_sm2gm_mean, dis_k_sm2gm_mean,
  265. dis_k_gi2sm_mean, dis_k_gi2gm_mean,
  266. time_precompute_gm_mean, time_optimize_ec_mean,
  267. time_generate_mean, time_total_mean, itrs_mean,
  268. num_converged, num_updates_ecc_mean,
  269. num_mge_decrease_order, num_mge_increase_order,
  270. num_mge_converged])
  271. f_summary.close()
  272. # save total pairwise kernel distances.
  273. dis_k_max = np.max(dis_k_max_list)
  274. dis_k_min = np.min(dis_k_min_list)
  275. dis_k_mean = np.mean(dis_k_mean_list)
  276. print('The maximum pairwise distance in kernel space:', dis_k_max)
  277. print('The minimum pairwise distance in kernel space:', dis_k_min)
  278. print('The average pairwise distance in kernel space:', dis_k_mean)
  279. # write Gram matrices to file.
  280. if (load_gm == 'auto' and not gmfile_exist) or not load_gm:
  281. np.savez(dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm', gram_matrix_unnorm_list=gram_matrix_unnorm_list, run_time_list=time_precompute_gm_list)
  282. print('\ncomplete.\n')
  283. def _init_output_file_preimage(ds_name, gkernel, fit_method, dir_output):
  284. os.makedirs(dir_output, exist_ok=True)
  285. # fn_output_detail = 'results_detail.' + ds_name + '.' + gkernel + '.' + fit_method + '.csv'
  286. fn_output_detail = 'results_detail.' + ds_name + '.' + gkernel + '.csv'
  287. f_detail = open(dir_output + fn_output_detail, 'a')
  288. csv.writer(f_detail).writerow(['dataset', 'graph kernel', 'edit cost',
  289. 'GED method', 'attr distance', 'fit method', 'num graphs',
  290. 'target', 'repeat', 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  291. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  292. 'dis_k gi -> GM', 'edit cost constants', 'time precompute gm',
  293. 'time optimize ec', 'time generate preimage', 'time total',
  294. 'itrs', 'converged', 'num updates ecc', 'mge decrease order',
  295. 'mge increase order', 'mge converged'])
  296. f_detail.close()
  297. # fn_output_summary = 'results_summary.' + ds_name + '.' + gkernel + '.' + fit_method + '.csv'
  298. fn_output_summary = 'results_summary.' + ds_name + '.' + gkernel + '.csv'
  299. f_summary = open(dir_output + fn_output_summary, 'a')
  300. csv.writer(f_summary).writerow(['dataset', 'graph kernel', 'edit cost',
  301. 'GED method', 'attr distance', 'fit method', 'num graphs',
  302. 'target', 'SOD SM', 'SOD GM', 'dis_k SM', 'dis_k GM',
  303. 'min dis_k gi', 'SOD SM -> GM', 'dis_k SM -> GM', 'dis_k gi -> SM',
  304. 'dis_k gi -> GM', 'time precompute gm', 'time optimize ec',
  305. 'time generate preimage', 'time total', 'itrs', 'num converged',
  306. 'num updates ecc', 'mge num decrease order', 'mge num increase order',
  307. 'mge num converged', '# SOD SM -> GM', '# dis_k SM -> GM',
  308. '# dis_k gi -> SM', '# dis_k gi -> GM'])
  309. # 'repeats better SOD SM -> GM',
  310. # 'repeats better dis_k SM -> GM', 'repeats better dis_k gi -> SM',
  311. # 'repeats better dis_k gi -> GM'])
  312. f_summary.close()
  313. return fn_output_detail, fn_output_summary
  314. def get_relations(sign):
  315. if sign == -1:
  316. return 'better'
  317. elif sign == 0:
  318. return 'same'
  319. elif sign == 1:
  320. return 'worse'
  321. #Dessin median courrant
  322. def draw_Letter_graph(graph, file_prefix):
  323. import matplotlib
  324. matplotlib.use('agg')
  325. import matplotlib.pyplot as plt
  326. plt.figure()
  327. pos = {}
  328. for n in graph.nodes:
  329. pos[n] = np.array([float(graph.nodes[n]['x']),float(graph.nodes[n]['y'])])
  330. nx.draw_networkx(graph, pos)
  331. plt.savefig(file_prefix + '.eps', format='eps', dpi=300)
  332. # plt.show()
  333. plt.clf()
  334. plt.close()
  335. def remove_edges(Gn):
  336. for G in Gn:
  337. for _, _, attrs in G.edges(data=True):
  338. attrs.clear()
  339. def dis_gstar(idx_g, idx_gi, alpha, Kmatrix, term3=0, withterm3=True):
  340. term1 = Kmatrix[idx_g, idx_g]
  341. term2 = 0
  342. for i, a in enumerate(alpha):
  343. term2 += a * Kmatrix[idx_g, idx_gi[i]]
  344. term2 *= 2
  345. if withterm3 == False:
  346. for i1, a1 in enumerate(alpha):
  347. for i2, a2 in enumerate(alpha):
  348. term3 += a1 * a2 * Kmatrix[idx_gi[i1], idx_gi[i2]]
  349. return np.sqrt(term1 - term2 + term3)
  350. def compute_k_dis(idx_g, idx_gi, alphas, Kmatrix, term3=0, withterm3=True):
  351. term1 = Kmatrix[idx_g, idx_g]
  352. term2 = 0
  353. for i, a in enumerate(alphas):
  354. term2 += a * Kmatrix[idx_g, idx_gi[i]]
  355. term2 *= 2
  356. if withterm3 == False:
  357. for i1, a1 in enumerate(alphas):
  358. for i2, a2 in enumerate(alphas):
  359. term3 += a1 * a2 * Kmatrix[idx_gi[i1], idx_gi[i2]]
  360. return np.sqrt(term1 - term2 + term3)
  361. def compute_kernel(Gn, graph_kernel, node_label, edge_label, verbose, parallel='imap_unordered'):
  362. if graph_kernel == 'marginalizedkernel':
  363. Kmatrix, _ = marginalizedkernel(Gn, node_label=node_label, edge_label=edge_label,
  364. p_quit=0.03, n_iteration=10, remove_totters=False,
  365. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  366. elif graph_kernel == 'untilhpathkernel':
  367. Kmatrix, _ = untilhpathkernel(Gn, node_label=node_label, edge_label=edge_label,
  368. depth=7, k_func='MinMax', compute_method='trie',
  369. parallel=parallel,
  370. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  371. elif graph_kernel == 'spkernel':
  372. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  373. Kmatrix = np.empty((len(Gn), len(Gn)))
  374. # Kmatrix[:] = np.nan
  375. Kmatrix, _, idx = spkernel(Gn, node_label=node_label, node_kernels=
  376. {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel},
  377. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  378. # for i, row in enumerate(idx):
  379. # for j, col in enumerate(idx):
  380. # Kmatrix[row, col] = Kmatrix_tmp[i, j]
  381. elif graph_kernel == 'structuralspkernel':
  382. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  383. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  384. Kmatrix, _ = structuralspkernel(Gn, node_label=node_label,
  385. edge_label=edge_label, node_kernels=sub_kernels,
  386. edge_kernels=sub_kernels,
  387. parallel=parallel, n_jobs=multiprocessing.cpu_count(),
  388. verbose=verbose)
  389. elif graph_kernel == 'treeletkernel':
  390. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  391. # pkernel = functools.partial(gaussiankernel, gamma=1e-6)
  392. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  393. Kmatrix, _ = treeletkernel(Gn, node_label=node_label, edge_label=edge_label,
  394. sub_kernel=pkernel, parallel=parallel,
  395. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  396. elif graph_kernel == 'weisfeilerlehmankernel':
  397. Kmatrix, _ = weisfeilerlehmankernel(Gn, node_label=node_label, edge_label=edge_label,
  398. height=4, base_kernel='subtree', parallel=None,
  399. n_jobs=multiprocessing.cpu_count(), verbose=verbose)
  400. else:
  401. raise Exception('The graph kernel "', graph_kernel, '" is not defined.')
  402. # normalization
  403. Kmatrix_diag = Kmatrix.diagonal().copy()
  404. for i in range(len(Kmatrix)):
  405. for j in range(i, len(Kmatrix)):
  406. Kmatrix[i][j] /= np.sqrt(Kmatrix_diag[i] * Kmatrix_diag[j])
  407. Kmatrix[j][i] = Kmatrix[i][j]
  408. return Kmatrix
  409. def gram2distances(Kmatrix):
  410. dmatrix = np.zeros((len(Kmatrix), len(Kmatrix)))
  411. for i1 in range(len(Kmatrix)):
  412. for i2 in range(len(Kmatrix)):
  413. dmatrix[i1, i2] = Kmatrix[i1, i1] + Kmatrix[i2, i2] - 2 * Kmatrix[i1, i2]
  414. dmatrix = np.sqrt(dmatrix)
  415. return dmatrix
  416. def kernel_distance_matrix(Gn, node_label, edge_label, Kmatrix=None,
  417. gkernel=None, verbose=True):
  418. import warnings
  419. warnings.warn('gklearn.preimage.utils.kernel_distance_matrix is deprecated, use gklearn.kernels.graph_kernel.compute_distance_matrix or gklearn.utils.compute_distance_matrix instead', DeprecationWarning)
  420. dis_mat = np.empty((len(Gn), len(Gn)))
  421. if Kmatrix is None:
  422. Kmatrix = compute_kernel(Gn, gkernel, node_label, edge_label, verbose)
  423. for i in range(len(Gn)):
  424. for j in range(i, len(Gn)):
  425. dis = Kmatrix[i, i] + Kmatrix[j, j] - 2 * Kmatrix[i, j]
  426. if dis < 0:
  427. if dis > -1e-10:
  428. dis = 0
  429. else:
  430. raise ValueError('The distance is negative.')
  431. dis_mat[i, j] = np.sqrt(dis)
  432. dis_mat[j, i] = dis_mat[i, j]
  433. dis_max = np.max(np.max(dis_mat))
  434. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  435. dis_mean = np.mean(np.mean(dis_mat))
  436. return dis_mat, dis_max, dis_min, dis_mean
  437. def get_same_item_indices(ls):
  438. """Get the indices of the same items in a list. Return a dict keyed by items.
  439. """
  440. idx_dict = {}
  441. for idx, item in enumerate(ls):
  442. if item in idx_dict:
  443. idx_dict[item].append(idx)
  444. else:
  445. idx_dict[item] = [idx]
  446. return idx_dict
  447. def k_nearest_neighbors_to_median_in_kernel_space(Gn, Kmatrix=None, gkernel=None,
  448. node_label=None, edge_label=None):
  449. dis_k_all = [] # distance between g_star and each graph.
  450. alpha = [1 / len(Gn)] * len(Gn)
  451. if Kmatrix is None:
  452. Kmatrix = compute_kernel(Gn, gkernel, node_label, edge_label, True)
  453. term3 = 0
  454. for i1, a1 in enumerate(alpha):
  455. for i2, a2 in enumerate(alpha):
  456. term3 += a1 * a2 * Kmatrix[idx_gi[i1], idx_gi[i2]]
  457. for ig, g in tqdm(enumerate(Gn_init), desc='computing distances', file=sys.stdout):
  458. dtemp = dis_gstar(ig, idx_gi, alpha, Kmatrix, term3=term3)
  459. dis_all.append(dtemp)
  460. def normalize_distance_matrix(D):
  461. max_value = np.amax(D)
  462. min_value = np.amin(D)
  463. return (D - min_value) / (max_value - min_value)

A Python package for graph kernels, graph edit distances and graph pre-image problem.