You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treeletKernel.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. import sys
  2. import pathlib
  3. sys.path.insert(0, "../")
  4. import time
  5. from collections import Counter
  6. from itertools import chain
  7. import networkx as nx
  8. import numpy as np
  9. def find_paths(G, source_node, length):
  10. if length == 0:
  11. return [[source_node]]
  12. path = [ [source_node] + path for neighbor in G[source_node] \
  13. for path in find_paths(G, neighbor, length - 1) if source_node not in path ]
  14. return path
  15. def find_all_paths(G, length):
  16. all_paths = []
  17. for node in G:
  18. all_paths.extend(find_paths(G, node, length))
  19. all_paths_r = [ path[::-1] for path in all_paths ]
  20. # remove double direction
  21. for idx, path in enumerate(all_paths[:-1]):
  22. for path2 in all_paths_r[idx+1::]:
  23. if path == path2:
  24. all_paths[idx] = []
  25. break
  26. return list(filter(lambda a: a != [], all_paths))
  27. def get_canonkey(G, node_label = 'atom', edge_label = 'bond_type', labeled = True):
  28. patterns = {}
  29. canonkey = {} # canonical key
  30. ### structural analysis ###
  31. # linear patterns
  32. patterns['0'] = G.nodes()
  33. canonkey['0'] = nx.number_of_nodes(G)
  34. for i in range(1, 6):
  35. patterns[str(i)] = find_all_paths(G, i)
  36. canonkey[str(i)] = len(patterns[str(i)])
  37. # n-star patterns
  38. patterns['3star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3 ]
  39. patterns['4star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4 ]
  40. patterns['5star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5 ]
  41. # n-star patterns
  42. canonkey['6'] = len(patterns['3star'])
  43. canonkey['8'] = len(patterns['4star'])
  44. canonkey['d'] = len(patterns['5star'])
  45. # pattern 7
  46. patterns['7'] = []
  47. for pattern in patterns['3star']:
  48. for i in range(1, len(pattern)):
  49. if G.degree(pattern[i]) >= 2:
  50. pattern_t = pattern[:]
  51. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  52. for neighborx in G[pattern[i]]:
  53. if neighborx != pattern[0]:
  54. new_pattern = pattern_t + [ neighborx ]
  55. patterns['7'].append(new_pattern)
  56. canonkey['7'] = len(patterns['7'])
  57. # pattern 11
  58. patterns['11'] = []
  59. for pattern in patterns['4star']:
  60. for i in range(1, len(pattern)):
  61. if G.degree(pattern[i]) >= 2:
  62. pattern_t = pattern[:]
  63. pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]
  64. for neighborx in G[pattern[i]]:
  65. if neighborx != pattern[0]:
  66. new_pattern = pattern_t + [ neighborx ]
  67. patterns['11'].append(new_pattern)
  68. canonkey['b'] = len(patterns['11'])
  69. # pattern 12
  70. patterns['12'] = []
  71. rootlist = []
  72. for pattern in patterns['3star']:
  73. if pattern[0] not in rootlist:
  74. rootlist.append(pattern[0])
  75. for i in range(1, len(pattern)):
  76. if G.degree(pattern[i]) >= 3:
  77. rootlist.append(pattern[i])
  78. pattern_t = pattern[:]
  79. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  80. for neighborx1 in G[pattern[i]]:
  81. if neighborx1 != pattern[0]:
  82. for neighborx2 in G[pattern[i]]:
  83. if neighborx1 > neighborx2 and neighborx2 != pattern[0]:
  84. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  85. # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]
  86. patterns['12'].append(new_pattern)
  87. canonkey['c'] = int(len(patterns['12']) / 2)
  88. # pattern 9
  89. patterns['9'] = []
  90. for pattern in patterns['3star']:
  91. for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \
  92. for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]:
  93. pattern_t = pattern[:]
  94. pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]
  95. pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]
  96. for neighborx1 in G[pairs[0]]:
  97. if neighborx1 != pattern[0]:
  98. for neighborx2 in G[pairs[1]]:
  99. if neighborx2 != pattern[0]:
  100. new_pattern = pattern_t + [neighborx1] + [neighborx2]
  101. patterns['9'].append(new_pattern)
  102. canonkey['9'] = len(patterns['9'])
  103. # pattern 10
  104. patterns['10'] = []
  105. for pattern in patterns['3star']:
  106. for i in range(1, len(pattern)):
  107. if G.degree(pattern[i]) >= 2:
  108. for neighborx in G[pattern[i]]:
  109. if neighborx != pattern[0] and G.degree(neighborx) >= 2:
  110. pattern_t = pattern[:]
  111. pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]
  112. new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]
  113. patterns['10'].extend(new_patterns)
  114. canonkey['a'] = len(patterns['10'])
  115. ### labeling information ###
  116. if labeled == True:
  117. canonkey_l = {}
  118. # linear patterns
  119. canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values()))
  120. for key in canonkey_t:
  121. canonkey_l['0' + key] = canonkey_t[key]
  122. for i in range(1, 6):
  123. treelet = []
  124. for pattern in patterns[str(i)]:
  125. canonlist = list(chain.from_iterable((G.node[node][node_label], \
  126. G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1])))
  127. canonlist.append(G.node[pattern[-1]][node_label])
  128. canonkey_t = ''.join(canonlist)
  129. canonkey_t = canonkey_t if canonkey_t < canonkey_t[::-1] else canonkey_t[::-1]
  130. treelet.append(str(i) + canonkey_t)
  131. canonkey_l.update(Counter(treelet))
  132. # n-star patterns
  133. for i in range(3, 6):
  134. treelet = []
  135. for pattern in patterns[str(i) + 'star']:
  136. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:] ]
  137. canonlist.sort()
  138. canonkey_t = ('d' if i == 5 else str(i * 2)) + G.node[pattern[0]][node_label] + ''.join(canonlist)
  139. treelet.append(canonkey_t)
  140. canonkey_l.update(Counter(treelet))
  141. # pattern 7
  142. treelet = []
  143. for pattern in patterns['7']:
  144. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  145. canonlist.sort()
  146. canonkey_t = '7' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  147. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  148. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]
  149. treelet.append(canonkey_t)
  150. canonkey_l.update(Counter(treelet))
  151. # pattern 11
  152. treelet = []
  153. for pattern in patterns['11']:
  154. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:4] ]
  155. canonlist.sort()
  156. canonkey_t = 'b' + G.node[pattern[0]][node_label] + ''.join(canonlist) \
  157. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[0]][edge_label] \
  158. + G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  159. treelet.append(canonkey_t)
  160. canonkey_l.update(Counter(treelet))
  161. # pattern 10
  162. treelet = []
  163. for pattern in patterns['10']:
  164. canonkey4 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]
  165. canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  166. canonlist.sort()
  167. canonkey0 = ''.join(canonlist)
  168. canonkey_t = 'a' + G.node[pattern[3]][node_label] \
  169. + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label] \
  170. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  171. + canonkey4 + canonkey0
  172. treelet.append(canonkey_t)
  173. canonkey_l.update(Counter(treelet))
  174. # pattern 12
  175. treelet = []
  176. for pattern in patterns['12']:
  177. canonlist0 = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]
  178. canonlist0.sort()
  179. canonlist3 = [ G.node[leaf][node_label] + G[leaf][pattern[3]][edge_label] for leaf in pattern[4:6] ]
  180. canonlist3.sort()
  181. canonkey_t1 = 'c' + G.node[pattern[0]][node_label] \
  182. + ''.join(canonlist0) \
  183. + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \
  184. + ''.join(canonlist3)
  185. canonkey_t2 = 'c' + G.node[pattern[3]][node_label] \
  186. + ''.join(canonlist3) \
  187. + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \
  188. + ''.join(canonlist0)
  189. treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)
  190. canonkey_l.update(Counter(treelet))
  191. # pattern 9
  192. treelet = []
  193. for pattern in patterns['9']:
  194. canonkey2 = G.node[pattern[4]][node_label] + G[pattern[4]][pattern[2]][edge_label]
  195. canonkey3 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[3]][edge_label]
  196. prekey2 = G.node[pattern[2]][node_label] + G[pattern[2]][pattern[0]][edge_label]
  197. prekey3 = G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label]
  198. if prekey2 + canonkey2 < prekey3 + canonkey3:
  199. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  200. + prekey2 + prekey3 + canonkey2 + canonkey3
  201. else:
  202. canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \
  203. + prekey3 + prekey2 + canonkey3 + canonkey2
  204. treelet.append('9' + G.node[pattern[0]][node_label] + canonkey_t)
  205. canonkey_l.update(Counter(treelet))
  206. return canonkey_l
  207. return canonkey
  208. def treeletkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True):
  209. if len(args) == 1: # for a list of graphs
  210. Gn = args[0]
  211. Kmatrix = np.zeros((len(Gn), len(Gn)))
  212. start_time = time.time()
  213. for i in range(0, len(Gn)):
  214. for j in range(i, len(Gn)):
  215. Kmatrix[i][j] = treeletkernel(Gn[i], Gn[j], labeled = labeled, node_label = node_label, edge_label = edge_label)
  216. Kmatrix[j][i] = Kmatrix[i][j]
  217. run_time = time.time() - start_time
  218. print("\n --- treelet kernel matrix of size %d built in %s seconds ---" % (len(Gn), run_time))
  219. return Kmatrix, run_time
  220. else: # for only 2 graphs
  221. G1 = args[0]
  222. G = args[1]
  223. kernel = 0
  224. # start_time = time.time()
  225. canonkey2 = get_canonkey(G, node_label = node_label, edge_label = edge_label, labeled = labeled)
  226. canonkey1 = get_canonkey(G1, node_label = node_label, edge_label = edge_label, labeled = labeled)
  227. keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
  228. vector1 = np.matrix([ (canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys ])
  229. vector2 = np.matrix([ (canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys ])
  230. kernel = np.sum(np.exp(- np.square(vector1 - vector2) / 2))
  231. # run_time = time.time() - start_time
  232. # print("\n --- treelet kernel built in %s seconds ---" % (run_time))
  233. return kernel#, run_time

A Python package for graph kernels, graph edit distances and graph pre-image problem.