You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treelet.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Apr 13 18:02:46 2020
  5. @author: ljia
  6. @references:
  7. [1] Gaüzère B, Brun L, Villemin D. Two new graphs kernels in
  8. chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47.
  9. """
  10. import sys
  11. from multiprocessing import Pool
  12. from tqdm import tqdm
  13. import numpy as np
  14. import networkx as nx
  15. from collections import Counter
  16. from itertools import chain
  17. from gklearn.utils import SpecialLabel
  18. from gklearn.utils.parallel import parallel_gm, parallel_me
  19. from gklearn.utils.utils import find_all_paths, get_mlti_dim_node_attrs
  20. from gklearn.kernels import GraphKernel
  21. class Treelet(GraphKernel):
  22. def __init__(self, **kwargs):
  23. GraphKernel.__init__(self)
  24. self._node_labels = kwargs.get('node_labels', [])
  25. self._edge_labels = kwargs.get('edge_labels', [])
  26. self._sub_kernel = kwargs.get('sub_kernel', None)
  27. self._ds_infos = kwargs.get('ds_infos', {})
  28. if self._sub_kernel is None:
  29. raise Exception('Sub kernel not set.')
  30. def _compute_gm_series(self):
  31. self._add_dummy_labels(self._graphs)
  32. # get all canonical keys of all graphs before computing kernels to save
  33. # time, but this may cost a lot of memory for large dataset.
  34. canonkeys = []
  35. if self._verbose >= 2:
  36. iterator = tqdm(self._graphs, desc='getting canonkeys', file=sys.stdout)
  37. else:
  38. iterator = self._graphs
  39. for g in iterator:
  40. canonkeys.append(self._get_canonkeys(g))
  41. # compute Gram matrix.
  42. gram_matrix = np.zeros((len(self._graphs), len(self._graphs)))
  43. from itertools import combinations_with_replacement
  44. itr = combinations_with_replacement(range(0, len(self._graphs)), 2)
  45. if self._verbose >= 2:
  46. iterator = tqdm(itr, desc='Computing kernels', file=sys.stdout)
  47. else:
  48. iterator = itr
  49. for i, j in iterator:
  50. kernel = self._kernel_do(canonkeys[i], canonkeys[j])
  51. gram_matrix[i][j] = kernel
  52. gram_matrix[j][i] = kernel # @todo: no directed graph considered?
  53. return gram_matrix
  54. def _compute_gm_imap_unordered(self):
  55. self._add_dummy_labels(self._graphs)
  56. # get all canonical keys of all graphs before computing kernels to save
  57. # time, but this may cost a lot of memory for large dataset.
  58. pool = Pool(self._n_jobs)
  59. itr = zip(self._graphs, range(0, len(self._graphs)))
  60. if len(self._graphs) < 100 * self._n_jobs:
  61. chunksize = int(len(self._graphs) / self._n_jobs) + 1
  62. else:
  63. chunksize = 100
  64. canonkeys = [[] for _ in range(len(self._graphs))]
  65. get_fun = self._wrapper_get_canonkeys
  66. if self._verbose >= 2:
  67. iterator = tqdm(pool.imap_unordered(get_fun, itr, chunksize),
  68. desc='getting canonkeys', file=sys.stdout)
  69. else:
  70. iterator = pool.imap_unordered(get_fun, itr, chunksize)
  71. for i, ck in iterator:
  72. canonkeys[i] = ck
  73. pool.close()
  74. pool.join()
  75. # compute Gram matrix.
  76. gram_matrix = np.zeros((len(self._graphs), len(self._graphs)))
  77. def init_worker(canonkeys_toshare):
  78. global G_canonkeys
  79. G_canonkeys = canonkeys_toshare
  80. do_fun = self._wrapper_kernel_do
  81. parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=init_worker,
  82. glbv=(canonkeys,), n_jobs=self._n_jobs, verbose=self._verbose)
  83. return gram_matrix
  84. def _compute_kernel_list_series(self, g1, g_list):
  85. self._add_dummy_labels(g_list + [g1])
  86. # get all canonical keys of all graphs before computing kernels to save
  87. # time, but this may cost a lot of memory for large dataset.
  88. canonkeys_1 = self._get_canonkeys(g1)
  89. canonkeys_list = []
  90. if self._verbose >= 2:
  91. iterator = tqdm(g_list, desc='getting canonkeys', file=sys.stdout)
  92. else:
  93. iterator = g_list
  94. for g in iterator:
  95. canonkeys_list.append(self._get_canonkeys(g))
  96. # compute kernel list.
  97. kernel_list = [None] * len(g_list)
  98. if self._verbose >= 2:
  99. iterator = tqdm(range(len(g_list)), desc='Computing kernels', file=sys.stdout)
  100. else:
  101. iterator = range(len(g_list))
  102. for i in iterator:
  103. kernel = self._kernel_do(canonkeys_1, canonkeys_list[i])
  104. kernel_list[i] = kernel
  105. return kernel_list
  106. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  107. self._add_dummy_labels(g_list + [g1])
  108. # get all canonical keys of all graphs before computing kernels to save
  109. # time, but this may cost a lot of memory for large dataset.
  110. canonkeys_1 = self._get_canonkeys(g1)
  111. canonkeys_list = [[] for _ in range(len(g_list))]
  112. pool = Pool(self._n_jobs)
  113. itr = zip(g_list, range(0, len(g_list)))
  114. if len(g_list) < 100 * self._n_jobs:
  115. chunksize = int(len(g_list) / self._n_jobs) + 1
  116. else:
  117. chunksize = 100
  118. get_fun = self._wrapper_get_canonkeys
  119. if self._verbose >= 2:
  120. iterator = tqdm(pool.imap_unordered(get_fun, itr, chunksize),
  121. desc='getting canonkeys', file=sys.stdout)
  122. else:
  123. iterator = pool.imap_unordered(get_fun, itr, chunksize)
  124. for i, ck in iterator:
  125. canonkeys_list[i] = ck
  126. pool.close()
  127. pool.join()
  128. # compute kernel list.
  129. kernel_list = [None] * len(g_list)
  130. def init_worker(ck_1_toshare, ck_list_toshare):
  131. global G_ck_1, G_ck_list
  132. G_ck_1 = ck_1_toshare
  133. G_ck_list = ck_list_toshare
  134. do_fun = self._wrapper_kernel_list_do
  135. def func_assign(result, var_to_assign):
  136. var_to_assign[result[0]] = result[1]
  137. itr = range(len(g_list))
  138. len_itr = len(g_list)
  139. parallel_me(do_fun, func_assign, kernel_list, itr, len_itr=len_itr,
  140. init_worker=init_worker, glbv=(canonkeys_1, canonkeys_list), method='imap_unordered',
  141. n_jobs=self._n_jobs, itr_desc='Computing kernels', verbose=self._verbose)
  142. return kernel_list
  143. def _wrapper_kernel_list_do(self, itr):
  144. return itr, self._kernel_do(G_ck_1, G_ck_list[itr])
  145. def _compute_single_kernel_series(self, g1, g2):
  146. self._add_dummy_labels([g1] + [g2])
  147. canonkeys_1 = self._get_canonkeys(g1)
  148. canonkeys_2 = self._get_canonkeys(g2)
  149. kernel = self._kernel_do(canonkeys_1, canonkeys_2)
  150. return kernel
  151. def _kernel_do(self, canonkey1, canonkey2):
  152. """Compute treelet graph kernel between 2 graphs.
  153. Parameters
  154. ----------
  155. canonkey1, canonkey2 : list
  156. List of canonical keys in 2 graphs, where each key is represented by a string.
  157. Return
  158. ------
  159. kernel : float
  160. Treelet kernel between 2 graphs.
  161. """
  162. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  163. vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys])
  164. vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys])
  165. kernel = self._sub_kernel(vector1, vector2)
  166. return kernel
  167. def _wrapper_kernel_do(self, itr):
  168. i = itr[0]
  169. j = itr[1]
  170. return i, j, self._kernel_do(G_canonkeys[i], G_canonkeys[j])
  171. def _get_canonkeys(self, G):
  172. """Generate canonical keys of all treelets in a graph.
  173. Parameters
  174. ----------
  175. G : NetworkX graphs
  176. The graph in which keys are generated.
  177. Return
  178. ------
  179. canonkey/canonkey_l : dict
  180. For unlabeled graphs, canonkey is a dictionary which records amount of
  181. every tree pattern. For labeled graphs, canonkey_l is one which keeps
  182. track of amount of every treelet.
  183. """
  184. patterns = {} # a dictionary which consists of lists of patterns for all graphlet.
  185. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern.
  186. ### structural analysis ###
  187. ### In this section, a list of patterns is generated for each graphlet,
  188. ### where every pattern is represented by nodes ordered by Morgan's
  189. ### extended labeling.
  190. # linear patterns
  191. patterns['0'] = list(G.nodes())
  192. canonkey['0'] = nx.number_of_nodes(G)
  193. for i in range(1, 6): # for i in range(1, 6):
  194. patterns[str(i)] = find_all_paths(G, i, self._ds_infos['directed'])
  195. canonkey[str(i)] = len(patterns[str(i)])
  196. # n-star patterns
  197. patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3]
  198. patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4]
  199. patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5]
  200. # n-star patterns
  201. canonkey['6'] = len(patterns['3star'])
  202. canonkey['8'] = len(patterns['4star'])
  203. canonkey['d'] = len(patterns['5star'])
  204. # pattern 7
  205. patterns['7'] = [] # the 1st line of Table 1 in Ref [1]
  206. for pattern in patterns['3star']:
  207. for i in range(1, len(pattern)): # for each neighbor of node 0
  208. if G.degree(pattern[i]) >= 2:
  209. pattern_t = pattern[:]
  210. # set the node with degree >= 2 as the 4th node
  211. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  212. for neighborx in G[pattern[i]]:
  213. if neighborx != pattern[0]:
  214. new_pattern = pattern_t + [neighborx]
  215. patterns['7'].append(new_pattern)
  216. canonkey['7'] = len(patterns['7'])
  217. # pattern 11
  218. patterns['11'] = [] # the 4th line of Table 1 in Ref [1]
  219. for pattern in patterns['4star']:
  220. for i in range(1, len(pattern)):
  221. if G.degree(pattern[i]) >= 2:
  222. pattern_t = pattern[:]
  223. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  224. for neighborx in G[pattern[i]]:
  225. if neighborx != pattern[0]:
  226. new_pattern = pattern_t + [neighborx]
  227. patterns['11'].append(new_pattern)
  228. canonkey['b'] = len(patterns['11'])
  229. # pattern 12
  230. patterns['12'] = [] # the 5th line of Table 1 in Ref [1]
  231. rootlist = [] # a list of root nodes, whose extended labels are 3
  232. for pattern in patterns['3star']:
  233. if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes
  234. rootlist.append(pattern[0])
  235. for i in range(1, len(pattern)):
  236. if G.degree(pattern[i]) >= 3:
  237. rootlist.append(pattern[i])
  238. pattern_t = pattern[:]
  239. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  240. for neighborx1 in G[pattern[i]]:
  241. if neighborx1 != pattern[0]:
  242. for neighborx2 in G[pattern[i]]:
  243. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  244. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  245. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  246. patterns['12'].append(new_pattern)
  247. canonkey['c'] = int(len(patterns['12']) / 2)
  248. # pattern 9
  249. patterns['9'] = [] # the 2nd line of Table 1 in Ref [1]
  250. for pattern in patterns['3star']:
  251. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  252. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2]:
  253. pattern_t = pattern[:]
  254. # move nodes with extended labels 4 to specific position to correspond to their children
  255. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  256. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  257. for neighborx1 in G[pairs[0]]:
  258. if neighborx1 != pattern[0]:
  259. for neighborx2 in G[pairs[1]]:
  260. if neighborx2 != pattern[0]:
  261. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  262. patterns['9'].append(new_pattern)
  263. canonkey['9'] = len(patterns['9'])
  264. # pattern 10
  265. patterns['10'] = [] # the 3rd line of Table 1 in Ref [1]
  266. for pattern in patterns['3star']:
  267. for i in range(1, len(pattern)):
  268. if G.degree(pattern[i]) >= 2:
  269. for neighborx in G[pattern[i]]:
  270. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  271. pattern_t = pattern[:]
  272. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  273. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  274. patterns['10'].extend(new_patterns)
  275. canonkey['a'] = len(patterns['10'])
  276. ### labeling information ###
  277. ### In this section, a list of canonical keys is generated for every
  278. ### pattern obtained in the structural analysis section above, which is a
  279. ### string corresponding to a unique treelet. A dictionary is built to keep
  280. ### track of the amount of every treelet.
  281. if len(self._node_labels) > 0 or len(self._edge_labels) > 0:
  282. canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet.
  283. # linear patterns
  284. canonkey_t = Counter(get_mlti_dim_node_attrs(G, self._node_labels))
  285. for key in canonkey_t:
  286. canonkey_l[('0', key)] = canonkey_t[key]
  287. for i in range(1, 6): # for i in range(1, 6):
  288. treelet = []
  289. for pattern in patterns[str(i)]:
  290. canonlist = []
  291. for idx, node in enumerate(pattern[:-1]):
  292. canonlist.append(tuple(G.nodes[node][nl] for nl in self._node_labels))
  293. canonlist.append(tuple(G[node][pattern[idx+1]][el] for el in self._edge_labels))
  294. canonlist.append(tuple(G.nodes[pattern[-1]][nl] for nl in self._node_labels))
  295. canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1]
  296. treelet.append(tuple([str(i)] + canonkey_t))
  297. canonkey_l.update(Counter(treelet))
  298. # n-star patterns
  299. for i in range(3, 6):
  300. treelet = []
  301. for pattern in patterns[str(i) + 'star']:
  302. canonlist = []
  303. for leaf in pattern[1:]:
  304. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  305. elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels)
  306. canonlist.append(tuple((nlabels, elabels)))
  307. canonlist.sort()
  308. canonlist = list(chain.from_iterable(canonlist))
  309. canonkey_t = tuple(['d' if i == 5 else str(i * 2)] +
  310. [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)]
  311. + canonlist)
  312. treelet.append(canonkey_t)
  313. canonkey_l.update(Counter(treelet))
  314. # pattern 7
  315. treelet = []
  316. for pattern in patterns['7']:
  317. canonlist = []
  318. for leaf in pattern[1:3]:
  319. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  320. elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels)
  321. canonlist.append(tuple((nlabels, elabels)))
  322. canonlist.sort()
  323. canonlist = list(chain.from_iterable(canonlist))
  324. canonkey_t = tuple(['7']
  325. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist
  326. + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)]
  327. + [tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)]
  328. + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)]
  329. + [tuple(G[pattern[4]][pattern[3]][el] for el in self._edge_labels)])
  330. treelet.append(canonkey_t)
  331. canonkey_l.update(Counter(treelet))
  332. # pattern 11
  333. treelet = []
  334. for pattern in patterns['11']:
  335. canonlist = []
  336. for leaf in pattern[1:4]:
  337. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  338. elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels)
  339. canonlist.append(tuple((nlabels, elabels)))
  340. canonlist.sort()
  341. canonlist = list(chain.from_iterable(canonlist))
  342. canonkey_t = tuple(['b']
  343. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist
  344. + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)]
  345. + [tuple(G[pattern[4]][pattern[0]][el] for el in self._edge_labels)]
  346. + [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels)]
  347. + [tuple(G[pattern[5]][pattern[4]][el] for el in self._edge_labels)])
  348. treelet.append(canonkey_t)
  349. canonkey_l.update(Counter(treelet))
  350. # pattern 10
  351. treelet = []
  352. for pattern in patterns['10']:
  353. canonkey4 = [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels),
  354. tuple(G[pattern[5]][pattern[4]][el] for el in self._edge_labels)]
  355. canonlist = []
  356. for leaf in pattern[1:3]:
  357. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  358. elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels)
  359. canonlist.append(tuple((nlabels, elabels)))
  360. canonlist.sort()
  361. canonkey0 = list(chain.from_iterable(canonlist))
  362. canonkey_t = tuple(['a']
  363. + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)]
  364. + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)]
  365. + [tuple(G[pattern[4]][pattern[3]][el] for el in self._edge_labels)]
  366. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)]
  367. + [tuple(G[pattern[0]][pattern[3]][el] for el in self._edge_labels)]
  368. + canonkey4 + canonkey0)
  369. treelet.append(canonkey_t)
  370. canonkey_l.update(Counter(treelet))
  371. # pattern 12
  372. treelet = []
  373. for pattern in patterns['12']:
  374. canonlist0 = []
  375. for leaf in pattern[1:3]:
  376. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  377. elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels)
  378. canonlist0.append(tuple((nlabels, elabels)))
  379. canonlist0.sort()
  380. canonlist0 = list(chain.from_iterable(canonlist0))
  381. canonlist3 = []
  382. for leaf in pattern[4:6]:
  383. nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels)
  384. elabels = tuple(G[leaf][pattern[3]][el] for el in self._edge_labels)
  385. canonlist3.append(tuple((nlabels, elabels)))
  386. canonlist3.sort()
  387. canonlist3 = list(chain.from_iterable(canonlist3))
  388. # 2 possible key can be generated from 2 nodes with extended label 3,
  389. # select the one with lower lexicographic order.
  390. canonkey_t1 = tuple(['c']
  391. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist0
  392. + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)]
  393. + [tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)]
  394. + canonlist3)
  395. canonkey_t2 = tuple(['c']
  396. + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)] + canonlist3
  397. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)]
  398. + [tuple(G[pattern[0]][pattern[3]][el] for el in self._edge_labels)]
  399. + canonlist0)
  400. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  401. canonkey_l.update(Counter(treelet))
  402. # pattern 9
  403. treelet = []
  404. for pattern in patterns['9']:
  405. canonkey2 = [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels),
  406. tuple(G[pattern[4]][pattern[2]][el] for el in self._edge_labels)]
  407. canonkey3 = [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels),
  408. tuple(G[pattern[5]][pattern[3]][el] for el in self._edge_labels)]
  409. prekey2 = [tuple(G.nodes[pattern[2]][nl] for nl in self._node_labels),
  410. tuple(G[pattern[2]][pattern[0]][el] for el in self._edge_labels)]
  411. prekey3 = [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels),
  412. tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)]
  413. if prekey2 + canonkey2 < prekey3 + canonkey3:
  414. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self._node_labels)] \
  415. + [tuple(G[pattern[1]][pattern[0]][el] for el in self._edge_labels)] \
  416. + prekey2 + prekey3 + canonkey2 + canonkey3
  417. else:
  418. canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self._node_labels)] \
  419. + [tuple(G[pattern[1]][pattern[0]][el] for el in self._edge_labels)] \
  420. + prekey3 + prekey2 + canonkey3 + canonkey2
  421. treelet.append(tuple(['9']
  422. + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)]
  423. + canonkey_t))
  424. canonkey_l.update(Counter(treelet))
  425. return canonkey_l
  426. return canonkey
  427. def _wrapper_get_canonkeys(self, itr_item):
  428. g = itr_item[0]
  429. i = itr_item[1]
  430. return i, self._get_canonkeys(g)
  431. def _add_dummy_labels(self, Gn):
  432. if len(self._node_labels) == 0 or (len(self._node_labels) == 1 and self._node_labels[0] == SpecialLabel.DUMMY):
  433. for i in range(len(Gn)):
  434. nx.set_node_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  435. self._node_labels = [SpecialLabel.DUMMY]
  436. if len(self._edge_labels) == 0 or (len(self._edge_labels) == 1 and self._edge_labels[0] == SpecialLabel.DUMMY):
  437. for i in range(len(Gn)):
  438. nx.set_edge_attributes(Gn[i], '0', SpecialLabel.DUMMY)
  439. self._edge_labels = [SpecialLabel.DUMMY]

A Python package for graph kernels, graph edit distances and graph pre-image problem.