You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 29 19:17:36 2020
  5. @author: ljia
  6. """
  7. import os
  8. import pickle
  9. import numpy as np
  10. from tqdm import tqdm
  11. import sys
  12. from gklearn.dataset import Dataset
  13. from gklearn.experiments import DATASET_ROOT
  14. def get_dataset(ds_name):
  15. # The node/edge labels that will not be used in the computation.
  16. # if ds_name == 'MAO':
  17. # irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
  18. # if ds_name == 'Monoterpenoides':
  19. # irrelevant_labels = {'edge_labels': ['valence']}
  20. # elif ds_name == 'MUTAG':
  21. # irrelevant_labels = {'edge_labels': ['label_0']}
  22. if ds_name == 'AIDS_symb':
  23. irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']}
  24. ds_name = 'AIDS'
  25. else:
  26. irrelevant_labels = {}
  27. # Load predefined dataset.
  28. dataset = Dataset(ds_name, root=DATASET_ROOT)
  29. # Remove irrelevant labels.
  30. dataset.remove_labels(**irrelevant_labels)
  31. print('dataset size:', len(dataset.graphs))
  32. return dataset
  33. def set_edit_cost_consts(ratio, node_labeled=True, edge_labeled=True, mode='uniform'):
  34. if mode == 'uniform':
  35. edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1]
  36. if not node_labeled:
  37. edit_cost_constants[2] = 0
  38. if not edge_labeled:
  39. edit_cost_constants[5] = 0
  40. return edit_cost_constants
  41. def nested_keys_exists(element, *keys):
  42. '''
  43. Check if *keys (nested) exists in `element` (dict).
  44. '''
  45. if not isinstance(element, dict):
  46. raise AttributeError('keys_exists() expects dict as first argument.')
  47. if len(keys) == 0:
  48. raise AttributeError('keys_exists() expects at least two arguments, one given.')
  49. _element = element
  50. for key in keys:
  51. try:
  52. _element = _element[key]
  53. except KeyError:
  54. return False
  55. return True
  56. # Check average relative error along elements in two ged matrices.
  57. def matrices_ave_relative_error(m1, m2):
  58. error = 0
  59. base = 0
  60. for i in range(m1.shape[0]):
  61. for j in range(m1.shape[1]):
  62. error += np.abs(m1[i, j] - m2[i, j])
  63. base += (np.abs(m1[i, j]) + np.abs(m2[i, j])) / 2
  64. return error / base
  65. def compute_relative_error(ged_mats):
  66. if len(ged_mats) != 0:
  67. # get the smallest "correct" GED matrix.
  68. ged_mat_s = np.ones(ged_mats[0].shape) * np.inf
  69. for i in range(ged_mats[0].shape[0]):
  70. for j in range(ged_mats[0].shape[1]):
  71. ged_mat_s[i, j] = np.min([mat[i, j] for mat in ged_mats])
  72. # compute average error.
  73. errors = []
  74. for i, mat in enumerate(ged_mats):
  75. err = matrices_ave_relative_error(mat, ged_mat_s)
  76. # if not per_correct:
  77. # print('matrix # ', str(i))
  78. # pass
  79. errors.append(err)
  80. else:
  81. errors = [0]
  82. return np.mean(errors)
  83. def parse_group_file_name(fn):
  84. splits_all = fn.split('.')
  85. key1 = splits_all[1]
  86. pos2 = splits_all[2].rfind('_')
  87. # key2 = splits_all[2][:pos2]
  88. val2 = splits_all[2][pos2+1:]
  89. pos3 = splits_all[3].rfind('_')
  90. # key3 = splits_all[3][:pos3]
  91. val3 = splits_all[3][pos3+1:] + '.' + splits_all[4]
  92. return key1, val2, val3
  93. def get_all_errors(save_dir, errors):
  94. # Loop for each GED matrix file.
  95. for file in tqdm(sorted(os.listdir(save_dir)), desc='Getting errors', file=sys.stdout):
  96. if os.path.isfile(os.path.join(save_dir, file)) and file.startswith('ged_mats.'):
  97. keys = parse_group_file_name(file)
  98. # Check if the results is in the errors.
  99. if not keys[0] in errors:
  100. errors[keys[0]] = {}
  101. if not keys[1] in errors[keys[0]]:
  102. errors[keys[0]][keys[1]] = {}
  103. # Compute the error if not exist.
  104. if not keys[2] in errors[keys[0]][keys[1]]:
  105. ged_mats = np.load(os.path.join(save_dir, file))
  106. errors[keys[0]][keys[1]][keys[2]] = compute_relative_error(ged_mats)
  107. return errors
  108. def get_relative_errors(save_dir, overwrite=False):
  109. """ # Read relative errors from previous computed and saved file. Create the
  110. file, compute the errors, or add and save the new computed errors to the
  111. file if necessary.
  112. Parameters
  113. ----------
  114. save_dir : TYPE
  115. DESCRIPTION.
  116. overwrite : TYPE, optional
  117. DESCRIPTION. The default is False.
  118. Returns
  119. -------
  120. None.
  121. """
  122. if not overwrite:
  123. fn_err = save_dir + '/relative_errors.pkl'
  124. # If error file exists.
  125. if os.path.isfile(fn_err):
  126. with open(fn_err, 'rb') as f:
  127. errors = pickle.load(f)
  128. errors = get_all_errors(save_dir, errors)
  129. else:
  130. errors = get_all_errors(save_dir, {})
  131. else:
  132. errors = get_all_errors(save_dir, {})
  133. with open(fn_err, 'wb') as f:
  134. pickle.dump(errors, f)
  135. return errors
  136. def interpolate_result(Z, method='linear'):
  137. values = Z.copy()
  138. for i in range(Z.shape[0]):
  139. for j in range(Z.shape[1]):
  140. if np.isnan(Z[i, j]):
  141. # Get the nearest non-nan values.
  142. x_neg = np.nan
  143. for idx, val in enumerate(Z[i, :][j::-1]):
  144. if not np.isnan(val):
  145. x_neg = val
  146. x_neg_off = idx
  147. break
  148. x_pos = np.nan
  149. for idx, val in enumerate(Z[i, :][j:]):
  150. if not np.isnan(val):
  151. x_pos = val
  152. x_pos_off = idx
  153. break
  154. # Interpolate.
  155. if not np.isnan(x_neg) and not np.isnan(x_pos):
  156. val_int = (x_pos_off / (x_neg_off + x_pos_off)) * (x_neg - x_pos) + x_pos
  157. values[i, j] = val_int
  158. break
  159. y_neg = np.nan
  160. for idx, val in enumerate(Z[:, j][i::-1]):
  161. if not np.isnan(val):
  162. y_neg = val
  163. y_neg_off = idx
  164. break
  165. y_pos = np.nan
  166. for idx, val in enumerate(Z[:, j][i:]):
  167. if not np.isnan(val):
  168. y_pos = val
  169. y_pos_off = idx
  170. break
  171. # Interpolate.
  172. if not np.isnan(y_neg) and not np.isnan(y_pos):
  173. val_int = (y_pos_off / (y_neg_off + y_neg_off)) * (y_neg - y_pos) + y_pos
  174. values[i, j] = val_int
  175. break
  176. return values
  177. def set_axis_style(ax):
  178. ax.set_axisbelow(True)
  179. ax.spines['top'].set_visible(False)
  180. ax.spines['bottom'].set_visible(False)
  181. ax.spines['right'].set_visible(False)
  182. ax.spines['left'].set_visible(False)
  183. ax.xaxis.set_ticks_position('none')
  184. ax.yaxis.set_ticks_position('none')
  185. ax.tick_params(labelsize=8, color='w', pad=1, grid_color='w')
  186. ax.tick_params(axis='x', pad=-2)
  187. ax.tick_params(axis='y', labelrotation=-40, pad=-2)
  188. # ax.zaxis._axinfo['juggled'] = (1, 2, 0)
  189. ax.set_xlabel(ax.get_xlabel(), fontsize=10, labelpad=-3)
  190. ax.set_ylabel(ax.get_ylabel(), fontsize=10, labelpad=-2, rotation=50)
  191. ax.set_zlabel(ax.get_zlabel(), fontsize=10, labelpad=-2)
  192. ax.set_title(ax.get_title(), pad=30, fontsize=15)
  193. return
  194. if __name__ == '__main__':
  195. root_dir = 'outputs/CRIANN/'
  196. # for dir_ in sorted(os.listdir(root_dir)):
  197. # if os.path.isdir(root_dir):
  198. # full_dir = os.path.join(root_dir, dir_)
  199. # print('---', full_dir,':')
  200. # save_dir = os.path.join(full_dir, 'groups/')
  201. # if os.path.exists(save_dir):
  202. # try:
  203. # get_relative_errors(save_dir)
  204. # except Exception as exp:
  205. # print('An exception occured when running this experiment:')
  206. # print(repr(exp))

A Python package for graph kernels, graph edit distances and graph pre-image problem.