You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treelet.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Apr 13 18:02:46 2020
  5. @author: ljia
  6. @references:
  7. [1] Gaüzère B, Brun L, Villemin D. Two new graphs kernels in
  8. chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47.
  9. """
  10. import sys
  11. from multiprocessing import Pool
  12. from gklearn.utils import get_iters
  13. import numpy as np
  14. import networkx as nx
  15. from collections import Counter
  16. from itertools import chain
  17. from sklearn.utils.validation import check_is_fitted
  18. from sklearn.exceptions import NotFittedError
  19. from gklearn.utils import SpecialLabel
  20. from gklearn.utils.parallel import parallel_gm, parallel_me
  21. from gklearn.utils.utils import find_all_paths, get_mlti_dim_node_attrs
  22. from gklearn.kernels import GraphKernel
  23. class Treelet(GraphKernel):
  24. def __init__(self, parallel=None, n_jobs=None, chunksize=None, normalize=True, verbose=2, precompute_canonkeys=True, save_canonkeys=False, **kwargs):
  25. """Initialise a treelet kernel.
  26. """
  27. super().__init__(parallel=parallel, n_jobs=n_jobs, chunksize=chunksize, normalize=normalize, verbose=verbose)
  28. self.node_labels = kwargs.get('node_labels', [])
  29. self.edge_labels = kwargs.get('edge_labels', [])
  30. self.sub_kernel = kwargs.get('sub_kernel', None)
  31. self.ds_infos = kwargs.get('ds_infos', {})
  32. self.precompute_canonkeys = precompute_canonkeys
  33. self.save_canonkeys = save_canonkeys
  34. ##########################################################################
  35. # The following is the 1st paradigm to compute kernel matrix, which is
  36. # compatible with `scikit-learn`.
  37. # -------------------------------------------------------------------
  38. # Special thanks to the "GraKeL" library for providing an excellent template!
  39. ##########################################################################
  40. def clear_attributes(self):
  41. super().clear_attributes()
  42. if hasattr(self, '_canonkeys'):
  43. delattr(self, '_canonkeys')
  44. if hasattr(self, '_Y_canonkeys'):
  45. delattr(self, '_Y_canonkeys')
  46. if hasattr(self, '_dummy_labels_considered'):
  47. delattr(self, '_dummy_labels_considered')
  48. def validate_parameters(self):
  49. """Validate all parameters for the transformer.
  50. Returns
  51. -------
  52. None.
  53. """
  54. super().validate_parameters()
  55. if self.sub_kernel is None:
  56. raise ValueError('Sub-kernel not set.')
  57. def _compute_kernel_matrix_series(self, Y):
  58. """Compute the kernel matrix between a given target graphs (Y) and
  59. the fitted graphs (X / self._graphs) without parallelization.
  60. Parameters
  61. ----------
  62. Y : list of graphs, optional
  63. The target graphs.
  64. Returns
  65. -------
  66. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  67. The computed kernel matrix.
  68. """
  69. # self._add_dummy_labels will modify the input in place.
  70. self._add_dummy_labels() # For self._graphs
  71. # Y = [g.copy() for g in Y] # @todo: ?
  72. self._add_dummy_labels(Y)
  73. # get all canonical keys of all graphs before computing kernels to save
  74. # time, but this may cost a lot of memory for large dataset.
  75. # Canonical keys for self._graphs.
  76. try:
  77. check_is_fitted(self, ['_canonkeys'])
  78. canonkeys_list1 = self._canonkeys
  79. except NotFittedError:
  80. canonkeys_list1 = []
  81. iterator = get_iters(self._graphs, desc='getting canonkeys for X', file=sys.stdout, verbose=(self.verbose >= 2))
  82. for g in iterator:
  83. canonkeys_list1.append(self._get_canonkeys(g))
  84. if self.save_canonkeys:
  85. self._canonkeys = canonkeys_list1
  86. # Canonical keys for Y.
  87. canonkeys_list2 = []
  88. iterator = get_iters(Y, desc='getting canonkeys for Y', file=sys.stdout, verbose=(self.verbose >= 2))
  89. for g in iterator:
  90. canonkeys_list2.append(self._get_canonkeys(g))
  91. if self.save_canonkeys:
  92. self._Y_canonkeys = canonkeys_list2
  93. # compute kernel matrix.
  94. kernel_matrix = np.zeros((len(Y), len(canonkeys_list1)))
  95. from itertools import product
  96. itr = product(range(len(Y)), range(len(canonkeys_list1)))
  97. len_itr = int(len(Y) * len(canonkeys_list1))
  98. iterator = get_iters(itr, desc='Computing kernels', file=sys.stdout,
  99. length=len_itr, verbose=(self.verbose >= 2))
  100. for i_y, i_x in iterator:
  101. kernel = self._kernel_do(canonkeys_list2[i_y], canonkeys_list1[i_x])
  102. kernel_matrix[i_y][i_x] = kernel
  103. return kernel_matrix
  104. def _compute_kernel_matrix_imap_unordered(self, Y):
  105. """Compute the kernel matrix between a given target graphs (Y) and
  106. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  107. Parameters
  108. ----------
  109. Y : list of graphs, optional
  110. The target graphs.
  111. Returns
  112. -------
  113. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  114. The computed kernel matrix.
  115. """
  116. raise Exception('Parallelization for kernel matrix is not implemented.')
  117. def pairwise_kernel(self, x, y, are_keys=False):
  118. """Compute pairwise kernel between two graphs.
  119. Parameters
  120. ----------
  121. x, y : NetworkX Graph.
  122. Graphs bewteen which the kernel is computed.
  123. are_keys : boolean, optional
  124. If `True`, `x` and `y` are canonical keys, otherwise are graphs.
  125. The default is False.
  126. Returns
  127. -------
  128. kernel: float
  129. The computed kernel.
  130. """
  131. if are_keys:
  132. # x, y are canonical keys.
  133. kernel = self._kernel_do(x, y)
  134. else:
  135. # x, y are graphs.
  136. kernel = self._compute_single_kernel_series(x, y)
  137. return kernel
  138. def diagonals(self):
  139. """Compute the kernel matrix diagonals of the fit/transformed data.
  140. Returns
  141. -------
  142. X_diag : numpy array
  143. The diagonal of the kernel matrix between the fitted data.
  144. This consists of each element calculated with itself.
  145. Y_diag : numpy array
  146. The diagonal of the kernel matrix, of the transform.
  147. This consists of each element calculated with itself.
  148. """
  149. # Check if method "fit" had been called.
  150. check_is_fitted(self, ['_graphs'])
  151. # Check if the diagonals of X exist.
  152. try:
  153. check_is_fitted(self, ['_X_diag'])
  154. except NotFittedError:
  155. # Compute diagonals of X.
  156. self._X_diag = np.empty(shape=(len(self._graphs),))
  157. try:
  158. check_is_fitted(self, ['_canonkeys'])
  159. for i, x in enumerate(self._canonkeys):
  160. self._X_diag[i] = self.pairwise_kernel(x, x, are_keys=True) # @todo: parallel?
  161. except NotFittedError:
  162. for i, x in enumerate(self._graphs):
  163. self._X_diag[i] = self.pairwise_kernel(x, x, are_keys=False) # @todo: parallel?
  164. try:
  165. # If transform has happened, return both diagonals.
  166. check_is_fitted(self, ['_Y'])
  167. self._Y_diag = np.empty(shape=(len(self._Y),))
  168. try:
  169. check_is_fitted(self, ['_Y_canonkeys'])
  170. for (i, y) in enumerate(self._Y_canonkeys):
  171. self._Y_diag[i] = self.pairwise_kernel(y, y, are_keys=True) # @todo: parallel?
  172. except NotFittedError:
  173. for (i, y) in enumerate(self._Y):
  174. self._Y_diag[i] = self.pairwise_kernel(y, y, are_keys=False) # @todo: parallel?
  175. return self._X_diag, self._Y_diag
  176. except NotFittedError:
  177. # Else just return both X_diag
  178. return self._X_diag
  179. ##########################################################################
  180. # The following is the 2nd paradigm to compute kernel matrix. It is
  181. # simplified and not compatible with `scikit-learn`.
  182. ##########################################################################
  183. def _compute_gm_series(self):
  184. self._add_dummy_labels(self._graphs)
  185. # get all canonical keys of all graphs before computing kernels to save
  186. # time, but this may cost a lot of memory for large dataset.
  187. canonkeys = []
  188. iterator = get_iters(self._graphs, desc='getting canonkeys', file=sys.stdout,
  189. verbose=(self.verbose >= 2))
  190. for g in iterator:
  191. canonkeys.append(self._get_canonkeys(g))
  192. if self.save_canonkeys:
  193. self._canonkeys = canonkeys
  194. # compute Gram matrix.
  195. gram_matrix = np.zeros((len(self._graphs), len(self._graphs)))
  196. from itertools import combinations_with_replacement
  197. itr = combinations_with_replacement(range(0, len(self._graphs)), 2)
  198. len_itr = int(len(self._graphs) * (len(self._graphs) + 1) / 2)
  199. iterator = get_iters(itr, desc='Computing kernels', file=sys.stdout,
  200. length=len_itr, verbose=(self.verbose >= 2))
  201. for i, j in iterator:
  202. kernel = self._kernel_do(canonkeys[i], canonkeys[j])
  203. gram_matrix[i][j] = kernel
  204. gram_matrix[j][i] = kernel # @todo: no directed graph considered?
  205. return gram_matrix
  206. def _compute_gm_imap_unordered(self):
  207. self._add_dummy_labels(self._graphs)
  208. # get all canonical keys of all graphs before computing kernels to save
  209. # time, but this may cost a lot of memory for large dataset.
  210. pool = Pool(self.n_jobs)
  211. itr = zip(self._graphs, range(0, len(self._graphs)))
  212. if len(self._graphs) < 100 * self.n_jobs:
  213. chunksize = int(len(self._graphs) / self.n_jobs) + 1
  214. else:
  215. chunksize = 100
  216. canonkeys = [[] for _ in range(len(self._graphs))]
  217. get_fun = self._wrapper_get_canonkeys
  218. iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize),
  219. desc='getting canonkeys', file=sys.stdout,
  220. length=len(self._graphs), verbose=(self.verbose >= 2))
  221. for i, ck in iterator:
  222. canonkeys[i] = ck
  223. pool.close()
  224. pool.join()
  225. if self.save_canonkeys:
  226. self._canonkeys = canonkeys
  227. # compute Gram matrix.
  228. gram_matrix = np.zeros((len(self._graphs), len(self._graphs)))
  229. def init_worker(canonkeys_toshare):
  230. global G_canonkeys
  231. G_canonkeys = canonkeys_toshare
  232. do_fun = self._wrapper_kernel_do
  233. parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=init_worker,
  234. glbv=(canonkeys,), n_jobs=self.n_jobs, verbose=self.verbose)
  235. return gram_matrix
  236. def _compute_kernel_list_series(self, g1, g_list):
  237. # self._add_dummy_labels(g_list + [g1])
  238. # get all canonical keys of all graphs before computing kernels to save
  239. # time, but this may cost a lot of memory for large dataset.
  240. canonkeys_1 = self._get_canonkeys(g1)
  241. canonkeys_list = []
  242. iterator = get_iters(g_list, desc='getting canonkeys', file=sys.stdout, verbose=(self.verbose >= 2))
  243. for g in iterator:
  244. canonkeys_list.append(self._get_canonkeys(g))
  245. # compute kernel list.
  246. kernel_list = [None] * len(g_list)
  247. iterator = get_iters(range(len(g_list)), desc='Computing kernels', file=sys.stdout, length=len(g_list), verbose=(self.verbose >= 2))
  248. for i in iterator:
  249. kernel = self._kernel_do(canonkeys_1, canonkeys_list[i])
  250. kernel_list[i] = kernel
  251. return kernel_list
  252. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  253. self._add_dummy_labels(g_list + [g1])
  254. # get all canonical keys of all graphs before computing kernels to save
  255. # time, but this may cost a lot of memory for large dataset.
  256. canonkeys_1 = self._get_canonkeys(g1)
  257. canonkeys_list = [[] for _ in range(len(g_list))]
  258. pool = Pool(self.n_jobs)
  259. itr = zip(g_list, range(0, len(g_list)))
  260. if len(g_list) < 100 * self.n_jobs:
  261. chunksize = int(len(g_list) / self.n_jobs) + 1
  262. else:
  263. chunksize = 100
  264. get_fun = self._wrapper_get_canonkeys
  265. iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize),
  266. desc='getting canonkeys', file=sys.stdout,
  267. length=len(g_list), verbose=(self.verbose >= 2))
  268. for i, ck in iterator:
  269. canonkeys_list[i] = ck
  270. pool.close()
  271. pool.join()
  272. # compute kernel list.
  273. kernel_list = [None] * len(g_list)
  274. def init_worker(ck_1_toshare, ck_list_toshare):
  275. global G_ck_1, G_ck_list
  276. G_ck_1 = ck_1_toshare
  277. G_ck_list = ck_list_toshare
  278. do_fun = self._wrapper_kernel_list_do
  279. def func_assign(result, var_to_assign):
  280. var_to_assign[result[0]] = result[1]
  281. itr = range(len(g_list))
  282. len_itr = len(g_list)
  283. parallel_me(do_fun, func_assign, kernel_list, itr, len_itr=len_itr,
  284. init_worker=init_worker, glbv=(canonkeys_1, canonkeys_list), method='imap_unordered',
  285. n_jobs=self.n_jobs, itr_desc='Computing kernels', verbose=self.verbose)
  286. return kernel_list
  287. def _wrapper_kernel_list_do(self, itr):
  288. return itr, self._kernel_do(G_ck_1, G_ck_list[itr])
  289. def _compute_single_kernel_series(self, g1, g2):
  290. # self._add_dummy_labels([g1] + [g2])
  291. canonkeys_1 = self._get_canonkeys(g1)
  292. canonkeys_2 = self._get_canonkeys(g2)
  293. kernel = self._kernel_do(canonkeys_1, canonkeys_2)
  294. return kernel
  295. # @profile
  296. def _kernel_do(self, canonkey1, canonkey2):
  297. """Compute treelet graph kernel between 2 graphs.
  298. Parameters
  299. ----------
  300. canonkey1, canonkey2 : list
  301. List of canonical keys in 2 graphs, where each key is represented by a string.
  302. Return
  303. ------
  304. kernel : float
  305. Treelet kernel between 2 graphs.
  306. """
  307. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  308. vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys])
  309. vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys])
  310. # vector1, vector2 = [], []
  311. # keys1, keys2 = canonkey1, canonkey2
  312. # keys_searched = {}
  313. # for k, v in canonkey1.items():
  314. # if k in keys2:
  315. # vector1.append(v)
  316. # vector2.append(canonkey2[k])
  317. # keys_searched[k] = v
  318. # for k, v in canonkey2.items():
  319. # if k in keys1 and k not in keys_searched:
  320. # vector1.append(canonkey1[k])
  321. # vector2.append(v)
  322. # vector1, vector2 = np.array(vector1), np.array(vector2)
  323. kernel = self.sub_kernel(vector1, vector2)
  324. return kernel
  325. def _wrapper_kernel_do(self, itr):
  326. i = itr[0]
  327. j = itr[1]
  328. return i, j, self._kernel_do(G_canonkeys[i], G_canonkeys[j])
  329. def _get_canonkeys(self, G):
  330. """Generate canonical keys of all treelets in a graph.
  331. Parameters
  332. ----------
  333. G : NetworkX graphs
  334. The graph in which keys are generated.
  335. Return
  336. ------
  337. canonkey/canonkey_l : dict
  338. For unlabeled graphs, canonkey is a dictionary which records amount of
  339. every tree pattern. For labeled graphs, canonkey_l is one which keeps
  340. track of amount of every treelet.
  341. """
  342. patterns = {} # a dictionary which consists of lists of patterns for all graphlet.
  343. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern.
  344. ### structural analysis ###
  345. ### In this section, a list of patterns is generated for each graphlet,
  346. ### where every pattern is represented by nodes ordered by Morgan's
  347. ### extended labeling.
  348. # linear patterns
  349. patterns['0'] = list(G.nodes())
  350. canonkey['0'] = nx.number_of_nodes(G)
  351. for i in range(1, 6): # for i in range(1, 6):
  352. patterns[str(i)] = find_all_paths(G, i, self.ds_infos['directed'])
  353. canonkey[str(i)] = len(patterns[str(i)])
  354. # n-star patterns
  355. patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3]
  356. patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4]
  357. patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5]
  358. # n-star patterns
  359. canonkey['6'] = len(patterns['3star'])
  360. canonkey['8'] = len(patterns['4star'])
  361. canonkey['d'] = len(patterns['5star'])
  362. # pattern 7
  363. patterns['7'] = [] # the 1st line of Table 1 in Ref [1]
  364. for pattern in patterns['3star']:
  365. for i in range(1, len(pattern)): # for each neighbor of node 0
  366. if G.degree(pattern[i]) >= 2:
  367. pattern_t = pattern[:]
  368. # set the node with degree >= 2 as the 4th node
  369. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  370. for neighborx in G[pattern[i]]:
  371. if neighborx != pattern[0]:
  372. new_pattern = pattern_t + [neighborx]
  373. patterns['7'].append(new_pattern)
  374. canonkey['7'] = len(patterns['7'])
  375. # pattern 11
  376. patterns['11'] = [] # the 4th line of Table 1 in Ref [1]
  377. for pattern in patterns['4star']:
  378. for i in range(1, len(pattern)):
  379. if G.degree(pattern[i]) >= 2:
  380. pattern_t = pattern[:]
  381. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  382. for neighborx in G[pattern[i]]:
  383. if neighborx != pattern[0]:
  384. new_pattern = pattern_t + [neighborx]
  385. patterns['11'].append(new_pattern)
  386. canonkey['b'] = len(patterns['11'])
  387. # pattern 12
  388. patterns['12'] = [] # the 5th line of Table 1 in Ref [1]
  389. rootlist = [] # a list of root nodes, whose extended labels are 3
  390. for pattern in patterns['3star']:
  391. if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes
  392. rootlist.append(pattern[0])
  393. for i in range(1, len(pattern)):
  394. if G.degree(pattern[i]) >= 3:
  395. rootlist.append(pattern[i])
  396. pattern_t = pattern[:]
  397. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  398. for neighborx1 in G[pattern[i]]:
  399. if neighborx1 != pattern[0]:
  400. for neighborx2 in G[pattern[i]]:
  401. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  402. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  403. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  404. patterns['12'].append(new_pattern)
  405. canonkey['c'] = int(len(patterns['12']) / 2)
  406. # pattern 9
  407. patterns['9'] = [] # the 2nd line of Table 1 in Ref [1]
  408. for pattern in patterns['3star']:
  409. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  410. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2]:
  411. pattern_t = pattern[:]
  412. # move nodes with extended labels 4 to specific position to correspond to their children
  413. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  414. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  415. for neighborx1 in G[pairs[0]]:
  416. if neighborx1 != pattern[0]:
  417. for neighborx2 in G[pairs[1]]:
  418. if neighborx2 != pattern[0]:
  419. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  420. patterns['9'].append(new_pattern)
  421. canonkey['9'] = len(patterns['9'])
  422. # pattern 10
  423. patterns['10'] = [] # the 3rd line of Table 1 in Ref [1]
  424. for pattern in patterns['3star']:
  425. for i in range(1, len(pattern)):
  426. if G.degree(pattern[i]) >= 2:
  427. for neighborx in G[pattern[i]]:
  428. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  429. pattern_t = pattern[:]
  430. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  431. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  432. patterns['10'].extend(new_patterns)
  433. canonkey['a'] = len(patterns['10'])
  434. ### labeling information ###
  435. ### In this section, a list of canonical keys is generated for every
  436. ### pattern obtained in the structural analysis section above, which is a
  437. ### string corresponding to a unique treelet. A dictionary is built to keep
  438. ### track of the amount of every treelet.
  439. if len(self.node_labels) > 0 or len(self.edge_labels) > 0:
  440. canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet.
  441. # linear patterns
  442. canonkey_t = Counter(get_mlti_dim_node_attrs(G, self.node_labels))
  443. for key in canonkey_t:
  444. canonkey_l[('0', key)] = canonkey_t[key]
  445. for i in range(1, 6): # for i in range(1, 6):
  446. treelet = []
  447. for pattern in patterns[str(i)]:
  448. canonlist = []
  449. for idx, node in enumerate(pattern[:-1]):
  450. canonlist.append(tuple(G.nodes[node][nl] for nl in self.node_labels))
  451. canonlist.append(tuple(G[node][pattern[idx+1]][el] for el in self.edge_labels))
  452. canonlist.append(tuple(G.nodes[pattern[-1]][nl] for nl in self.node_labels))
  453. canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1]
  454. treelet.append(tuple([str(i)] + canonkey_t))
  455. canonkey_l.update(Counter(treelet))
  456. # n-star patterns
  457. for i in range(3, 6):
  458. treelet = []
  459. for pattern in patterns[str(i) + 'star']:
  460. canonlist = []
  461. for leaf in pattern[1:]:
  462. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  463. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  464. canonlist.append(tuple((nlabels, elabels)))
  465. canonlist.sort()
  466. canonlist = list(chain.from_iterable(canonlist))
  467. canonkey_t = tuple(['d' if i == 5 else str(i * 2)] +
  468. [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  469. + canonlist)
  470. treelet.append(canonkey_t)
  471. canonkey_l.update(Counter(treelet))
  472. # pattern 7
  473. treelet = []
  474. for pattern in patterns['7']:
  475. canonlist = []
  476. for leaf in pattern[1:3]:
  477. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  478. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  479. canonlist.append(tuple((nlabels, elabels)))
  480. canonlist.sort()
  481. canonlist = list(chain.from_iterable(canonlist))
  482. canonkey_t = tuple(['7']
  483. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist
  484. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  485. + [tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  486. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  487. + [tuple(G[pattern[4]][pattern[3]][el] for el in self.edge_labels)])
  488. treelet.append(canonkey_t)
  489. canonkey_l.update(Counter(treelet))
  490. # pattern 11
  491. treelet = []
  492. for pattern in patterns['11']:
  493. canonlist = []
  494. for leaf in pattern[1:4]:
  495. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  496. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  497. canonlist.append(tuple((nlabels, elabels)))
  498. canonlist.sort()
  499. canonlist = list(chain.from_iterable(canonlist))
  500. canonkey_t = tuple(['b']
  501. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist
  502. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  503. + [tuple(G[pattern[4]][pattern[0]][el] for el in self.edge_labels)]
  504. + [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels)]
  505. + [tuple(G[pattern[5]][pattern[4]][el] for el in self.edge_labels)])
  506. treelet.append(canonkey_t)
  507. canonkey_l.update(Counter(treelet))
  508. # pattern 10
  509. treelet = []
  510. for pattern in patterns['10']:
  511. canonkey4 = [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels),
  512. tuple(G[pattern[5]][pattern[4]][el] for el in self.edge_labels)]
  513. canonlist = []
  514. for leaf in pattern[1:3]:
  515. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  516. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  517. canonlist.append(tuple((nlabels, elabels)))
  518. canonlist.sort()
  519. canonkey0 = list(chain.from_iterable(canonlist))
  520. canonkey_t = tuple(['a']
  521. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  522. + [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels)]
  523. + [tuple(G[pattern[4]][pattern[3]][el] for el in self.edge_labels)]
  524. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  525. + [tuple(G[pattern[0]][pattern[3]][el] for el in self.edge_labels)]
  526. + canonkey4 + canonkey0)
  527. treelet.append(canonkey_t)
  528. canonkey_l.update(Counter(treelet))
  529. # pattern 12
  530. treelet = []
  531. for pattern in patterns['12']:
  532. canonlist0 = []
  533. for leaf in pattern[1:3]:
  534. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  535. elabels = tuple(G[leaf][pattern[0]][el] for el in self.edge_labels)
  536. canonlist0.append(tuple((nlabels, elabels)))
  537. canonlist0.sort()
  538. canonlist0 = list(chain.from_iterable(canonlist0))
  539. canonlist3 = []
  540. for leaf in pattern[4:6]:
  541. nlabels = tuple(G.nodes[leaf][nl] for nl in self.node_labels)
  542. elabels = tuple(G[leaf][pattern[3]][el] for el in self.edge_labels)
  543. canonlist3.append(tuple((nlabels, elabels)))
  544. canonlist3.sort()
  545. canonlist3 = list(chain.from_iterable(canonlist3))
  546. # 2 possible key can be generated from 2 nodes with extended label 3,
  547. # select the one with lower lexicographic order.
  548. canonkey_t1 = tuple(['c']
  549. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)] + canonlist0
  550. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)]
  551. + [tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  552. + canonlist3)
  553. canonkey_t2 = tuple(['c']
  554. + [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels)] + canonlist3
  555. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  556. + [tuple(G[pattern[0]][pattern[3]][el] for el in self.edge_labels)]
  557. + canonlist0)
  558. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  559. canonkey_l.update(Counter(treelet))
  560. # pattern 9
  561. treelet = []
  562. for pattern in patterns['9']:
  563. canonkey2 = [tuple(G.nodes[pattern[4]][nl] for nl in self.node_labels),
  564. tuple(G[pattern[4]][pattern[2]][el] for el in self.edge_labels)]
  565. canonkey3 = [tuple(G.nodes[pattern[5]][nl] for nl in self.node_labels),
  566. tuple(G[pattern[5]][pattern[3]][el] for el in self.edge_labels)]
  567. prekey2 = [tuple(G.nodes[pattern[2]][nl] for nl in self.node_labels),
  568. tuple(G[pattern[2]][pattern[0]][el] for el in self.edge_labels)]
  569. prekey3 = [tuple(G.nodes[pattern[3]][nl] for nl in self.node_labels),
  570. tuple(G[pattern[3]][pattern[0]][el] for el in self.edge_labels)]
  571. if prekey2 + canonkey2 < prekey3 + canonkey3:
  572. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self.node_labels)] \
  573. + [tuple(G[pattern[1]][pattern[0]][el] for el in self.edge_labels)] \
  574. + prekey2 + prekey3 + canonkey2 + canonkey3
  575. else:
  576. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self.node_labels)] \
  577. + [tuple(G[pattern[1]][pattern[0]][el] for el in self.edge_labels)] \
  578. + prekey3 + prekey2 + canonkey3 + canonkey2
  579. treelet.append(tuple(['9']
  580. + [tuple(G.nodes[pattern[0]][nl] for nl in self.node_labels)]
  581. + canonkey_t))
  582. canonkey_l.update(Counter(treelet))
  583. return canonkey_l
  584. return canonkey
  585. def _wrapper_get_canonkeys(self, itr_item):
  586. g = itr_item[0]
  587. i = itr_item[1]
  588. return i, self._get_canonkeys(g)
  589. def _add_dummy_labels(self, Gn=None):
  590. def _add_dummy(Gn):
  591. if len(self.node_labels) == 0 or (len(self.node_labels) == 1 and self.node_labels[0] == SpecialLabel.DUMMY):
  592. for i in range(len(Gn)):
  593. nx.set_node_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  594. self.node_labels = [SpecialLabel.DUMMY]
  595. if len(self.edge_labels) == 0 or (len(self.edge_labels) == 1 and self.edge_labels[0] == SpecialLabel.DUMMY):
  596. for i in range(len(Gn)):
  597. nx.set_edge_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  598. self.edge_labels = [SpecialLabel.DUMMY]
  599. if Gn is None or Gn is self._graphs:
  600. # Add dummy labels for the copy of self._graphs.
  601. try:
  602. check_is_fitted(self, ['_dummy_labels_considered'])
  603. if not self._dummy_labels_considered:
  604. Gn = self._graphs # @todo: ?[g.copy() for g in self._graphs]
  605. _add_dummy(Gn)
  606. self._graphs = Gn
  607. self._dummy_labels_considered = True
  608. except NotFittedError:
  609. Gn = self._graphs # @todo: ?[g.copy() for g in self._graphs]
  610. _add_dummy(Gn)
  611. self._graphs = Gn
  612. self._dummy_labels_considered = True
  613. else:
  614. # Add dummy labels for the input.
  615. _add_dummy(Gn)

A Python package for graph kernels, graph edit distances and graph pre-image problem.