You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treeletKernel.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. @author: linlin
  3. @references: Gaüzère B, Brun L, Villemin D. Two new graphs kernels in chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. import time
  9. from collections import Counter
  10. from itertools import chain
  11. import networkx as nx
  12. import numpy as np
  13. def treeletkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True):
  14. """Calculate treelet graph kernels between graphs.
  15. Parameters
  16. ----------
  17. Gn : List of NetworkX graph
  18. List of graphs between which the kernels are calculated.
  19. /
  20. G1, G2 : NetworkX graphs
  21. 2 graphs between which the kernel is calculated.
  22. node_label : string
  23. node attribute used as label. The default node label is atom.
  24. edge_label : string
  25. edge attribute used as label. The default edge label is bond_type.
  26. labeled : boolean
  27. Whether the graphs are labeled. The default is True.
  28. Return
  29. ------
  30. Kmatrix/kernel : Numpy matrix/float
  31. Kernel matrix, each element of which is the treelet kernel between 2 praphs. / Treelet kernel between 2 graphs.
  32. """
  33. if len(args) == 1: # for a list of graphs
  34. Gn = args[0]
  35. Kmatrix = np.zeros((len(Gn), len(Gn)))
  36. start_time = time.time()
  37. # get all canonical keys of all graphs before calculating kernels to save time, but this may cost a lot of memory for large dataset.
  38. canonkeys = [ get_canonkeys(Gn[i], node_label = node_label, edge_label = edge_label, labeled = labeled) \
  39. for i in range(0, len(Gn)) ]
  40. for i in range(0, len(Gn)):
  41. for j in range(i, len(Gn)):
  42. Kmatrix[i][j] = _treeletkernel_do(canonkeys[i], canonkeys[j], node_label = node_label, edge_label = edge_label, labeled = labeled)
  43. Kmatrix[j][i] = Kmatrix[i][j]
  44. run_time = time.time() - start_time
  45. print("\n --- treelet kernel matrix of size %d built in %s seconds ---" % (len(Gn), run_time))
  46. return Kmatrix, run_time
  47. else: # for only 2 graphs
  48. start_time = time.time()
  49. canonkey1 = get_canonkeys(args[0], node_label = node_label, edge_label = edge_label, labeled = labeled)
  50. canonkey2 = get_canonkeys(args[1], node_label = node_label, edge_label = edge_label, labeled = labeled)
  51. kernel = _treeletkernel_do(canonkey1, canonkey2, node_label = node_label, edge_label = edge_label, labeled = labeled)
  52. run_time = time.time() - start_time
  53. print("\n --- treelet kernel built in %s seconds ---" % (run_time))
  54. return kernel, run_time
  55. def _treeletkernel_do(canonkey1, canonkey2, node_label = 'atom', edge_label = 'bond_type', labeled = True):
  56. """Calculate treelet graph kernel between 2 graphs.
  57. Parameters
  58. ----------
  59. canonkey1, canonkey2 : list
  60. List of canonical keys in 2 graphs, where each key is represented by a string.
  61. node_label : string
  62. Node attribute used as label. The default node label is atom.
  63. edge_label : string
  64. Edge attribute used as label. The default edge label is bond_type.
  65. labeled : boolean
  66. Whether the graphs are labeled. The default is True.
  67. Return
  68. ------
  69. kernel : float
  70. Treelet Kernel between 2 graphs.
  71. """
  72. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  73. vector1 = np.array([ (canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys ])
  74. vector2 = np.array([ (canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys ])
  75. kernel = np.sum(np.exp(- np.square(vector1 - vector2) / 2))
  76. return kernel
  77. def get_canonkeys(G, node_label = 'atom', edge_label = 'bond_type', labeled = True):
  78. """Generate canonical keys of all treelets in a graph.
  79. Parameters
  80. ----------
  81. G : NetworkX graphs
  82. The graph in which keys are generated.
  83. node_label : string
  84. node attribute used as label. The default node label is atom.
  85. edge_label : string
  86. edge attribute used as label. The default edge label is bond_type.
  87. labeled : boolean
  88. Whether the graphs are labeled. The default is True.
  89. Return
  90. ------
  91. canonkey/canonkey_l : dict
  92. For unlabeled graphs, canonkey is a dictionary which records amount of every tree pattern. For labeled graphs, canonkey_l is one which keeps track of amount of every treelet.
  93. """
  94. patterns = {} # a dictionary which consists of lists of patterns for all graphlet.
  95. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern.
  96. ### structural analysis ###
  97. ### In this section, a list of patterns is generated for each graphlet, where every pattern is represented by nodes ordered by
  98. ### Morgan's extended labeling.
  99. # linear patterns
  100. patterns['0'] = G.nodes()
  101. canonkey['0'] = nx.number_of_nodes(G)
  102. for i in range(1, 6): # for i in range(1, 6):
  103. patterns[str(i)] = find_all_paths(G, i)
  104. canonkey[str(i)] = len(patterns[str(i)])
  105. # n-star patterns
  106. patterns['3star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3 ]
  107. patterns['4star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4 ]
  108. patterns['5star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5 ]
  109. # n-star patterns
  110. canonkey['6'] = len(patterns['3star'])
  111. canonkey['8'] = len(patterns['4star'])
  112. canonkey['d'] = len(patterns['5star'])
  113. # pattern 7
  114. patterns['7'] = [] # the 1st line of Table 1 in Ref [1]
  115. for pattern in patterns['3star']:
  116. for i in range(1, len(pattern)): # for each neighbor of node 0
  117. if G.degree(pattern[i]) >= 2:
  118. pattern_t = pattern[:]
  119. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] # set the node with degree >= 2 as the 4th node
  120. for neighborx in G[pattern[i]]:
  121. if neighborx != pattern[0]:
  122. new_pattern = pattern_t + [ neighborx ]
  123. patterns['7'].append(new_pattern)
  124. canonkey['7'] = len(patterns['7'])
  125. # pattern 11
  126. patterns['11'] = [] # the 4th line of Table 1 in Ref [1]
  127. for pattern in patterns['4star']:
  128. for i in range(1, len(pattern)):
  129. if G.degree(pattern[i]) >= 2:
  130. pattern_t = pattern[:]
  131. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  132. for neighborx in G[pattern[i]]:
  133. if neighborx != pattern[0]:
  134. new_pattern = pattern_t + [ neighborx ]
  135. patterns['11'].append(new_pattern)
  136. canonkey['b'] = len(patterns['11'])
  137. # pattern 12
  138. patterns['12'] = [] # the 5th line of Table 1 in Ref [1]
  139. rootlist = [] # a list of root nodes, whose extended labels are 3
  140. for pattern in patterns['3star']:
  141. if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes
  142. rootlist.append(pattern[0])
  143. for i in range(1, len(pattern)):
  144. if G.degree(pattern[i]) >= 3:
  145. rootlist.append(pattern[i])
  146. pattern_t = pattern[:]
  147. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  148. for neighborx1 in G[pattern[i]]:
  149. if neighborx1 != pattern[0]:
  150. for neighborx2 in G[pattern[i]]:
  151. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  152. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  153. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  154. patterns['12'].append(new_pattern)
  155. canonkey['c'] = int(len(patterns['12']) / 2)
  156. # pattern 9
  157. patterns['9'] = [] # the 2nd line of Table 1 in Ref [1]
  158. for pattern in patterns['3star']:
  159. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  160. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]:
  161. pattern_t = pattern[:]
  162. # move nodes with extended labels 4 to specific position to correspond to their children
  163. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  164. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  165. for neighborx1 in G[pairs[0]]:
  166. if neighborx1 != pattern[0]:
  167. for neighborx2 in G[pairs[1]]:
  168. if neighborx2 != pattern[0]:
  169. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  170. patterns['9'].append(new_pattern)
  171. canonkey['9'] = len(patterns['9'])
  172. # pattern 10
  173. patterns['10'] = [] # the 3rd line of Table 1 in Ref [1]
  174. for pattern in patterns['3star']:
  175. for i in range(1, len(pattern)):
  176. if G.degree(pattern[i]) >= 2:
  177. for neighborx in G[pattern[i]]:
  178. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  179. pattern_t = pattern[:]
  180. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  181. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  182. patterns['10'].extend(new_patterns)
  183. canonkey['a'] = len(patterns['10'])
  184. ### labeling information ###
  185. ### In this section, a list of canonical keys is generated for every pattern obtained in the structural analysis
  186. ### section above, which is a string corresponding to a unique treelet. A dictionary is built to keep track of
  187. ### the amount of every treelet.
  188. if labeled == True:
  189. canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet.
  190. # linear patterns
  191. canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values()))
  192. for key in canonkey_t:
  193. canonkey_l['0' + key] = canonkey_t[key]
  194. for i in range(1, 6): # for i in range(1, 6):
  195. treelet = []
  196. for pattern in patterns[str(i)]:
  197. canonlist = list(chain.from_iterable((G.node[node][node_label], \
  198. G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1])))
  199. canonlist.append(G.node[pattern[-1]][node_label])
  200. canonkey_t = ''.join(canonlist)
  201. canonkey_t = canonkey_t if canonkey_t < canonkey_t[::-1] else canonkey_t[::-1]
  202. treelet.append(str(i) + canonkey_t)
  203. canonkey_l.update(Counter(treelet))
  204. # n-star patterns
  205. for i in range(3, 6):
  206. treelet = []
  207. for pattern in patterns[str(i) + 'star']:
  208. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:] ]
  209. canonlist.sort()
  210. canonkey_t = ('d' if i == 5 else str(i * 2)) + G.node[pattern[0]][node_label] + ''.join(canonlist)
  211. treelet.append(canonkey_t)
  212. canonkey_l.update(Counter(treelet))
  213. # pattern 7
  214. treelet = []
  215. for pattern in patterns['7']:
  216. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  217. canonlist.sort()
  218. canonkey_t = '7' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  219. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  220. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]
  221. treelet.append(canonkey_t)
  222. canonkey_l.update(Counter(treelet))
  223. # pattern 11
  224. treelet = []
  225. for pattern in patterns['11']:
  226. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:4] ]
  227. canonlist.sort()
  228. canonkey_t = 'b' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  229. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[0]][edge_label] \
  230. + G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  231. treelet.append(canonkey_t)
  232. canonkey_l.update(Counter(treelet))
  233. # pattern 10
  234. treelet = []
  235. for pattern in patterns['10']:
  236. canonkey4 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  237. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  238. canonlist.sort()
  239. canonkey0 = ''.join(canonlist)
  240. canonkey_t = 'a' + G.node[pattern[3]][node_label] \
  241. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label] \
  242. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  243. + canonkey4 + canonkey0
  244. treelet.append(canonkey_t)
  245. canonkey_l.update(Counter(treelet))
  246. # pattern 12
  247. treelet = []
  248. for pattern in patterns['12']:
  249. canonlist0 = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  250. canonlist0.sort()
  251. canonlist3 = [ G.node[leaf][node_label] + G[leaf][pattern[3]][edge_label] for leaf in pattern[4:6] ]
  252. canonlist3.sort()
  253. # 2 possible key can be generated from 2 nodes with extended label 3, select the one with lower lexicographic order.
  254. canonkey_t1 = 'c' + G.node[pattern[0]][node_label] \
  255. + ''.join(canonlist0) \
  256. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  257. + ''.join(canonlist3)
  258. canonkey_t2 = 'c' + G.node[pattern[3]][node_label] \
  259. + ''.join(canonlist3) \
  260. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  261. + ''.join(canonlist0)
  262. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  263. canonkey_l.update(Counter(treelet))
  264. # pattern 9
  265. treelet = []
  266. for pattern in patterns['9']:
  267. canonkey2 = G.node[pattern[4]][node_label] + G[pattern[4]][pattern[2]][edge_label]
  268. canonkey3 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[3]][edge_label]
  269. prekey2 = G.node[pattern[2]][node_label] + G[pattern[2]][pattern[0]][edge_label]
  270. prekey3 = G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label]
  271. if prekey2 + canonkey2 < prekey3 + canonkey3:
  272. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  273. + prekey2 + prekey3 + canonkey2 + canonkey3
  274. else:
  275. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  276. + prekey3 + prekey2 + canonkey3 + canonkey2
  277. treelet.append('9' + G.node[pattern[0]][node_label] + canonkey_t)
  278. canonkey_l.update(Counter(treelet))
  279. return canonkey_l
  280. return canonkey
  281. def find_paths(G, source_node, length):
  282. """Find all paths with a certain length those start from a source node. A recursive depth first search is applied.
  283. Parameters
  284. ----------
  285. G : NetworkX graphs
  286. The graph in which paths are searched.
  287. source_node : integer
  288. The number of the node from where all paths start.
  289. length : integer
  290. The length of paths.
  291. Return
  292. ------
  293. path : list of list
  294. List of paths retrieved, where each path is represented by a list of nodes.
  295. """
  296. if length == 0:
  297. return [[source_node]]
  298. path = [ [source_node] + path for neighbor in G[source_node] \
  299. for path in find_paths(G, neighbor, length - 1) if source_node not in path ]
  300. return path
  301. def find_all_paths(G, length):
  302. """Find all paths with a certain length in a graph. A recursive depth first search is applied.
  303. Parameters
  304. ----------
  305. G : NetworkX graphs
  306. The graph in which paths are searched.
  307. length : integer
  308. The length of paths.
  309. Return
  310. ------
  311. path : list of list
  312. List of paths retrieved, where each path is represented by a list of nodes.
  313. """
  314. all_paths = []
  315. for node in G:
  316. all_paths.extend(find_paths(G, node, length))
  317. all_paths_r = [ path[::-1] for path in all_paths ]
  318. # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  319. for idx, path in enumerate(all_paths[:-1]):
  320. for path2 in all_paths_r[idx+1::]:
  321. if path == path2:
  322. all_paths[idx] = []
  323. break
  324. return list(filter(lambda a: a != [], all_paths))

A Python package for graph kernels, graph edit distances and graph pre-image problem.