You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

edit_costs.nums_sols.ratios.IPFP.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Oct 20 11:48:02 2020
  5. @author: ljia
  6. """
  7. # This script tests the influence of the ratios between node costs and edge costs on the stability of the GED computation, where the base edit costs are [1, 1, 1, 1, 1, 1].
  8. import os
  9. import multiprocessing
  10. import pickle
  11. import logging
  12. from gklearn.utils import Dataset
  13. from gklearn.ged.util import compute_geds
  14. def get_dataset(ds_name):
  15. # The node/edge labels that will not be used in the computation.
  16. if ds_name == 'MAO':
  17. irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
  18. elif ds_name == 'Monoterpenoides':
  19. irrelevant_labels = {'edge_labels': ['valence']}
  20. elif ds_name == 'MUTAG':
  21. irrelevant_labels = {'edge_labels': ['label_0']}
  22. elif ds_name == 'AIDS_symb':
  23. irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']}
  24. # Initialize a Dataset.
  25. dataset = Dataset()
  26. # Load predefined dataset.
  27. dataset.load_predefined_dataset(ds_name)
  28. # Remove irrelevant labels.
  29. dataset.remove_labels(**irrelevant_labels)
  30. print('dataset size:', len(dataset.graphs))
  31. return dataset
  32. def xp_compute_ged_matrix(ds_name, num_solutions, ratio, trial):
  33. save_dir = 'outputs/edit_costs.num_sols.ratios.IPFP/'
  34. if not os.path.exists(save_dir):
  35. os.makedirs(save_dir)
  36. save_file_suffix = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.trial_' + str(trial)
  37. """**1. Get dataset.**"""
  38. dataset = get_dataset(ds_name)
  39. """**2. Set parameters.**"""
  40. # Parameters for GED computation.
  41. ged_options = {'method': 'IPFP', # use IPFP huristic.
  42. 'initialization_method': 'RANDOM', # or 'NODE', etc.
  43. # when bigger than 1, then the method is considered mIPFP.
  44. 'initial_solutions': int(num_solutions * 4),
  45. 'edit_cost': 'CONSTANT', # use CONSTANT cost.
  46. # the distance between non-symbolic node/edge labels is computed by euclidean distance.
  47. 'attr_distance': 'euclidean',
  48. 'ratio_runs_from_initial_solutions': 0.25,
  49. # parallel threads. Do not work if mpg_options['parallel'] = False.
  50. 'threads': multiprocessing.cpu_count(),
  51. 'init_option': 'EAGER_WITHOUT_SHUFFLED_COPIES'
  52. }
  53. edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1]
  54. # edit_cost_constants = [item * 0.01 for item in edit_cost_constants]
  55. # pickle.dump(edit_cost_constants, open(save_dir + "edit_costs" + save_file_suffix + ".pkl", "wb"))
  56. options = ged_options.copy()
  57. options['edit_cost_constants'] = edit_cost_constants
  58. options['node_labels'] = dataset.node_labels
  59. options['edge_labels'] = dataset.edge_labels
  60. options['node_attrs'] = dataset.node_attrs
  61. options['edge_attrs'] = dataset.edge_attrs
  62. parallel = True # if num_solutions == 1 else False
  63. """**5. Compute GED matrix.**"""
  64. ged_mat = 'error'
  65. try:
  66. ged_vec_init, ged_mat, n_edit_operations = compute_geds(dataset.graphs, options=options, parallel=parallel, verbose=True)
  67. except Exception as exp:
  68. print('An exception occured when running this experiment:')
  69. LOG_FILENAME = save_dir + 'error.txt'
  70. logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG)
  71. logging.exception('save_file_suffix')
  72. print(repr(exp))
  73. """**6. Get results.**"""
  74. pickle.dump(ged_mat, open(save_dir + 'ged_matrix' + save_file_suffix + '.pkl', 'wb'))
  75. if __name__ == '__main__':
  76. for ds_name in ['MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb']:
  77. print()
  78. print('Dataset:', ds_name)
  79. for num_solutions in [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
  80. print()
  81. print('# of solutions:', num_solutions)
  82. for ratio in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
  83. print()
  84. print('Ratio:', ratio)
  85. for trial in range(1, 101):
  86. print()
  87. print('Trial:', trial)
  88. xp_compute_ged_matrix(ds_name, num_solutions, ratio, trial)

A Python package for graph kernels, graph edit distances and graph pre-image problem.