You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commonWalkKernel.py 13 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Thomas Gärtner, Peter Flach, and Stefan Wrobel. On graph kernels:
  5. Hardness results and efficient alternatives. Learning Theory and Kernel
  6. Machines, pages 129–143, 2003.
  7. """
  8. import sys
  9. import time
  10. from collections import Counter
  11. from functools import partial
  12. import networkx as nx
  13. import numpy as np
  14. from gklearn.utils.utils import direct_product
  15. from gklearn.utils.graphdataset import get_dataset_attributes
  16. from gklearn.utils.parallel import parallel_gm
  17. def commonwalkkernel(*args,
  18. node_label='atom',
  19. edge_label='bond_type',
  20. # n=None,
  21. weight=1,
  22. compute_method=None,
  23. n_jobs=None,
  24. chunksize=None,
  25. verbose=True):
  26. """Compute common walk graph kernels between graphs.
  27. Parameters
  28. ----------
  29. Gn : List of NetworkX graph
  30. List of graphs between which the kernels are computed.
  31. G1, G2 : NetworkX graphs
  32. Two graphs between which the kernel is computed.
  33. node_label : string
  34. Node attribute used as symbolic label. The default node label is 'atom'.
  35. edge_label : string
  36. Edge attribute used as symbolic label. The default edge label is 'bond_type'.
  37. weight: integer
  38. Weight coefficient of different lengths of walks, which represents beta
  39. in 'exp' method and gamma in 'geo'.
  40. compute_method : string
  41. Method used to compute walk kernel. The Following choices are
  42. available:
  43. 'exp': method based on exponential serials applied on the direct
  44. product graph, as shown in reference [1]. The time complexity is O(n^6)
  45. for graphs with n vertices.
  46. 'geo': method based on geometric serials applied on the direct product
  47. graph, as shown in reference [1]. The time complexity is O(n^6) for
  48. graphs with n vertices.
  49. n_jobs : int
  50. Number of jobs for parallelization.
  51. Return
  52. ------
  53. Kmatrix : Numpy matrix
  54. Kernel matrix, each element of which is a common walk kernel between 2
  55. graphs.
  56. """
  57. # n : integer
  58. # Longest length of walks. Only useful when applying the 'brute' method.
  59. # 'brute': brute force, simply search for all walks and compare them.
  60. compute_method = compute_method.lower()
  61. # arrange all graphs in a list
  62. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  63. # remove graphs with only 1 node, as they do not have adjacency matrices
  64. len_gn = len(Gn)
  65. Gn = [(idx, G) for idx, G in enumerate(Gn) if nx.number_of_nodes(G) != 1]
  66. idx = [G[0] for G in Gn]
  67. Gn = [G[1] for G in Gn]
  68. if len(Gn) != len_gn:
  69. if verbose:
  70. print('\n %d graphs are removed as they have only 1 node.\n' %
  71. (len_gn - len(Gn)))
  72. ds_attrs = get_dataset_attributes(
  73. Gn,
  74. attr_names=['node_labeled', 'edge_labeled', 'is_directed'],
  75. node_label=node_label, edge_label=edge_label)
  76. if not ds_attrs['node_labeled']:
  77. for G in Gn:
  78. nx.set_node_attributes(G, '0', 'atom')
  79. if not ds_attrs['edge_labeled']:
  80. for G in Gn:
  81. nx.set_edge_attributes(G, '0', 'bond_type')
  82. if not ds_attrs['is_directed']: # convert
  83. Gn = [G.to_directed() for G in Gn]
  84. start_time = time.time()
  85. Kmatrix = np.zeros((len(Gn), len(Gn)))
  86. # ---- use pool.imap_unordered to parallel and track progress. ----
  87. def init_worker(gn_toshare):
  88. global G_gn
  89. G_gn = gn_toshare
  90. # direct product graph method - exponential
  91. if compute_method == 'exp':
  92. do_partial = partial(wrapper_cw_exp, node_label, edge_label, weight)
  93. # direct product graph method - geometric
  94. elif compute_method == 'geo':
  95. do_partial = partial(wrapper_cw_geo, node_label, edge_label, weight)
  96. parallel_gm(do_partial, Kmatrix, Gn, init_worker=init_worker,
  97. glbv=(Gn,), n_jobs=n_jobs, chunksize=chunksize, verbose=verbose)
  98. # pool = Pool(n_jobs)
  99. # itr = zip(combinations_with_replacement(Gn, 2),
  100. # combinations_with_replacement(range(0, len(Gn)), 2))
  101. # len_itr = int(len(Gn) * (len(Gn) + 1) / 2)
  102. # if len_itr < 1000 * n_jobs:
  103. # chunksize = int(len_itr / n_jobs) + 1
  104. # else:
  105. # chunksize = 1000
  106. #
  107. # # direct product graph method - exponential
  108. # if compute_method == 'exp':
  109. # do_partial = partial(wrapper_cw_exp, node_label, edge_label, weight)
  110. # # direct product graph method - geometric
  111. # elif compute_method == 'geo':
  112. # do_partial = partial(wrapper_cw_geo, node_label, edge_label, weight)
  113. #
  114. # for i, j, kernel in tqdm(
  115. # pool.imap_unordered(do_partial, itr, chunksize),
  116. # desc='computing kernels',
  117. # file=sys.stdout):
  118. # Kmatrix[i][j] = kernel
  119. # Kmatrix[j][i] = kernel
  120. # pool.close()
  121. # pool.join()
  122. # # ---- direct running, normally use single CPU core. ----
  123. # # direct product graph method - exponential
  124. # itr = combinations_with_replacement(range(0, len(Gn)), 2)
  125. # if compute_method == 'exp':
  126. # for i, j in tqdm(itr, desc='Computing kernels', file=sys.stdout):
  127. # Kmatrix[i][j] = _commonwalkkernel_exp(Gn[i], Gn[j], node_label,
  128. # edge_label, weight)
  129. # Kmatrix[j][i] = Kmatrix[i][j]
  130. #
  131. # # direct product graph method - geometric
  132. # elif compute_method == 'geo':
  133. # for i, j in tqdm(itr, desc='Computing kernels', file=sys.stdout):
  134. # Kmatrix[i][j] = _commonwalkkernel_geo(Gn[i], Gn[j], node_label,
  135. # edge_label, weight)
  136. # Kmatrix[j][i] = Kmatrix[i][j]
  137. # # search all paths use brute force.
  138. # elif compute_method == 'brute':
  139. # n = int(n)
  140. # # get all paths of all graphs before computing kernels to save time, but this may cost a lot of memory for large dataset.
  141. # all_walks = [
  142. # find_all_walks_until_length(Gn[i], n, node_label, edge_label)
  143. # for i in range(0, len(Gn))
  144. # ]
  145. #
  146. # for i in range(0, len(Gn)):
  147. # for j in range(i, len(Gn)):
  148. # Kmatrix[i][j] = _commonwalkkernel_brute(
  149. # all_walks[i],
  150. # all_walks[j],
  151. # node_label=node_label,
  152. # edge_label=edge_label)
  153. # Kmatrix[j][i] = Kmatrix[i][j]
  154. run_time = time.time() - start_time
  155. if verbose:
  156. print("\n --- kernel matrix of common walk kernel of size %d built in %s seconds ---"
  157. % (len(Gn), run_time))
  158. return Kmatrix, run_time, idx
  159. def _commonwalkkernel_exp(g1, g2, node_label, edge_label, beta):
  160. """Compute walk graph kernels up to n between 2 graphs using exponential
  161. series.
  162. Parameters
  163. ----------
  164. Gn : List of NetworkX graph
  165. List of graphs between which the kernels are computed.
  166. node_label : string
  167. Node attribute used as label.
  168. edge_label : string
  169. Edge attribute used as label.
  170. beta : integer
  171. Weight.
  172. ij : tuple of integer
  173. Index of graphs between which the kernel is computed.
  174. Return
  175. ------
  176. kernel : float
  177. The common walk Kernel between 2 graphs.
  178. """
  179. # get tensor product / direct product
  180. gp = direct_product(g1, g2, node_label, edge_label)
  181. # return 0 if the direct product graph have no more than 1 node.
  182. if nx.number_of_nodes(gp) < 2:
  183. return 0
  184. A = nx.adjacency_matrix(gp).todense()
  185. # print(A)
  186. # from matplotlib import pyplot as plt
  187. # nx.draw_networkx(G1)
  188. # plt.show()
  189. # nx.draw_networkx(G2)
  190. # plt.show()
  191. # nx.draw_networkx(gp)
  192. # plt.show()
  193. # print(G1.nodes(data=True))
  194. # print(G2.nodes(data=True))
  195. # print(gp.nodes(data=True))
  196. # print(gp.edges(data=True))
  197. ew, ev = np.linalg.eig(A)
  198. # print('ew: ', ew)
  199. # print(ev)
  200. # T = np.matrix(ev)
  201. # print('T: ', T)
  202. # T = ev.I
  203. D = np.zeros((len(ew), len(ew)))
  204. for i in range(len(ew)):
  205. D[i][i] = np.exp(beta * ew[i])
  206. # print('D: ', D)
  207. # print('hshs: ', T.I * D * T)
  208. # print(np.exp(-2))
  209. # print(D)
  210. # print(np.exp(weight * D))
  211. # print(ev)
  212. # print(np.linalg.inv(ev))
  213. exp_D = ev * D * ev.T
  214. # print(exp_D)
  215. # print(np.exp(weight * A))
  216. # print('-------')
  217. return exp_D.sum()
  218. def wrapper_cw_exp(node_label, edge_label, beta, itr):
  219. i = itr[0]
  220. j = itr[1]
  221. return i, j, _commonwalkkernel_exp(G_gn[i], G_gn[j], node_label, edge_label, beta)
  222. def _commonwalkkernel_geo(g1, g2, node_label, edge_label, gamma):
  223. """Compute common walk graph kernels up to n between 2 graphs using
  224. geometric series.
  225. Parameters
  226. ----------
  227. Gn : List of NetworkX graph
  228. List of graphs between which the kernels are computed.
  229. node_label : string
  230. Node attribute used as label.
  231. edge_label : string
  232. Edge attribute used as label.
  233. gamma: integer
  234. Weight.
  235. ij : tuple of integer
  236. Index of graphs between which the kernel is computed.
  237. Return
  238. ------
  239. kernel : float
  240. The common walk Kernel between 2 graphs.
  241. """
  242. # get tensor product / direct product
  243. gp = direct_product(g1, g2, node_label, edge_label)
  244. # return 0 if the direct product graph have no more than 1 node.
  245. if nx.number_of_nodes(gp) < 2:
  246. return 0
  247. A = nx.adjacency_matrix(gp).todense()
  248. mat = np.identity(len(A)) - gamma * A
  249. # try:
  250. return mat.I.sum()
  251. # except np.linalg.LinAlgError:
  252. # return np.nan
  253. def wrapper_cw_geo(node_label, edge_label, gama, itr):
  254. i = itr[0]
  255. j = itr[1]
  256. return i, j, _commonwalkkernel_geo(G_gn[i], G_gn[j], node_label, edge_label, gama)
  257. def _commonwalkkernel_brute(walks1,
  258. walks2,
  259. node_label='atom',
  260. edge_label='bond_type',
  261. labeled=True):
  262. """Compute walk graph kernels up to n between 2 graphs.
  263. Parameters
  264. ----------
  265. walks1, walks2 : list
  266. List of walks in 2 graphs, where for unlabeled graphs, each walk is
  267. represented by a list of nodes; while for labeled graphs, each walk is
  268. represented by a string consists of labels of nodes and edges on that
  269. walk.
  270. node_label : string
  271. node attribute used as label. The default node label is atom.
  272. edge_label : string
  273. edge attribute used as label. The default edge label is bond_type.
  274. labeled : boolean
  275. Whether the graphs are labeled. The default is True.
  276. Return
  277. ------
  278. kernel : float
  279. Treelet Kernel between 2 graphs.
  280. """
  281. counts_walks1 = dict(Counter(walks1))
  282. counts_walks2 = dict(Counter(walks2))
  283. all_walks = list(set(walks1 + walks2))
  284. vector1 = [(counts_walks1[walk] if walk in walks1 else 0)
  285. for walk in all_walks]
  286. vector2 = [(counts_walks2[walk] if walk in walks2 else 0)
  287. for walk in all_walks]
  288. kernel = np.dot(vector1, vector2)
  289. return kernel
  290. # this method find walks repetively, it could be faster.
  291. def find_all_walks_until_length(G,
  292. length,
  293. node_label='atom',
  294. edge_label='bond_type',
  295. labeled=True):
  296. """Find all walks with a certain maximum length in a graph.
  297. A recursive depth first search is applied.
  298. Parameters
  299. ----------
  300. G : NetworkX graphs
  301. The graph in which walks are searched.
  302. length : integer
  303. The maximum length of walks.
  304. node_label : string
  305. node attribute used as label. The default node label is atom.
  306. edge_label : string
  307. edge attribute used as label. The default edge label is bond_type.
  308. labeled : boolean
  309. Whether the graphs are labeled. The default is True.
  310. Return
  311. ------
  312. walk : list
  313. List of walks retrieved, where for unlabeled graphs, each walk is
  314. represented by a list of nodes; while for labeled graphs, each walk
  315. is represented by a string consists of labels of nodes and edges on
  316. that walk.
  317. """
  318. all_walks = []
  319. # @todo: in this way, the time complexity is close to N(d^n+d^(n+1)+...+1), which could be optimized to O(Nd^n)
  320. for i in range(0, length + 1):
  321. new_walks = find_all_walks(G, i)
  322. if new_walks == []:
  323. break
  324. all_walks.extend(new_walks)
  325. if labeled == True: # convert paths to strings
  326. walk_strs = []
  327. for walk in all_walks:
  328. strlist = [
  329. G.node[node][node_label] +
  330. G[node][walk[walk.index(node) + 1]][edge_label]
  331. for node in walk[:-1]
  332. ]
  333. walk_strs.append(''.join(strlist) + G.node[walk[-1]][node_label])
  334. return walk_strs
  335. return all_walks
  336. def find_walks(G, source_node, length):
  337. """Find all walks with a certain length those start from a source node. A
  338. recursive depth first search is applied.
  339. Parameters
  340. ----------
  341. G : NetworkX graphs
  342. The graph in which walks are searched.
  343. source_node : integer
  344. The number of the node from where all walks start.
  345. length : integer
  346. The length of walks.
  347. Return
  348. ------
  349. walk : list of list
  350. List of walks retrieved, where each walk is represented by a list of
  351. nodes.
  352. """
  353. return [[source_node]] if length == 0 else \
  354. [[source_node] + walk for neighbor in G[source_node]
  355. for walk in find_walks(G, neighbor, length - 1)]
  356. def find_all_walks(G, length):
  357. """Find all walks with a certain length in a graph. A recursive depth first
  358. search is applied.
  359. Parameters
  360. ----------
  361. G : NetworkX graphs
  362. The graph in which walks are searched.
  363. length : integer
  364. The length of walks.
  365. Return
  366. ------
  367. walk : list of list
  368. List of walks retrieved, where each walk is represented by a list of
  369. nodes.
  370. """
  371. all_walks = []
  372. for node in G:
  373. all_walks.extend(find_walks(G, node, length))
  374. # The following process is not carried out according to the original article
  375. # all_paths_r = [ path[::-1] for path in all_paths ]
  376. # # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  377. # for idx, path in enumerate(all_paths[:-1]):
  378. # for path2 in all_paths_r[idx+1::]:
  379. # if path == path2:
  380. # all_paths[idx] = []
  381. # break
  382. # return list(filter(lambda a: a != [], all_paths))
  383. return all_walks

A Python package for graph kernels, graph edit distances and graph pre-image problem.