You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treelet.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Apr 13 18:02:46 2020
  5. @author: ljia
  6. @references:
  7. [1] Gaüzère B, Brun L, Villemin D. Two new graphs kernels in
  8. chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47.
  9. """
  10. import sys
  11. from multiprocessing import Pool
  12. from gklearn.utils import get_iters
  13. import numpy as np
  14. import networkx as nx
  15. from collections import Counter
  16. from itertools import chain
  17. from sklearn.utils.validation import check_is_fitted
  18. from sklearn.exceptions import NotFittedError
  19. from gklearn.utils import SpecialLabel
  20. from gklearn.utils.parallel import parallel_gm, parallel_me
  21. from gklearn.utils.utils import find_all_paths, get_mlti_dim_node_attrs
  22. from gklearn.kernels import GraphKernel
  23. class Treelet(GraphKernel):
  24. def __init__(self, **kwargs):
  25. """Initialise a treelet kernel.
  26. """
  27. GraphKernel.__init__(self, **{k: kwargs.get(k) for k in ['parallel', 'n_jobs', 'chunksize', 'normalize', 'copy_graphs', 'verbose'] if k in kwargs})
  28. self.node_labels = kwargs.get('node_labels', [])
  29. self.edge_labels = kwargs.get('edge_labels', [])
  30. self.sub_kernel = kwargs.get('sub_kernel', None)
  31. self.ds_infos = kwargs.get('ds_infos', {})
  32. self.precompute_canonkeys = kwargs.get('precompute_canonkeys', True)
  33. self.save_canonkeys = kwargs.get('save_canonkeys', True)
  34. ##########################################################################
  35. # The following is the 1st paradigm to compute kernel matrix, which is
  36. # compatible with `scikit-learn`.
  37. # -------------------------------------------------------------------
  38. # Special thanks to the "GraKeL" library for providing an excellent template!
  39. ##########################################################################
  40. def clear_attributes(self):
  41. super().clear_attributes()
  42. if hasattr(self, '_canonkeys'):
  43. delattr(self, '_canonkeys')
  44. if hasattr(self, '_Y_canonkeys'):
  45. delattr(self, '_Y_canonkeys')
  46. if hasattr(self, '_dummy_labels_considered'):
  47. delattr(self, '_dummy_labels_considered')
  48. def validate_parameters(self):
  49. """Validate all parameters for the transformer.
  50. Returns
  51. -------
  52. None.
  53. """
  54. super().validate_parameters()
  55. if self.sub_kernel is None:
  56. raise ValueError('Sub-kernel not set.')
  57. def _compute_kernel_matrix_series(self, Y, X=None, load_canonkeys=True):
  58. """Compute the kernel matrix between a given target graphs (Y) and
  59. the fitted graphs (X / self._graphs) without parallelization.
  60. Parameters
  61. ----------
  62. Y : list of graphs, optional
  63. The target graphs.
  64. Returns
  65. -------
  66. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  67. The computed kernel matrix.
  68. """
  69. if_comp_X_canonkeys = True
  70. # if load saved canonkeys of X from the instance:
  71. if load_canonkeys:
  72. # Canonical keys for self._graphs.
  73. try:
  74. check_is_fitted(self, ['_canonkeys'])
  75. canonkeys_list1 = self._canonkeys
  76. if_comp_X_canonkeys = False
  77. except NotFittedError:
  78. import warnings
  79. warnings.warn('The canonkeys of self._graphs are not computed/saved. The keys of `X` is computed instead.')
  80. if_comp_X_canonkeys = True
  81. # get all canonical keys of all graphs before computing kernels to save
  82. # time, but this may cost a lot of memory for large dataset.
  83. # Compute the canonical keys of X.
  84. if if_comp_X_canonkeys:
  85. if X is None:
  86. raise('X can not be None.')
  87. # self._add_dummy_labels will modify the input in place.
  88. self._add_dummy_labels(X) # for X
  89. canonkeys_list1 = []
  90. iterator = get_iters(self._graphs, desc='Getting canonkeys for X', file=sys.stdout, verbose=(self.verbose >= 2))
  91. for g in iterator:
  92. canonkeys_list1.append(self._get_canonkeys(g))
  93. # Canonical keys for Y.
  94. # Y = [g.copy() for g in Y] # @todo: ?
  95. self._add_dummy_labels(Y)
  96. canonkeys_list2 = []
  97. iterator = get_iters(Y, desc='Getting canonkeys for Y', file=sys.stdout, verbose=(self.verbose >= 2))
  98. for g in iterator:
  99. canonkeys_list2.append(self._get_canonkeys(g))
  100. # if self.save_canonkeys:
  101. # self._Y_canonkeys = canonkeys_list2
  102. # compute kernel matrix.
  103. kernel_matrix = np.zeros((len(Y), len(canonkeys_list1)))
  104. from itertools import product
  105. itr = product(range(len(Y)), range(len(canonkeys_list1)))
  106. len_itr = int(len(Y) * len(canonkeys_list1))
  107. iterator = get_iters(itr, desc='Computing kernels', file=sys.stdout,
  108. length=len_itr, verbose=(self.verbose >= 2))
  109. for i_y, i_x in iterator:
  110. kernel = self._kernel_do(canonkeys_list2[i_y], canonkeys_list1[i_x])
  111. kernel_matrix[i_y][i_x] = kernel
  112. return kernel_matrix
  113. def _compute_kernel_matrix_imap_unordered(self, Y):
  114. """Compute the kernel matrix between a given target graphs (Y) and
  115. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  116. Parameters
  117. ----------
  118. Y : list of graphs, optional
  119. The target graphs.
  120. Returns
  121. -------
  122. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  123. The computed kernel matrix.
  124. """
  125. raise Exception('Parallelization for kernel matrix is not implemented.')
  126. def pairwise_kernel(self, x, y, are_keys=False):
  127. """Compute pairwise kernel between two graphs.
  128. Parameters
  129. ----------
  130. x, y : NetworkX Graph.
  131. Graphs bewteen which the kernel is computed.
  132. are_keys : boolean, optional
  133. If `True`, `x` and `y` are canonical keys, otherwise are graphs.
  134. The default is False.
  135. Returns
  136. -------
  137. kernel: float
  138. The computed kernel.
  139. """
  140. if are_keys:
  141. # x, y are canonical keys.
  142. kernel = self._kernel_do(x, y)
  143. else:
  144. # x, y are graphs.
  145. kernel = self._compute_single_kernel_series(x, y)
  146. return kernel
  147. def diagonals(self):
  148. """Compute the kernel matrix diagonals of the fit/transformed data.
  149. Returns
  150. -------
  151. X_diag : numpy array
  152. The diagonal of the kernel matrix between the fitted data.
  153. This consists of each element calculated with itself.
  154. Y_diag : numpy array
  155. The diagonal of the kernel matrix, of the transform.
  156. This consists of each element calculated with itself.
  157. """
  158. # Check if method "fit" had been called.
  159. check_is_fitted(self, ['_graphs'])
  160. # Check if the diagonals of X exist.
  161. try:
  162. check_is_fitted(self, ['_X_diag'])
  163. except NotFittedError:
  164. # Compute diagonals of X.
  165. self._X_diag = np.empty(shape=(len(self._graphs),))
  166. try:
  167. check_is_fitted(self, ['_canonkeys'])
  168. for i, x in enumerate(self._canonkeys):
  169. self._X_diag[i] = self.pairwise_kernel(x, x, are_keys=True) # @todo: parallel?
  170. except NotFittedError:
  171. for i, x in enumerate(self._graphs):
  172. self._X_diag[i] = self.pairwise_kernel(x, x, are_keys=False) # @todo: parallel?
  173. try:
  174. # If transform has happened, return both diagonals.
  175. check_is_fitted(self, ['_Y'])
  176. self._Y_diag = np.empty(shape=(len(self._Y),))
  177. try:
  178. check_is_fitted(self, ['_Y_canonkeys'])
  179. for (i, y) in enumerate(self._Y_canonkeys):
  180. self._Y_diag[i] = self.pairwise_kernel(y, y, are_keys=True) # @todo: parallel?
  181. except NotFittedError:
  182. for (i, y) in enumerate(self._Y):
  183. self._Y_diag[i] = self.pairwise_kernel(y, y, are_keys=False) # @todo: parallel?
  184. return self._X_diag, self._Y_diag
  185. except NotFittedError:
  186. # Else just return both X_diag
  187. return self._X_diag
  188. ##########################################################################
  189. # The following is the 2nd paradigm to compute kernel matrix. It is
  190. # simplified and not compatible with `scikit-learn`.
  191. ##########################################################################
  192. def _compute_gm_series(self, graphs):
  193. self._add_dummy_labels(graphs)
  194. # get all canonical keys of all graphs before computing kernels to save
  195. # time, but this may cost a lot of memory for large dataset.
  196. canonkeys = []
  197. iterator = get_iters(graphs, desc='getting canonkeys', file=sys.stdout,
  198. verbose=(self.verbose >= 2))
  199. for g in iterator:
  200. canonkeys.append(self._get_canonkeys(g))
  201. if self.save_canonkeys:
  202. self._canonkeys = canonkeys
  203. # compute Gram matrix.
  204. gram_matrix = np.zeros((len(graphs), len(graphs)))
  205. from itertools import combinations_with_replacement
  206. itr = combinations_with_replacement(range(0, len(graphs)), 2)
  207. len_itr = int(len(graphs) * (len(graphs) + 1) / 2)
  208. iterator = get_iters(itr, desc='Computing kernels', file=sys.stdout,
  209. length=len_itr, verbose=(self.verbose >= 2))
  210. for i, j in iterator:
  211. kernel = self._kernel_do(canonkeys[i], canonkeys[j])
  212. gram_matrix[i][j] = kernel
  213. gram_matrix[j][i] = kernel # @todo: no directed graph considered?
  214. return gram_matrix
  215. def _compute_gm_imap_unordered(self):
  216. self._add_dummy_labels(self._graphs)
  217. # get all canonical keys of all graphs before computing kernels to save
  218. # time, but this may cost a lot of memory for large dataset.
  219. pool = Pool(self.n_jobs)
  220. itr = zip(self._graphs, range(0, len(self._graphs)))
  221. if len(self._graphs) < 100 * self.n_jobs:
  222. chunksize = int(len(self._graphs) / self.n_jobs) + 1
  223. else:
  224. chunksize = 100
  225. canonkeys = [[] for _ in range(len(self._graphs))]
  226. get_fun = self._wrapper_get_canonkeys
  227. iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize),
  228. desc='getting canonkeys', file=sys.stdout,
  229. length=len(self._graphs), verbose=(self.verbose >= 2))
  230. for i, ck in iterator:
  231. canonkeys[i] = ck
  232. pool.close()
  233. pool.join()
  234. if self.save_canonkeys:
  235. self._canonkeys = canonkeys
  236. # compute Gram matrix.
  237. gram_matrix = np.zeros((len(self._graphs), len(self._graphs)))
  238. def init_worker(canonkeys_toshare):
  239. global G_canonkeys
  240. G_canonkeys = canonkeys_toshare
  241. do_fun = self._wrapper_kernel_do
  242. parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=init_worker,
  243. glbv=(canonkeys,), n_jobs=self.n_jobs, verbose=self.verbose)
  244. return gram_matrix
  245. def _compute_kernel_list_series(self, g1, g_list):
  246. # self._add_dummy_labels(g_list + [g1])
  247. # get all canonical keys of all graphs before computing kernels to save
  248. # time, but this may cost a lot of memory for large dataset.
  249. canonkeys_1 = self._get_canonkeys(g1)
  250. canonkeys_list = []
  251. iterator = get_iters(g_list, desc='getting canonkeys', file=sys.stdout, verbose=(self.verbose >= 2))
  252. for g in iterator:
  253. canonkeys_list.append(self._get_canonkeys(g))
  254. # compute kernel list.
  255. kernel_list = [None] * len(g_list)
  256. iterator = get_iters(range(len(g_list)), desc='Computing kernels', file=sys.stdout, length=len(g_list), verbose=(self.verbose >= 2))
  257. for i in iterator:
  258. kernel = self._kernel_do(canonkeys_1, canonkeys_list[i])
  259. kernel_list[i] = kernel
  260. return kernel_list
  261. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  262. self._add_dummy_labels(g_list + [g1])
  263. # get all canonical keys of all graphs before computing kernels to save
  264. # time, but this may cost a lot of memory for large dataset.
  265. canonkeys_1 = self._get_canonkeys(g1)
  266. canonkeys_list = [[] for _ in range(len(g_list))]
  267. pool = Pool(self.n_jobs)
  268. itr = zip(g_list, range(0, len(g_list)))
  269. if len(g_list) < 100 * self.n_jobs:
  270. chunksize = int(len(g_list) / self.n_jobs) + 1
  271. else:
  272. chunksize = 100
  273. get_fun = self._wrapper_get_canonkeys
  274. iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize),
  275. desc='getting canonkeys', file=sys.stdout,
  276. length=len(g_list), verbose=(self.verbose >= 2))
  277. for i, ck in iterator:
  278. canonkeys_list[i] = ck
  279. pool.close()
  280. pool.join()
  281. # compute kernel list.
  282. kernel_list = [None] * len(g_list)
  283. def init_worker(ck_1_toshare, ck_list_toshare):
  284. global G_ck_1, G_ck_list
  285. G_ck_1 = ck_1_toshare
  286. G_ck_list = ck_list_toshare
  287. do_fun = self._wrapper_kernel_list_do
  288. def func_assign(result, var_to_assign):
  289. var_to_assign[result[0]] = result[1]
  290. itr = range(len(g_list))
  291. len_itr = len(g_list)
  292. parallel_me(do_fun, func_assign, kernel_list, itr, len_itr=len_itr,
  293. init_worker=init_worker, glbv=(canonkeys_1, canonkeys_list), method='imap_unordered',
  294. n_jobs=self.n_jobs, itr_desc='Computing kernels', verbose=self.verbose)
  295. return kernel_list
  296. def _wrapper_kernel_list_do(self, itr):
  297. return itr, self._kernel_do(G_ck_1, G_ck_list[itr])
  298. def _compute_single_kernel_series(self, g1, g2):
  299. # self._add_dummy_labels([g1] + [g2])
  300. canonkeys_1 = self._get_canonkeys(g1)
  301. canonkeys_2 = self._get_canonkeys(g2)
  302. kernel = self._kernel_do(canonkeys_1, canonkeys_2)
  303. return kernel
  304. # @profile
  305. def _kernel_do(self, canonkey1, canonkey2):
  306. """Compute treelet graph kernel between 2 graphs.
  307. Parameters
  308. ----------
  309. canonkey1, canonkey2 : list
  310. List of canonical keys in 2 graphs, where each key is represented by a string.
  311. Return
  312. ------
  313. kernel : float
  314. Treelet kernel between 2 graphs.
  315. """
  316. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  317. if len(keys) == 0: # There is nothing in common...
  318. return 0
  319. vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys])
  320. vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys])
  321. # vector1, vector2 = [], []
  322. # keys1, keys2 = canonkey1, canonkey2
  323. # keys_searched = {}
  324. # for k, v in canonkey1.items():
  325. # if k in keys2:
  326. # vector1.append(v)
  327. # vector2.append(canonkey2[k])
  328. # keys_searched[k] = v
  329. # for k, v in canonkey2.items():
  330. # if k in keys1 and k not in keys_searched:
  331. # vector1.append(canonkey1[k])
  332. # vector2.append(v)
  333. # vector1, vector2 = np.array(vector1), np.array(vector2)
  334. kernel = self.sub_kernel(vector1, vector2)
  335. return kernel
  336. def _wrapper_kernel_do(self, itr):
  337. i = itr[0]
  338. j = itr[1]
  339. return i, j, self._kernel_do(G_canonkeys[i], G_canonkeys[j])
  340. def _get_canonkeys(self, G):
  341. """Generate canonical keys of all treelets in a graph.
  342. Parameters
  343. ----------
  344. G : NetworkX graphs
  345. The graph in which keys are generated.
  346. Return
  347. ------
  348. canonkey/canonkey_l : dict
  349. For unlabeled graphs, canonkey is a dictionary which records amount of
  350. every tree pattern. For labeled graphs, canonkey_l is one which keeps
  351. track of amount of every treelet.
  352. """
  353. patterns = {} # a dictionary which consists of lists of patterns for all graphlet.
  354. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern.
  355. ### structural analysis ###
  356. ### In this section, a list of patterns is generated for each graphlet,
  357. ### where every pattern is represented by nodes ordered by Morgan's
  358. ### extended labeling.
  359. # linear patterns
  360. patterns['0'] = list(G.nodes())
  361. canonkey['0'] = nx.number_of_nodes(G)
  362. for i in range(1, 6): # for i in range(1, 6):
  363. patterns[str(i)] = find_all_paths(G, i, self.ds_infos['directed'])
  364. canonkey[str(i)] = len(patterns[str(i)])
  365. # n-star patterns
  366. patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3]
  367. patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4]
  368. patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5]
  369. # n-star patterns
  370. canonkey['6'] = len(patterns['3star'])
  371. canonkey['8'] = len(patterns['4star'])
  372. canonkey['d'] = len(patterns['5star'])
  373. # pattern 7
  374. patterns['7'] = [] # the 1st line of Table 1 in Ref [1]
  375. for pattern in patterns['3star']:
  376. for i in range(1, len(pattern)): # for each neighbor of node 0
  377. if G.degree(pattern[i]) >= 2:
  378. pattern_t = pattern[:]
  379. # set the node with degree >= 2 as the 4th node
  380. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  381. for neighborx in G[pattern[i]]:
  382. if neighborx != pattern[0]:
  383. new_pattern = pattern_t + [neighborx]
  384. patterns['7'].append(new_pattern)
  385. canonkey['7'] = len(patterns['7'])
  386. # pattern 11
  387. patterns['11'] = [] # the 4th line of Table 1 in Ref [1]
  388. for pattern in patterns['4star']:
  389. for i in range(1, len(pattern)):
  390. if G.degree(pattern[i]) >= 2:
  391. pattern_t = pattern[:]
  392. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  393. for neighborx in G[pattern[i]]:
  394. if neighborx != pattern[0]:
  395. new_pattern = pattern_t + [neighborx]
  396. patterns['11'].append(new_pattern)
  397. canonkey['b'] = len(patterns['11'])
  398. # pattern 12
  399. patterns['12'] = [] # the 5th line of Table 1 in Ref [1]
  400. rootlist = [] # a list of root nodes, whose extended labels are 3
  401. for pattern in patterns['3star']:
  402. if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes
  403. rootlist.append(pattern[0])
  404. for i in range(1, len(pattern)):
  405. if G.degree(pattern[i]) >= 3:
  406. rootlist.append(pattern[i])
  407. pattern_t = pattern[:]
  408. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  409. for neighborx1 in G[pattern[i]]:
  410. if neighborx1 != pattern[0]:
  411. for neighborx2 in G[pattern[i]]:
  412. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  413. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  414. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  415. patterns['12'].append(new_pattern)
  416. canonkey['c'] = int(len(patterns['12']) / 2)
  417. # pattern 9
  418. patterns['9'] = [] # the 2nd line of Table 1 in Ref [1]
  419. for pattern in patterns['3star']:
  420. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  421. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2]:
  422. pattern_t = pattern[:]
  423. # move nodes with extended labels 4 to specific position to correspond to their children
  424. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  425. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  426. for neighborx1 in G[pairs[0]]:
  427. if neighborx1 != pattern[0]:
  428. for neighborx2 in G[pairs[1]]:
  429. if neighborx2 != pattern[0]:
  430. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  431. patterns['9'].append(new_pattern)
  432. canonkey['9'] = len(patterns['9'])
  433. # pattern 10
  434. patterns['10'] = [] # the 3rd line of Table 1 in Ref [1]
  435. for pattern in patterns['3star']:
  436. for i in range(1, len(pattern)):
  437. if G.degree(pattern[i]) >= 2:
  438. for neighborx in G[pattern[i]]:
  439. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  440. pattern_t = pattern[:]
  441. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  442. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  443. patterns['10'].extend(new_patterns)
  444. canonkey['a'] = len(patterns['10'])
  445. ### labeling information ###
  446. ### In this section, a list of canonical keys is generated for every
  447. ### pattern obtained in the structural analysis section above, which is a
  448. ### string corresponding to a unique treelet. A dictionary is built to keep
  449. ### track of the amount of every treelet.
  450. if len(self.node_labels) > 0 or len(self.edge_labels) > 0:
  451. canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet.
  452. # linear patterns
  453. canonkey_t = Counter(get_mlti_dim_node_attrs(G, self.node_labels))
  454. for key in canonkey_t:
  455. canonkey_l[('0', key)] = canonkey_t[key]
  456. for i in range(1, 6): # for i in range(1, 6):
  457. treelet = []
  458. for pattern in patterns[str(i)]:
  459. canonlist = []
  460. for idx, node in enumerate(pattern[:-1]):
  461. canonlist.append(tuple(G.nodes[node][nl] for nl in self.node_labels))
  462. canonlist.append(tuple(G[node][pattern[idx+1]][el] for el in self.edge_labels))
  463. canonlist.append(tuple(G.nodes[pattern[-1]][nl] for nl in self.node_labels))
  464. canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1]
  465. treelet.append(tuple([str(i)] + canonkey_t))
  466. canonkey_l.update(Counter(treelet))
  467. # n-star patterns
  468. for i in range(3, 6):
  469. treelet = []
  470. for pattern in patterns[str(i) + 'star']:
  471. canonlist = []
  472. for leaf in pattern[1:]:
  473. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  474. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  475. canonlist.append(tuple((nlabels, elabels)))
  476. canonlist.sort()
  477. canonlist = list(chain.from_iterable(canonlist))
  478. canonkey_t = tuple(['d' if i == 5 else str(i * 2)] +
  479. [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  480. + canonlist)
  481. treelet.append(canonkey_t)
  482. canonkey_l.update(Counter(treelet))
  483. # pattern 7
  484. treelet = []
  485. for pattern in patterns['7']:
  486. canonlist = []
  487. for leaf in pattern[1:3]:
  488. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  489. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  490. canonlist.append(tuple((nlabels, elabels)))
  491. canonlist.sort()
  492. canonlist = list(chain.from_iterable(canonlist))
  493. canonkey_t = tuple(['7']
  494. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist
  495. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  496. + [tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  497. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  498. + [tuple(G[pattern[4]][pattern[3]][el] for el in self.edge_labels)])
  499. treelet.append(canonkey_t)
  500. canonkey_l.update(Counter(treelet))
  501. # pattern 11
  502. treelet = []
  503. for pattern in patterns['11']:
  504. canonlist = []
  505. for leaf in pattern[1:4]:
  506. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  507. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  508. canonlist.append(tuple((nlabels, elabels)))
  509. canonlist.sort()
  510. canonlist = list(chain.from_iterable(canonlist))
  511. canonkey_t = tuple(['b']
  512. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist
  513. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  514. + [tuple(G[pattern[4]][pattern[0]][el] for el in self.edge_labels)]
  515. + [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels)]
  516. + [tuple(G[pattern[5]][pattern[4]][el] for el in self.edge_labels)])
  517. treelet.append(canonkey_t)
  518. canonkey_l.update(Counter(treelet))
  519. # pattern 10
  520. treelet = []
  521. for pattern in patterns['10']:
  522. canonkey4 = [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels),
  523. tuple(G[pattern[5]][pattern[4]][el] for el in self.edge_labels)]
  524. canonlist = []
  525. for leaf in pattern[1:3]:
  526. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  527. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  528. canonlist.append(tuple((nlabels, elabels)))
  529. canonlist.sort()
  530. canonkey0 = list(chain.from_iterable(canonlist))
  531. canonkey_t = tuple(['a']
  532. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  533. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  534. + [tuple(G[pattern[4]][pattern[3]][el] for el in self.edge_labels)]
  535. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  536. + [tuple(G[pattern[0]][pattern[3]][el] for el in self.edge_labels)]
  537. + canonkey4 + canonkey0)
  538. treelet.append(canonkey_t)
  539. canonkey_l.update(Counter(treelet))
  540. # pattern 12
  541. treelet = []
  542. for pattern in patterns['12']:
  543. canonlist0 = []
  544. for leaf in pattern[1:3]:
  545. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  546. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  547. canonlist0.append(tuple((nlabels, elabels)))
  548. canonlist0.sort()
  549. canonlist0 = list(chain.from_iterable(canonlist0))
  550. canonlist3 = []
  551. for leaf in pattern[4:6]:
  552. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  553. elabels = tuple(G[leaf][pattern[3]][el] for el in self.edge_labels)
  554. canonlist3.append(tuple((nlabels, elabels)))
  555. canonlist3.sort()
  556. canonlist3 = list(chain.from_iterable(canonlist3))
  557. # 2 possible key can be generated from 2 nodes with extended label 3,
  558. # select the one with lower lexicographic order.
  559. canonkey_t1 = tuple(['c']
  560. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist0
  561. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  562. + [tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  563. + canonlist3)
  564. canonkey_t2 = tuple(['c']
  565. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)] + canonlist3
  566. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  567. + [tuple(G[pattern[0]][pattern[3]][el] for el in self.edge_labels)]
  568. + canonlist0)
  569. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  570. canonkey_l.update(Counter(treelet))
  571. # pattern 9
  572. treelet = []
  573. for pattern in patterns['9']:
  574. canonkey2 = [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels),
  575. tuple(G[pattern[4]][pattern[2]][el] for el in self.edge_labels)]
  576. canonkey3 = [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels),
  577. tuple(G[pattern[5]][pattern[3]][el] for el in self.edge_labels)]
  578. prekey2 = [tuple(G.nodes[pattern[2]][nl] for nl in self.node_labels),
  579. tuple(G[pattern[2]][pattern[0]][el] for el in self.edge_labels)]
  580. prekey3 = [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels),
  581. tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  582. if prekey2 + canonkey2 < prekey3 + canonkey3:
  583. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self.node_labels)] \
  584. + [tuple(G[pattern[1]][pattern[0]][el] for el in self.edge_labels)] \
  585. + prekey2 + prekey3 + canonkey2 + canonkey3
  586. else:
  587. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self.node_labels)] \
  588. + [tuple(G[pattern[1]][pattern[0]][el] for el in self.edge_labels)] \
  589. + prekey3 + prekey2 + canonkey3 + canonkey2
  590. treelet.append(tuple(['9']
  591. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  592. + canonkey_t))
  593. canonkey_l.update(Counter(treelet))
  594. return canonkey_l
  595. return canonkey
  596. def _wrapper_get_canonkeys(self, itr_item):
  597. g = itr_item[0]
  598. i = itr_item[1]
  599. return i, self._get_canonkeys(g)
  600. def _add_dummy_labels(self, Gn=None):
  601. def _add_dummy(Gn):
  602. if len(self.node_labels) == 0 or (len(self.node_labels) == 1 and self.node_labels[0] == SpecialLabel.DUMMY):
  603. for i in range(len(Gn)):
  604. nx.set_node_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  605. self.node_labels = [SpecialLabel.DUMMY]
  606. if len(self.edge_labels) == 0 or (len(self.edge_labels) == 1 and self.edge_labels[0] == SpecialLabel.DUMMY):
  607. for i in range(len(Gn)):
  608. nx.set_edge_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  609. self.edge_labels = [SpecialLabel.DUMMY]
  610. if Gn is None or Gn is self._graphs:
  611. # Add dummy labels for the copy of self._graphs.
  612. try:
  613. check_is_fitted(self, ['_dummy_labels_considered'])
  614. if not self._dummy_labels_considered:
  615. Gn = self._graphs # @todo: ?[g.copy() for g in self._graphs]
  616. _add_dummy(Gn)
  617. self._graphs = Gn
  618. self._dummy_labels_considered = True
  619. except NotFittedError:
  620. Gn = self._graphs # @todo: ?[g.copy() for g in self._graphs]
  621. _add_dummy(Gn)
  622. self._graphs = Gn
  623. self._dummy_labels_considered = True
  624. else:
  625. # Add dummy labels for the input.
  626. _add_dummy(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.