You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Oct 29 19:17:36 2020
  5. @author: ljia
  6. """
  7. import os
  8. import pickle
  9. import numpy as np
  10. from tqdm import tqdm
  11. import sys
  12. from gklearn.dataset import Dataset
  13. from gklearn.experiments import DATASET_ROOT
  14. def get_dataset(ds_name):
  15. # The node/edge labels that will not be used in the computation.
  16. # if ds_name == 'MAO':
  17. # irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
  18. # if ds_name == 'Monoterpenoides':
  19. # irrelevant_labels = {'edge_labels': ['valence']}
  20. # elif ds_name == 'MUTAG':
  21. # irrelevant_labels = {'edge_labels': ['label_0']}
  22. if ds_name == 'AIDS_symb':
  23. irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']}
  24. ds_name = 'AIDS'
  25. else:
  26. irrelevant_labels = {}
  27. # Load predefined dataset.
  28. dataset = Dataset(ds_name, root=DATASET_ROOT)
  29. # Remove irrelevant labels.
  30. dataset.remove_labels(**irrelevant_labels)
  31. print('dataset size:', len(dataset.graphs))
  32. return dataset
  33. def set_edit_cost_consts(ratio, node_labeled=True, edge_labeled=True, mode='uniform'):
  34. if mode == 'uniform':
  35. edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1]
  36. if not node_labeled:
  37. edit_cost_constants[2] = 0
  38. if not edge_labeled:
  39. edit_cost_constants[5] = 0
  40. return edit_cost_constants
  41. def nested_keys_exists(element, *keys):
  42. '''
  43. Check if *keys (nested) exists in `element` (dict).
  44. '''
  45. if not isinstance(element, dict):
  46. raise AttributeError('keys_exists() expects dict as first argument.')
  47. if len(keys) == 0:
  48. raise AttributeError('keys_exists() expects at least two arguments, one given.')
  49. _element = element
  50. for key in keys:
  51. try:
  52. _element = _element[key]
  53. except KeyError:
  54. return False
  55. return True
  56. # Check average relative error along elements in two ged matrices.
  57. def matrices_ave_relative_error(m1, m2):
  58. error = 0
  59. base = 0
  60. for i in range(m1.shape[0]):
  61. for j in range(m1.shape[1]):
  62. error += np.abs(m1[i, j] - m2[i, j])
  63. # base += (np.abs(m1[i, j]) + np.abs(m2[i, j]))
  64. base += (m1[i, j] + m2[i, j]) # Require only 25% of the time of "base += (np.abs(m1[i, j]) + np.abs(m2[i, j]))".
  65. base = base / 2
  66. return error / base
  67. def compute_relative_error(ged_mats):
  68. if len(ged_mats) != 0:
  69. # get the smallest "correct" GED matrix.
  70. ged_mat_s = np.ones(ged_mats[0].shape) * np.inf
  71. for i in range(ged_mats[0].shape[0]):
  72. for j in range(ged_mats[0].shape[1]):
  73. ged_mat_s[i, j] = np.min([mat[i, j] for mat in ged_mats])
  74. # compute average error.
  75. errors = []
  76. for i, mat in enumerate(ged_mats):
  77. err = matrices_ave_relative_error(mat, ged_mat_s)
  78. # if not per_correct:
  79. # print('matrix # ', str(i))
  80. # pass
  81. errors.append(err)
  82. else:
  83. errors = [0]
  84. return np.mean(errors)
  85. def parse_group_file_name(fn):
  86. splits_all = fn.split('.')
  87. key1 = splits_all[1]
  88. pos2 = splits_all[2].rfind('_')
  89. # key2 = splits_all[2][:pos2]
  90. val2 = splits_all[2][pos2+1:]
  91. pos3 = splits_all[3].rfind('_')
  92. # key3 = splits_all[3][:pos3]
  93. val3 = splits_all[3][pos3+1:] + '.' + splits_all[4]
  94. return key1, val2, val3
  95. def get_all_errors(save_dir, errors):
  96. # Loop for each GED matrix file.
  97. for file in tqdm(sorted(os.listdir(save_dir)), desc='Getting errors', file=sys.stdout):
  98. if os.path.isfile(os.path.join(save_dir, file)) and file.startswith('ged_mats.'):
  99. keys = parse_group_file_name(file)
  100. # Check if the results is in the errors.
  101. if not keys[0] in errors:
  102. errors[keys[0]] = {}
  103. if not keys[1] in errors[keys[0]]:
  104. errors[keys[0]][keys[1]] = {}
  105. # Compute the error if not exist.
  106. if not keys[2] in errors[keys[0]][keys[1]]:
  107. ged_mats = np.load(os.path.join(save_dir, file))
  108. errors[keys[0]][keys[1]][keys[2]] = compute_relative_error(ged_mats)
  109. return errors
  110. def get_relative_errors(save_dir, overwrite=False):
  111. """ # Read relative errors from previous computed and saved file. Create the
  112. file, compute the errors, or add and save the new computed errors to the
  113. file if necessary.
  114. Parameters
  115. ----------
  116. save_dir : TYPE
  117. DESCRIPTION.
  118. overwrite : TYPE, optional
  119. DESCRIPTION. The default is False.
  120. Returns
  121. -------
  122. None.
  123. """
  124. if not overwrite:
  125. fn_err = save_dir + '/relative_errors.pkl'
  126. # If error file exists.
  127. if os.path.isfile(fn_err):
  128. with open(fn_err, 'rb') as f:
  129. errors = pickle.load(f)
  130. errors = get_all_errors(save_dir, errors)
  131. else:
  132. errors = get_all_errors(save_dir, {})
  133. else:
  134. errors = get_all_errors(save_dir, {})
  135. with open(fn_err, 'wb') as f:
  136. pickle.dump(errors, f)
  137. return errors
  138. def interpolate_result(Z, method='linear'):
  139. values = Z.copy()
  140. for i in range(Z.shape[0]):
  141. for j in range(Z.shape[1]):
  142. if np.isnan(Z[i, j]):
  143. # Get the nearest non-nan values.
  144. x_neg = np.nan
  145. for idx, val in enumerate(Z[i, :][j::-1]):
  146. if not np.isnan(val):
  147. x_neg = val
  148. x_neg_off = idx
  149. break
  150. x_pos = np.nan
  151. for idx, val in enumerate(Z[i, :][j:]):
  152. if not np.isnan(val):
  153. x_pos = val
  154. x_pos_off = idx
  155. break
  156. # Interpolate.
  157. if not np.isnan(x_neg) and not np.isnan(x_pos):
  158. val_int = (x_pos_off / (x_neg_off + x_pos_off)) * (x_neg - x_pos) + x_pos
  159. values[i, j] = val_int
  160. break
  161. y_neg = np.nan
  162. for idx, val in enumerate(Z[:, j][i::-1]):
  163. if not np.isnan(val):
  164. y_neg = val
  165. y_neg_off = idx
  166. break
  167. y_pos = np.nan
  168. for idx, val in enumerate(Z[:, j][i:]):
  169. if not np.isnan(val):
  170. y_pos = val
  171. y_pos_off = idx
  172. break
  173. # Interpolate.
  174. if not np.isnan(y_neg) and not np.isnan(y_pos):
  175. val_int = (y_pos_off / (y_neg_off + y_neg_off)) * (y_neg - y_pos) + y_pos
  176. values[i, j] = val_int
  177. break
  178. return values
  179. def set_axis_style(ax):
  180. ax.set_axisbelow(True)
  181. ax.spines['top'].set_visible(False)
  182. ax.spines['bottom'].set_visible(False)
  183. ax.spines['right'].set_visible(False)
  184. ax.spines['left'].set_visible(False)
  185. ax.xaxis.set_ticks_position('none')
  186. ax.yaxis.set_ticks_position('none')
  187. ax.tick_params(labelsize=8, color='w', pad=1, grid_color='w')
  188. ax.tick_params(axis='x', pad=-2)
  189. ax.tick_params(axis='y', labelrotation=-40, pad=-2)
  190. # ax.zaxis._axinfo['juggled'] = (1, 2, 0)
  191. ax.set_xlabel(ax.get_xlabel(), fontsize=10, labelpad=-3)
  192. ax.set_ylabel(ax.get_ylabel(), fontsize=10, labelpad=-2, rotation=50)
  193. ax.set_zlabel(ax.get_zlabel(), fontsize=10, labelpad=-2)
  194. ax.set_title(ax.get_title(), pad=30, fontsize=15)
  195. return
  196. def dichotomous_permutation(arr, layer=0):
  197. import math
  198. # def seperate_arr(arr, new_arr):
  199. # if (length % 2) == 0:
  200. # half = int(length / 2)
  201. # new_arr += [arr[half - 1], arr[half]]
  202. # subarr1 = [arr[i] for i in range(1, half - 1)]
  203. # else:
  204. # half = math.floor(length / 2)
  205. # new_arr.append(arr[half])
  206. # subarr1 = [arr[i] for i in range(1, half)]
  207. # subarr2 = [arr[i] for i in range(half + 1, length - 1)]
  208. # subarrs = [subarr1, subarr2]
  209. # return subarrs
  210. if layer == 0:
  211. length = len(arr)
  212. if length <= 2:
  213. return arr
  214. new_arr = [arr[0], arr[-1]]
  215. if (length % 2) == 0:
  216. half = int(length / 2)
  217. new_arr += [arr[half - 1], arr[half]]
  218. subarr1 = [arr[i] for i in range(1, half - 1)]
  219. else:
  220. half = math.floor(length / 2)
  221. new_arr.append(arr[half])
  222. subarr1 = [arr[i] for i in range(1, half)]
  223. subarr2 = [arr[i] for i in range(half + 1, length - 1)]
  224. subarrs = [subarr1, subarr2]
  225. # subarrs = seperate_arr(arr, new_arr)
  226. new_arr += dichotomous_permutation(subarrs, layer=layer+1)
  227. else:
  228. new_arr = []
  229. subarrs = []
  230. for a in arr:
  231. length = len(a)
  232. if length <= 2:
  233. new_arr += a
  234. else:
  235. # subarrs += seperate_arr(a, new_arr)
  236. if (length % 2) == 0:
  237. half = int(length / 2)
  238. new_arr += [a[half - 1], a[half]]
  239. subarr1 = [a[i] for i in range(0, half - 1)]
  240. else:
  241. half = math.floor(length / 2)
  242. new_arr.append(a[half])
  243. subarr1 = [a[i] for i in range(0, half)]
  244. subarr2 = [a[i] for i in range(half + 1, length)]
  245. subarrs += [subarr1, subarr2]
  246. if len(subarrs) > 0:
  247. new_arr += dichotomous_permutation(subarrs, layer=layer+1)
  248. return new_arr
  249. # length = len(arr)
  250. # if length <= 2:
  251. # return arr
  252. # new_arr = [arr[0], arr[-1]]
  253. # if (length % 2) == 0:
  254. # half = int(length / 2)
  255. # new_arr += [arr[half - 1], arr[half]]
  256. # subarr1 = [arr[i] for i in range(1, half - 1)]
  257. # else:
  258. # half = math.floor(length / 2)
  259. # new_arr.append(arr[half])
  260. # subarr1 = [arr[i] for i in range(1, half)]
  261. # subarr2 = [arr[i] for i in range(half + 1, length - 1)]
  262. # if len(subarr1) > 0:
  263. # new_arr += dichotomous_permutation(subarr1)
  264. # if len(subarr2) > 0:
  265. # new_arr += dichotomous_permutation(subarr2)
  266. # return new_arr
  267. if __name__ == '__main__':
  268. root_dir = 'outputs/CRIANN/'
  269. # for dir_ in sorted(os.listdir(root_dir)):
  270. # if os.path.isdir(root_dir):
  271. # full_dir = os.path.join(root_dir, dir_)
  272. # print('---', full_dir,':')
  273. # save_dir = os.path.join(full_dir, 'groups/')
  274. # if os.path.exists(save_dir):
  275. # try:
  276. # get_relative_errors(save_dir)
  277. # except Exception as exp:
  278. # print('An exception occured when running this experiment:')
  279. # print(repr(exp))

A Python package for graph kernels, graph edit distances and graph pre-image problem.