You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

edit_costs.real_data.nums_sols.ratios.bipartite.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Nov 2 16:17:01 2020
  5. @author: ljia
  6. """
  7. # This script tests the influence of the ratios between node costs and edge costs on the stability of the GED computation, where the base edit costs are [1, 1, 1, 1, 1, 1]. The minimum solution from given numbers of repeats are computed.
  8. import os
  9. import multiprocessing
  10. import pickle
  11. import logging
  12. from gklearn.ged.util import compute_geds
  13. import time
  14. from utils import get_dataset, set_edit_cost_consts, dichotomous_permutation, mix_param_grids
  15. import sys
  16. from group_results import group_trials, check_group_existence, update_group_marker
  17. def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial):
  18. save_file_suffix = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.trial_' + str(trial)
  19. # Return if the file exists.
  20. if os.path.isfile(save_dir + 'ged_matrix' + save_file_suffix + '.pkl'):
  21. return None, None
  22. """**2. Set parameters.**"""
  23. # Parameters for GED computation.
  24. ged_options = {'method': 'BIPARTITE', # use BIPARTITE huristic.
  25. # 'initialization_method': 'RANDOM', # or 'NODE', etc. (for GEDEnv)
  26. 'lsape_model': 'ECBP', #
  27. # ??when bigger than 1, then the method is considered mIPFP.
  28. # the actual number of computed solutions might be smaller than the specified value
  29. 'max_num_solutions': 1, # @ max_num_solutions,
  30. 'edit_cost': 'CONSTANT', # use CONSTANT cost.
  31. 'greedy_method': 'BASIC', #
  32. # the distance between non-symbolic node/edge labels is computed by euclidean distance.
  33. 'attr_distance': 'euclidean',
  34. 'optimal': True, # if TRUE, the option --greedy-method has no effect
  35. # parallel threads. Do not work if mpg_options['parallel'] = False.
  36. 'threads': multiprocessing.cpu_count(),
  37. 'centrality_method': 'NONE',
  38. 'centrality_weight': 0.7,
  39. 'init_option': 'EAGER_WITHOUT_SHUFFLED_COPIES'
  40. }
  41. edit_cost_constants = set_edit_cost_consts(ratio,
  42. node_labeled=len(dataset.node_labels),
  43. edge_labeled=len(dataset.edge_labels),
  44. mode='uniform')
  45. # edit_cost_constants = [item * 0.01 for item in edit_cost_constants]
  46. # pickle.dump(edit_cost_constants, open(save_dir + "edit_costs" + save_file_suffix + ".pkl", "wb"))
  47. options = ged_options.copy()
  48. options['edit_cost_constants'] = edit_cost_constants
  49. options['node_labels'] = dataset.node_labels
  50. options['edge_labels'] = dataset.edge_labels
  51. options['node_attrs'] = dataset.node_attrs
  52. options['edge_attrs'] = dataset.edge_attrs
  53. parallel = True # if num_solutions == 1 else False
  54. """**5. Compute GED matrix.**"""
  55. ged_mat = 'error'
  56. runtime = 0
  57. try:
  58. time0 = time.time()
  59. ged_vec_init, ged_mat, n_edit_operations = compute_geds(dataset.graphs,
  60. options=options,
  61. repeats=num_solutions,
  62. permute_nodes=True,
  63. random_state=None,
  64. parallel=parallel,
  65. verbose=True)
  66. runtime = time.time() - time0
  67. except Exception as exp:
  68. print('An exception occured when running this experiment:')
  69. LOG_FILENAME = save_dir + 'error.txt'
  70. logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG)
  71. logging.exception(save_file_suffix)
  72. print(repr(exp))
  73. """**6. Get results.**"""
  74. with open(save_dir + 'ged_matrix' + save_file_suffix + '.pkl', 'wb') as f:
  75. pickle.dump(ged_mat, f)
  76. with open(save_dir + 'runtime' + save_file_suffix + '.pkl', 'wb') as f:
  77. pickle.dump(runtime, f)
  78. return ged_mat, runtime
  79. def save_trials_as_group(dataset, ds_name, num_solutions, ratio):
  80. # Return if the group file exists.
  81. name_middle = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.'
  82. name_group = save_dir + 'groups/ged_mats' + name_middle + 'npy'
  83. if check_group_existence(name_group):
  84. return
  85. ged_mats = []
  86. runtimes = []
  87. num_trials = 100
  88. for trial in range(1, num_trials + 1):
  89. print()
  90. print('Trial:', trial)
  91. ged_mat, runtime = xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial)
  92. ged_mats.append(ged_mat)
  93. runtimes.append(runtime)
  94. # Group trials and remove single files.
  95. # @todo: if the program stops between the following lines, then there may be errors.
  96. name_prefix = 'ged_matrix' + name_middle
  97. group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials)
  98. name_prefix = 'runtime' + name_middle
  99. group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials)
  100. update_group_marker(name_group)
  101. def results_for_a_dataset(ds_name):
  102. """**1. Get dataset.**"""
  103. dataset = get_dataset(ds_name)
  104. for params in list(param_grid):
  105. print()
  106. print(params)
  107. save_trials_as_group(dataset, ds_name, params['num_solutions'], params['ratio'])
  108. def get_param_lists(ds_name, mode='test'):
  109. if mode == 'test':
  110. num_solutions_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
  111. ratio_list = [10]
  112. return num_solutions_list, ratio_list
  113. elif mode == 'simple':
  114. from sklearn.model_selection import ParameterGrid
  115. param_grid = mix_param_grids([list(ParameterGrid([
  116. {'num_solutions': dichotomous_permutation([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 40, 50, 60, 70, 80, 90, 100]), 'ratio': [10]}])),
  117. list(ParameterGrid([
  118. {'num_solutions': [10], 'ratio': dichotomous_permutation([0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9, 10])}]))])
  119. # print(list(param_grid))
  120. if ds_name == 'AIDS_symb':
  121. num_solutions_list = [1, 20, 40, 60, 80, 100]
  122. ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9]
  123. else:
  124. num_solutions_list = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] # [1, 20, 40, 60, 80, 100]
  125. ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9, 10][::-1]
  126. return param_grid
  127. if __name__ == '__main__':
  128. if len(sys.argv) > 1:
  129. ds_name_list = sys.argv[1:]
  130. else:
  131. ds_name_list = ['Acyclic', 'Alkane_unlabeled', 'MAO_lite', 'Monoterpenoides', 'MUTAG']
  132. # ds_name_list = ['MUTAG'] # 'Alkane_unlabeled']
  133. # ds_name_list = ['Acyclic', 'MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb']
  134. save_dir = 'outputs/CRIANN/edit_costs.real_data.nums_sols.ratios.bipartite/'
  135. os.makedirs(save_dir, exist_ok=True)
  136. os.makedirs(save_dir + 'groups/', exist_ok=True)
  137. for ds_name in ds_name_list:
  138. print()
  139. print('Dataset:', ds_name)
  140. param_grid = get_param_lists(ds_name, mode='simple')
  141. results_for_a_dataset(ds_name)

A Python package for graph kernels, graph edit distances and graph pre-image problem.