You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generate_random_preimages_by_class.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Jun 1 17:02:51 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. from gklearn.utils import Dataset
  9. import csv
  10. import os
  11. import os.path
  12. from gklearn.preimage import RandomPreimageGenerator
  13. from gklearn.utils import split_dataset_by_target
  14. from gklearn.utils.graphfiles import saveGXL
  15. def generate_random_preimages_by_class(ds_name, rpg_options, kernel_options, save_results=True, save_preimages=True, load_gm='auto', dir_save='', irrelevant_labels=None, edge_required=False, cut_range=None):
  16. # 1. get dataset.
  17. print('1. getting dataset...')
  18. dataset_all = Dataset()
  19. dataset_all.load_predefined_dataset(ds_name)
  20. dataset_all.trim_dataset(edge_required=edge_required)
  21. if irrelevant_labels is not None:
  22. dataset_all.remove_labels(**irrelevant_labels)
  23. if cut_range is not None:
  24. dataset_all.cut_graphs(cut_range)
  25. datasets = split_dataset_by_target(dataset_all)
  26. if save_results:
  27. # create result files.
  28. print('creating output files...')
  29. fn_output_detail, fn_output_summary = _init_output_file_preimage(ds_name, kernel_options['name'], dir_save)
  30. dis_k_dataset_list = []
  31. dis_k_preimage_list = []
  32. time_precompute_gm_list = []
  33. time_generate_list = []
  34. time_total_list = []
  35. itrs_list = []
  36. num_updates_list = []
  37. if load_gm == 'auto':
  38. gm_fname = dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm.npz'
  39. gmfile_exist = os.path.isfile(os.path.abspath(gm_fname))
  40. if gmfile_exist:
  41. gmfile = np.load(gm_fname, allow_pickle=True) # @todo: may not be safe.
  42. gram_matrix_unnorm_list = [item for item in gmfile['gram_matrix_unnorm_list']]
  43. time_precompute_gm_list = gmfile['run_time_list'].tolist()
  44. else:
  45. gram_matrix_unnorm_list = []
  46. time_precompute_gm_list = []
  47. elif not load_gm:
  48. gram_matrix_unnorm_list = []
  49. time_precompute_gm_list = []
  50. else:
  51. gm_fname = dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm.npz'
  52. gmfile = np.load(gm_fname, allow_pickle=True) # @todo: may not be safe.
  53. gram_matrix_unnorm_list = [item for item in gmfile['gram_matrix_unnorm_list']]
  54. time_precompute_gm_list = gmfile['run_time_list'].tolist()
  55. print('starting generating preimage for each class of target...')
  56. idx_offset = 0
  57. for idx, dataset in enumerate(datasets):
  58. target = dataset.targets[0]
  59. print('\ntarget =', target, '\n')
  60. # if target != 1:
  61. # continue
  62. num_graphs = len(dataset.graphs)
  63. if num_graphs < 2:
  64. print('\nnumber of graphs = ', num_graphs, ', skip.\n')
  65. idx_offset += 1
  66. continue
  67. # 2. set parameters.
  68. print('2. initializing mpg and setting parameters...')
  69. if load_gm:
  70. if gmfile_exist:
  71. rpg_options['gram_matrix_unnorm'] = gram_matrix_unnorm_list[idx - idx_offset]
  72. rpg_options['runtime_precompute_gm'] = time_precompute_gm_list[idx - idx_offset]
  73. rpg = RandomPreimageGenerator()
  74. rpg.dataset = dataset
  75. rpg.set_options(**rpg_options.copy())
  76. rpg.kernel_options = kernel_options.copy()
  77. # 3. compute preimage.
  78. print('3. computing preimage...')
  79. rpg.run()
  80. results = rpg.get_results()
  81. # 4. save results (and median graphs).
  82. print('4. saving results (and preimages)...')
  83. # write result detail.
  84. if save_results:
  85. print('writing results to files...')
  86. f_detail = open(dir_save + fn_output_detail, 'a')
  87. csv.writer(f_detail).writerow([ds_name, kernel_options['name'],
  88. num_graphs, target, 1,
  89. results['k_dis_dataset'], results['k_dis_preimage'],
  90. results['runtime_precompute_gm'],
  91. results['runtime_generate_preimage'], results['runtime_total'],
  92. results['itrs'], results['num_updates']])
  93. f_detail.close()
  94. # compute result summary.
  95. dis_k_dataset_list.append(results['k_dis_dataset'])
  96. dis_k_preimage_list.append(results['k_dis_preimage'])
  97. time_precompute_gm_list.append(results['runtime_precompute_gm'])
  98. time_generate_list.append(results['runtime_generate_preimage'])
  99. time_total_list.append(results['runtime_total'])
  100. itrs_list.append(results['itrs'])
  101. num_updates_list.append(results['num_updates'])
  102. # write result summary for each letter.
  103. f_summary = open(dir_save + fn_output_summary, 'a')
  104. csv.writer(f_summary).writerow([ds_name, kernel_options['name'],
  105. num_graphs, target,
  106. results['k_dis_dataset'], results['k_dis_preimage'],
  107. results['runtime_precompute_gm'],
  108. results['runtime_generate_preimage'], results['runtime_total'],
  109. results['itrs'], results['num_updates']])
  110. f_summary.close()
  111. # save median graphs.
  112. if save_preimages:
  113. os.makedirs(dir_save + 'preimages/', exist_ok=True)
  114. print('Saving preimages to files...')
  115. fn_best_dataset = dir_save + 'preimages/g_best_dataset.' + 'nbg' + str(num_graphs) + '.y' + str(target) + '.repeat' + str(1)
  116. saveGXL(rpg.best_from_dataset, fn_best_dataset + '.gxl', method='default',
  117. node_labels=dataset.node_labels, edge_labels=dataset.edge_labels,
  118. node_attrs=dataset.node_attrs, edge_attrs=dataset.edge_attrs)
  119. fn_preimage = dir_save + 'preimages/g_preimage.' + 'nbg' + str(num_graphs) + '.y' + str(target) + '.repeat' + str(1)
  120. saveGXL(rpg.preimage, fn_preimage + '.gxl', method='default',
  121. node_labels=dataset.node_labels, edge_labels=dataset.edge_labels,
  122. node_attrs=dataset.node_attrs, edge_attrs=dataset.edge_attrs)
  123. if (load_gm == 'auto' and not gmfile_exist) or not load_gm:
  124. gram_matrix_unnorm_list.append(rpg.gram_matrix_unnorm)
  125. # write result summary for each class.
  126. if save_results:
  127. dis_k_dataset_mean = np.mean(dis_k_dataset_list)
  128. dis_k_preimage_mean = np.mean(dis_k_preimage_list)
  129. time_precompute_gm_mean = np.mean(time_precompute_gm_list)
  130. time_generate_mean = np.mean(time_generate_list)
  131. time_total_mean = np.mean(time_total_list)
  132. itrs_mean = np.mean(itrs_list)
  133. num_updates_mean = np.mean(num_updates_list)
  134. f_summary = open(dir_save + fn_output_summary, 'a')
  135. csv.writer(f_summary).writerow([ds_name, kernel_options['name'],
  136. num_graphs, 'all',
  137. dis_k_dataset_mean, dis_k_preimage_mean,
  138. time_precompute_gm_mean,
  139. time_generate_mean, time_total_mean, itrs_mean,
  140. num_updates_mean])
  141. f_summary.close()
  142. # write Gram matrices to file.
  143. if (load_gm == 'auto' and not gmfile_exist) or not load_gm:
  144. np.savez(dir_save + 'gram_matrix_unnorm.' + ds_name + '.' + kernel_options['name'] + '.gm', gram_matrix_unnorm_list=gram_matrix_unnorm_list, run_time_list=time_precompute_gm_list)
  145. print('\ncomplete.\n')
  146. def _init_output_file_preimage(ds_name, gkernel, dir_output):
  147. os.makedirs(dir_output, exist_ok=True)
  148. fn_output_detail = 'results_detail.' + ds_name + '.' + gkernel + '.csv'
  149. f_detail = open(dir_output + fn_output_detail, 'a')
  150. csv.writer(f_detail).writerow(['dataset', 'graph kernel', 'num graphs',
  151. 'target', 'repeat', 'dis_k best from dataset', 'dis_k preimage',
  152. 'time precompute gm', 'time generate preimage', 'time total',
  153. 'itrs', 'num updates'])
  154. f_detail.close()
  155. fn_output_summary = 'results_summary.' + ds_name + '.' + gkernel + '.csv'
  156. f_summary = open(dir_output + fn_output_summary, 'a')
  157. csv.writer(f_summary).writerow(['dataset', 'graph kernel', 'num graphs',
  158. 'target', 'dis_k best from dataset', 'dis_k preimage',
  159. 'time precompute gm', 'time generate preimage', 'time total',
  160. 'itrs', 'num updates'])
  161. f_summary.close()
  162. return fn_output_detail, fn_output_summary

A Python package for graph kernels, graph edit distances and graph pre-image problem.