You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commonWalkKernel.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Thomas Gärtner, Peter Flach, and Stefan Wrobel. On graph kernels: Hardness results and efficient alternatives. Learning Theory and Kernel Machines, pages 129–143, 2003.
  5. """
  6. import sys
  7. import pathlib
  8. sys.path.insert(0, "../")
  9. import time
  10. from tqdm import tqdm
  11. from collections import Counter
  12. from itertools import product
  13. import networkx as nx
  14. import numpy as np
  15. from pygraph.utils.utils import direct_product
  16. from pygraph.utils.graphdataset import get_dataset_attributes
  17. def commonwalkkernel(*args,
  18. node_label='atom',
  19. edge_label='bond_type',
  20. n=None,
  21. weight=1,
  22. compute_method='exp'):
  23. """Calculate common walk graph kernels up to depth d between graphs.
  24. Parameters
  25. ----------
  26. Gn : List of NetworkX graph
  27. List of graphs between which the kernels are calculated.
  28. /
  29. G1, G2 : NetworkX graphs
  30. 2 graphs between which the kernel is calculated.
  31. node_label : string
  32. node attribute used as label. The default node label is atom.
  33. edge_label : string
  34. edge attribute used as label. The default edge label is bond_type.
  35. n : integer
  36. Longest length of walks.
  37. weight: integer
  38. Weight coefficient of different lengths of walks.
  39. compute_method : string
  40. Method used to compute walk kernel. The Following choices are available:
  41. 'direct' : direct product graph method, as shown in reference [1]. The time complexity is O(n^6) for unlabeled graphs with n vertices.
  42. 'brute' : brute force, simply search for all walks and compare them.
  43. Return
  44. ------
  45. Kmatrix : Numpy matrix
  46. Kernel matrix, each element of which is the path kernel up to d between 2 graphs.
  47. """
  48. compute_method = compute_method.lower()
  49. # arrange all graphs in a list
  50. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  51. Kmatrix = np.zeros((len(Gn), len(Gn)))
  52. ds_attrs = get_dataset_attributes(
  53. Gn,
  54. attr_names=['node_labeled', 'edge_labeled', 'is_directed'],
  55. node_label=node_label,
  56. edge_label=edge_label)
  57. if not ds_attrs['node_labeled']:
  58. for G in Gn:
  59. nx.set_node_attributes(G, '0', 'atom')
  60. if not ds_attrs['edge_labeled']:
  61. for G in Gn:
  62. nx.set_edge_attributes(G, '0', 'bond_type')
  63. start_time = time.time()
  64. # direct product graph method - exponential
  65. if compute_method == 'exp':
  66. pbar = tqdm(
  67. total=(1 + len(Gn)) * len(Gn) / 2,
  68. desc='calculating kernels',
  69. file=sys.stdout)
  70. for i in range(0, len(Gn)):
  71. for j in range(i, len(Gn)):
  72. Kmatrix[i][j] = _untilnwalkkernel_exp(Gn[i], Gn[j], node_label,
  73. edge_label, weight)
  74. Kmatrix[j][i] = Kmatrix[i][j]
  75. pbar.update(1)
  76. # direct product graph method - geometric
  77. if compute_method == 'geo':
  78. pbar = tqdm(
  79. total=(1 + len(Gn)) * len(Gn) / 2,
  80. desc='calculating kernels',
  81. file=sys.stdout)
  82. for i in range(0, len(Gn)):
  83. for j in range(i, len(Gn)):
  84. Kmatrix[i][j] = _untilnwalkkernel_geo(Gn[i], Gn[j], node_label,
  85. edge_label, weight)
  86. Kmatrix[j][i] = Kmatrix[i][j]
  87. pbar.update(1)
  88. # search all paths use brute force.
  89. elif compute_method == 'brute':
  90. n = int(n)
  91. # get all paths of all graphs before calculating kernels to save time, but this may cost a lot of memory for large dataset.
  92. all_walks = [
  93. find_all_walks_until_length(Gn[i], n, node_label, edge_label,
  94. labeled) for i in range(0, len(Gn))
  95. ]
  96. for i in range(0, len(Gn)):
  97. for j in range(i, len(Gn)):
  98. Kmatrix[i][j] = _untilnwalkkernel_brute(
  99. all_walks[i],
  100. all_walks[j],
  101. node_label=node_label,
  102. edge_label=edge_label,
  103. labeled=labeled)
  104. Kmatrix[j][i] = Kmatrix[i][j]
  105. run_time = time.time() - start_time
  106. print(
  107. "\n --- kernel matrix of common walk kernel of size %d built in %s seconds ---"
  108. % (len(Gn), run_time))
  109. return Kmatrix, run_time
  110. def _untilnwalkkernel_exp(G1, G2, node_label, edge_label, beta):
  111. """Calculate walk graph kernels up to n between 2 graphs using exponential series.
  112. Parameters
  113. ----------
  114. G1, G2 : NetworkX graph
  115. Graphs between which the kernel is calculated.
  116. node_label : string
  117. Node attribute used as label.
  118. edge_label : string
  119. Edge attribute used as label.
  120. beta: integer
  121. Weight.
  122. Return
  123. ------
  124. kernel : float
  125. Treelet Kernel between 2 graphs.
  126. """
  127. # get tensor product / direct product
  128. gp = direct_product(G1, G2, node_label, edge_label)
  129. A = nx.adjacency_matrix(gp).todense()
  130. # print(A)
  131. # from matplotlib import pyplot as plt
  132. # nx.draw_networkx(G1)
  133. # plt.show()
  134. # nx.draw_networkx(G2)
  135. # plt.show()
  136. # nx.draw_networkx(gp)
  137. # plt.show()
  138. # print(G1.nodes(data=True))
  139. # print(G2.nodes(data=True))
  140. # print(gp.nodes(data=True))
  141. # print(gp.edges(data=True))
  142. ew, ev = np.linalg.eig(A)
  143. # print('ew: ', ew)
  144. # print(ev)
  145. # T = np.matrix(ev)
  146. # print('T: ', T)
  147. # T = ev.I
  148. D = np.zeros((len(ew), len(ew)))
  149. for i in range(len(ew)):
  150. D[i][i] = np.exp(beta * ew[i])
  151. # print('D: ', D)
  152. # print('hshs: ', T.I * D * T)
  153. # print(np.exp(-2))
  154. # print(D)
  155. # print(np.exp(weight * D))
  156. # print(ev)
  157. # print(np.linalg.inv(ev))
  158. exp_D = ev * D * ev.I
  159. # print(exp_D)
  160. # print(np.exp(weight * A))
  161. # print('-------')
  162. return np.sum(exp_D.diagonal())
  163. def _untilnwalkkernel_geo(G1, G2, node_label, edge_label, gamma):
  164. """Calculate walk graph kernels up to n between 2 graphs using geometric series.
  165. Parameters
  166. ----------
  167. G1, G2 : NetworkX graph
  168. Graphs between which the kernel is calculated.
  169. node_label : string
  170. Node attribute used as label.
  171. edge_label : string
  172. Edge attribute used as label.
  173. gamma: integer
  174. Weight.
  175. Return
  176. ------
  177. kernel : float
  178. Treelet Kernel between 2 graphs.
  179. """
  180. # get tensor product / direct product
  181. gp = direct_product(G1, G2, node_label, edge_label)
  182. A = nx.adjacency_matrix(gp).todense()
  183. # print(A)
  184. # from matplotlib import pyplot as plt
  185. # nx.draw_networkx(G1)
  186. # plt.show()
  187. # nx.draw_networkx(G2)
  188. # plt.show()
  189. # nx.draw_networkx(gp)
  190. # plt.show()
  191. # print(G1.nodes(data=True))
  192. # print(G2.nodes(data=True))
  193. # print(gp.nodes(data=True))
  194. # print(gp.edges(data=True))
  195. ew, ev = np.linalg.eig(A)
  196. # print('ew: ', ew)
  197. # print(ev)
  198. # T = np.matrix(ev)
  199. # print('T: ', T)
  200. # T = ev.I
  201. D = np.zeros((len(ew), len(ew)))
  202. for i in range(len(ew)):
  203. D[i][i] = np.exp(beta * ew[i])
  204. # print('D: ', D)
  205. # print('hshs: ', T.I * D * T)
  206. # print(np.exp(-2))
  207. # print(D)
  208. # print(np.exp(weight * D))
  209. # print(ev)
  210. # print(np.linalg.inv(ev))
  211. exp_D = ev * D * ev.I
  212. # print(exp_D)
  213. # print(np.exp(weight * A))
  214. # print('-------')
  215. return np.sum(exp_D.diagonal())
  216. def _untilnwalkkernel_brute(walks1,
  217. walks2,
  218. node_label='atom',
  219. edge_label='bond_type',
  220. labeled=True):
  221. """Calculate walk graph kernels up to n between 2 graphs.
  222. Parameters
  223. ----------
  224. walks1, walks2 : list
  225. List of walks in 2 graphs, where for unlabeled graphs, each walk is represented by a list of nodes; while for labeled graphs, each walk is represented by a string consists of labels of nodes and edges on that walk.
  226. node_label : string
  227. node attribute used as label. The default node label is atom.
  228. edge_label : string
  229. edge attribute used as label. The default edge label is bond_type.
  230. labeled : boolean
  231. Whether the graphs are labeled. The default is True.
  232. Return
  233. ------
  234. kernel : float
  235. Treelet Kernel between 2 graphs.
  236. """
  237. counts_walks1 = dict(Counter(walks1))
  238. counts_walks2 = dict(Counter(walks2))
  239. all_walks = list(set(walks1 + walks2))
  240. vector1 = [(counts_walks1[walk] if walk in walks1 else 0)
  241. for walk in all_walks]
  242. vector2 = [(counts_walks2[walk] if walk in walks2 else 0)
  243. for walk in all_walks]
  244. kernel = np.dot(vector1, vector2)
  245. return kernel
  246. # this method find walks repetively, it could be faster.
  247. def find_all_walks_until_length(G,
  248. length,
  249. node_label='atom',
  250. edge_label='bond_type',
  251. labeled=True):
  252. """Find all walks with a certain maximum length in a graph. A recursive depth first search is applied.
  253. Parameters
  254. ----------
  255. G : NetworkX graphs
  256. The graph in which walks are searched.
  257. length : integer
  258. The maximum length of walks.
  259. node_label : string
  260. node attribute used as label. The default node label is atom.
  261. edge_label : string
  262. edge attribute used as label. The default edge label is bond_type.
  263. labeled : boolean
  264. Whether the graphs are labeled. The default is True.
  265. Return
  266. ------
  267. walk : list
  268. List of walks retrieved, where for unlabeled graphs, each walk is represented by a list of nodes; while for labeled graphs, each walk is represented by a string consists of labels of nodes and edges on that walk.
  269. """
  270. all_walks = []
  271. # @todo: in this way, the time complexity is close to N(d^n+d^(n+1)+...+1), which could be optimized to O(Nd^n)
  272. for i in range(0, length + 1):
  273. new_walks = find_all_walks(G, i)
  274. if new_walks == []:
  275. break
  276. all_walks.extend(new_walks)
  277. if labeled == True: # convert paths to strings
  278. walk_strs = []
  279. for walk in all_walks:
  280. strlist = [
  281. G.node[node][node_label] +
  282. G[node][walk[walk.index(node) + 1]][edge_label]
  283. for node in walk[:-1]
  284. ]
  285. walk_strs.append(''.join(strlist) + G.node[walk[-1]][node_label])
  286. return walk_strs
  287. return all_walks
  288. def find_walks(G, source_node, length):
  289. """Find all walks with a certain length those start from a source node. A recursive depth first search is applied.
  290. Parameters
  291. ----------
  292. G : NetworkX graphs
  293. The graph in which walks are searched.
  294. source_node : integer
  295. The number of the node from where all walks start.
  296. length : integer
  297. The length of walks.
  298. Return
  299. ------
  300. walk : list of list
  301. List of walks retrieved, where each walk is represented by a list of nodes.
  302. """
  303. return [[source_node]] if length == 0 else \
  304. [ [source_node] + walk for neighbor in G[source_node] \
  305. for walk in find_walks(G, neighbor, length - 1) ]
  306. def find_all_walks(G, length):
  307. """Find all walks with a certain length in a graph. A recursive depth first search is applied.
  308. Parameters
  309. ----------
  310. G : NetworkX graphs
  311. The graph in which walks are searched.
  312. length : integer
  313. The length of walks.
  314. Return
  315. ------
  316. walk : list of list
  317. List of walks retrieved, where each walk is represented by a list of nodes.
  318. """
  319. all_walks = []
  320. for node in G:
  321. all_walks.extend(find_walks(G, node, length))
  322. ### The following process is not carried out according to the original article
  323. # all_paths_r = [ path[::-1] for path in all_paths ]
  324. # # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  325. # for idx, path in enumerate(all_paths[:-1]):
  326. # for path2 in all_paths_r[idx+1::]:
  327. # if path == path2:
  328. # all_paths[idx] = []
  329. # break
  330. # return list(filter(lambda a: a != [], all_paths))
  331. return all_walks

A Python package for graph kernels, graph edit distances and graph pre-image problem.