You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

commonWalkKernel.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Thomas Gärtner, Peter Flach, and Stefan Wrobel. On graph kernels:
  5. Hardness results and efficient alternatives. Learning Theory and Kernel
  6. Machines, pages 129–143, 2003.
  7. """
  8. import sys
  9. import time
  10. from collections import Counter
  11. from functools import partial
  12. import networkx as nx
  13. import numpy as np
  14. sys.path.insert(0, "../")
  15. from pygraph.utils.utils import direct_product
  16. from pygraph.utils.graphdataset import get_dataset_attributes
  17. from pygraph.utils.parallel import parallel_gm
  18. def commonwalkkernel(*args,
  19. node_label='atom',
  20. edge_label='bond_type',
  21. # n=None,
  22. weight=1,
  23. compute_method=None,
  24. n_jobs=None,
  25. verbose=True):
  26. """Calculate common walk graph kernels between graphs.
  27. Parameters
  28. ----------
  29. Gn : List of NetworkX graph
  30. List of graphs between which the kernels are calculated.
  31. /
  32. G1, G2 : NetworkX graphs
  33. Two graphs between which the kernel is calculated.
  34. node_label : string
  35. Node attribute used as symbolic label. The default node label is 'atom'.
  36. edge_label : string
  37. Edge attribute used as symbolic label. The default edge label is 'bond_type'.
  38. # n : integer
  39. # Longest length of walks. Only useful when applying the 'brute' method.
  40. weight: integer
  41. Weight coefficient of different lengths of walks, which represents beta
  42. in 'exp' method and gamma in 'geo'.
  43. compute_method : string
  44. Method used to compute walk kernel. The Following choices are
  45. available:
  46. 'exp': method based on exponential serials applied on the direct
  47. product graph, as shown in reference [1]. The time complexity is O(n^6)
  48. for graphs with n vertices.
  49. 'geo': method based on geometric serials applied on the direct product
  50. graph, as shown in reference [1]. The time complexity is O(n^6) for
  51. graphs with n vertices.
  52. # 'brute': brute force, simply search for all walks and compare them.
  53. n_jobs : int
  54. Number of jobs for parallelization.
  55. Return
  56. ------
  57. Kmatrix : Numpy matrix
  58. Kernel matrix, each element of which is a common walk kernel between 2
  59. graphs.
  60. """
  61. compute_method = compute_method.lower()
  62. # arrange all graphs in a list
  63. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  64. # remove graphs with only 1 node, as they do not have adjacency matrices
  65. len_gn = len(Gn)
  66. Gn = [(idx, G) for idx, G in enumerate(Gn) if nx.number_of_nodes(G) != 1]
  67. idx = [G[0] for G in Gn]
  68. Gn = [G[1] for G in Gn]
  69. if len(Gn) != len_gn:
  70. if verbose:
  71. print('\n %d graphs are removed as they have only 1 node.\n' %
  72. (len_gn - len(Gn)))
  73. ds_attrs = get_dataset_attributes(
  74. Gn,
  75. attr_names=['node_labeled', 'edge_labeled', 'is_directed'],
  76. node_label=node_label, edge_label=edge_label)
  77. if not ds_attrs['node_labeled']:
  78. for G in Gn:
  79. nx.set_node_attributes(G, '0', 'atom')
  80. if not ds_attrs['edge_labeled']:
  81. for G in Gn:
  82. nx.set_edge_attributes(G, '0', 'bond_type')
  83. if not ds_attrs['is_directed']: # convert
  84. Gn = [G.to_directed() for G in Gn]
  85. start_time = time.time()
  86. Kmatrix = np.zeros((len(Gn), len(Gn)))
  87. # ---- use pool.imap_unordered to parallel and track progress. ----
  88. def init_worker(gn_toshare):
  89. global G_gn
  90. G_gn = gn_toshare
  91. # direct product graph method - exponential
  92. if compute_method == 'exp':
  93. do_partial = partial(wrapper_cw_exp, node_label, edge_label, weight)
  94. # direct product graph method - geometric
  95. elif compute_method == 'geo':
  96. do_partial = partial(wrapper_cw_geo, node_label, edge_label, weight)
  97. parallel_gm(do_partial, Kmatrix, Gn, init_worker=init_worker,
  98. glbv=(Gn,), n_jobs=n_jobs, verbose=verbose)
  99. # pool = Pool(n_jobs)
  100. # itr = zip(combinations_with_replacement(Gn, 2),
  101. # combinations_with_replacement(range(0, len(Gn)), 2))
  102. # len_itr = int(len(Gn) * (len(Gn) + 1) / 2)
  103. # if len_itr < 1000 * n_jobs:
  104. # chunksize = int(len_itr / n_jobs) + 1
  105. # else:
  106. # chunksize = 1000
  107. #
  108. # # direct product graph method - exponential
  109. # if compute_method == 'exp':
  110. # do_partial = partial(wrapper_cw_exp, node_label, edge_label, weight)
  111. # # direct product graph method - geometric
  112. # elif compute_method == 'geo':
  113. # do_partial = partial(wrapper_cw_geo, node_label, edge_label, weight)
  114. #
  115. # for i, j, kernel in tqdm(
  116. # pool.imap_unordered(do_partial, itr, chunksize),
  117. # desc='calculating kernels',
  118. # file=sys.stdout):
  119. # Kmatrix[i][j] = kernel
  120. # Kmatrix[j][i] = kernel
  121. # pool.close()
  122. # pool.join()
  123. # # ---- direct running, normally use single CPU core. ----
  124. # # direct product graph method - exponential
  125. # itr = combinations_with_replacement(range(0, len(Gn)), 2)
  126. # if compute_method == 'exp':
  127. # for i, j in tqdm(itr, desc='calculating kernels', file=sys.stdout):
  128. # Kmatrix[i][j] = _commonwalkkernel_exp(Gn[i], Gn[j], node_label,
  129. # edge_label, weight)
  130. # Kmatrix[j][i] = Kmatrix[i][j]
  131. #
  132. # # direct product graph method - geometric
  133. # elif compute_method == 'geo':
  134. # for i, j in tqdm(itr, desc='calculating kernels', file=sys.stdout):
  135. # Kmatrix[i][j] = _commonwalkkernel_geo(Gn[i], Gn[j], node_label,
  136. # edge_label, weight)
  137. # Kmatrix[j][i] = Kmatrix[i][j]
  138. # # search all paths use brute force.
  139. # elif compute_method == 'brute':
  140. # n = int(n)
  141. # # get all paths of all graphs before calculating kernels to save time, but this may cost a lot of memory for large dataset.
  142. # all_walks = [
  143. # find_all_walks_until_length(Gn[i], n, node_label, edge_label)
  144. # for i in range(0, len(Gn))
  145. # ]
  146. #
  147. # for i in range(0, len(Gn)):
  148. # for j in range(i, len(Gn)):
  149. # Kmatrix[i][j] = _commonwalkkernel_brute(
  150. # all_walks[i],
  151. # all_walks[j],
  152. # node_label=node_label,
  153. # edge_label=edge_label)
  154. # Kmatrix[j][i] = Kmatrix[i][j]
  155. run_time = time.time() - start_time
  156. if verbose:
  157. print("\n --- kernel matrix of common walk kernel of size %d built in %s seconds ---"
  158. % (len(Gn), run_time))
  159. return Kmatrix, run_time, idx
  160. def _commonwalkkernel_exp(g1, g2, node_label, edge_label, beta):
  161. """Calculate walk graph kernels up to n between 2 graphs using exponential
  162. series.
  163. Parameters
  164. ----------
  165. Gn : List of NetworkX graph
  166. List of graphs between which the kernels are calculated.
  167. node_label : string
  168. Node attribute used as label.
  169. edge_label : string
  170. Edge attribute used as label.
  171. beta : integer
  172. Weight.
  173. ij : tuple of integer
  174. Index of graphs between which the kernel is computed.
  175. Return
  176. ------
  177. kernel : float
  178. The common walk Kernel between 2 graphs.
  179. """
  180. # get tensor product / direct product
  181. gp = direct_product(g1, g2, node_label, edge_label)
  182. # return 0 if the direct product graph have no more than 1 node.
  183. if nx.number_of_nodes(gp) < 2:
  184. return 0
  185. A = nx.adjacency_matrix(gp).todense()
  186. # print(A)
  187. # from matplotlib import pyplot as plt
  188. # nx.draw_networkx(G1)
  189. # plt.show()
  190. # nx.draw_networkx(G2)
  191. # plt.show()
  192. # nx.draw_networkx(gp)
  193. # plt.show()
  194. # print(G1.nodes(data=True))
  195. # print(G2.nodes(data=True))
  196. # print(gp.nodes(data=True))
  197. # print(gp.edges(data=True))
  198. ew, ev = np.linalg.eig(A)
  199. # print('ew: ', ew)
  200. # print(ev)
  201. # T = np.matrix(ev)
  202. # print('T: ', T)
  203. # T = ev.I
  204. D = np.zeros((len(ew), len(ew)))
  205. for i in range(len(ew)):
  206. D[i][i] = np.exp(beta * ew[i])
  207. # print('D: ', D)
  208. # print('hshs: ', T.I * D * T)
  209. # print(np.exp(-2))
  210. # print(D)
  211. # print(np.exp(weight * D))
  212. # print(ev)
  213. # print(np.linalg.inv(ev))
  214. exp_D = ev * D * ev.T
  215. # print(exp_D)
  216. # print(np.exp(weight * A))
  217. # print('-------')
  218. return exp_D.sum()
  219. def wrapper_cw_exp(node_label, edge_label, beta, itr):
  220. i = itr[0]
  221. j = itr[1]
  222. return i, j, _commonwalkkernel_exp(G_gn[i], G_gn[j], node_label, edge_label, beta)
  223. def _commonwalkkernel_geo(g1, g2, node_label, edge_label, gamma):
  224. """Calculate common walk graph kernels up to n between 2 graphs using
  225. geometric series.
  226. Parameters
  227. ----------
  228. Gn : List of NetworkX graph
  229. List of graphs between which the kernels are calculated.
  230. node_label : string
  231. Node attribute used as label.
  232. edge_label : string
  233. Edge attribute used as label.
  234. gamma: integer
  235. Weight.
  236. ij : tuple of integer
  237. Index of graphs between which the kernel is computed.
  238. Return
  239. ------
  240. kernel : float
  241. The common walk Kernel between 2 graphs.
  242. """
  243. # get tensor product / direct product
  244. gp = direct_product(g1, g2, node_label, edge_label)
  245. # return 0 if the direct product graph have no more than 1 node.
  246. if nx.number_of_nodes(gp) < 2:
  247. return 0
  248. A = nx.adjacency_matrix(gp).todense()
  249. mat = np.identity(len(A)) - gamma * A
  250. # try:
  251. return mat.I.sum()
  252. # except np.linalg.LinAlgError:
  253. # return np.nan
  254. def wrapper_cw_geo(node_label, edge_label, gama, itr):
  255. i = itr[0]
  256. j = itr[1]
  257. return i, j, _commonwalkkernel_geo(G_gn[i], G_gn[j], node_label, edge_label, gama)
  258. def _commonwalkkernel_brute(walks1,
  259. walks2,
  260. node_label='atom',
  261. edge_label='bond_type',
  262. labeled=True):
  263. """Calculate walk graph kernels up to n between 2 graphs.
  264. Parameters
  265. ----------
  266. walks1, walks2 : list
  267. List of walks in 2 graphs, where for unlabeled graphs, each walk is
  268. represented by a list of nodes; while for labeled graphs, each walk is
  269. represented by a string consists of labels of nodes and edges on that
  270. walk.
  271. node_label : string
  272. node attribute used as label. The default node label is atom.
  273. edge_label : string
  274. edge attribute used as label. The default edge label is bond_type.
  275. labeled : boolean
  276. Whether the graphs are labeled. The default is True.
  277. Return
  278. ------
  279. kernel : float
  280. Treelet Kernel between 2 graphs.
  281. """
  282. counts_walks1 = dict(Counter(walks1))
  283. counts_walks2 = dict(Counter(walks2))
  284. all_walks = list(set(walks1 + walks2))
  285. vector1 = [(counts_walks1[walk] if walk in walks1 else 0)
  286. for walk in all_walks]
  287. vector2 = [(counts_walks2[walk] if walk in walks2 else 0)
  288. for walk in all_walks]
  289. kernel = np.dot(vector1, vector2)
  290. return kernel
  291. # this method find walks repetively, it could be faster.
  292. def find_all_walks_until_length(G,
  293. length,
  294. node_label='atom',
  295. edge_label='bond_type',
  296. labeled=True):
  297. """Find all walks with a certain maximum length in a graph.
  298. A recursive depth first search is applied.
  299. Parameters
  300. ----------
  301. G : NetworkX graphs
  302. The graph in which walks are searched.
  303. length : integer
  304. The maximum length of walks.
  305. node_label : string
  306. node attribute used as label. The default node label is atom.
  307. edge_label : string
  308. edge attribute used as label. The default edge label is bond_type.
  309. labeled : boolean
  310. Whether the graphs are labeled. The default is True.
  311. Return
  312. ------
  313. walk : list
  314. List of walks retrieved, where for unlabeled graphs, each walk is
  315. represented by a list of nodes; while for labeled graphs, each walk
  316. is represented by a string consists of labels of nodes and edges on
  317. that walk.
  318. """
  319. all_walks = []
  320. # @todo: in this way, the time complexity is close to N(d^n+d^(n+1)+...+1), which could be optimized to O(Nd^n)
  321. for i in range(0, length + 1):
  322. new_walks = find_all_walks(G, i)
  323. if new_walks == []:
  324. break
  325. all_walks.extend(new_walks)
  326. if labeled == True: # convert paths to strings
  327. walk_strs = []
  328. for walk in all_walks:
  329. strlist = [
  330. G.node[node][node_label] +
  331. G[node][walk[walk.index(node) + 1]][edge_label]
  332. for node in walk[:-1]
  333. ]
  334. walk_strs.append(''.join(strlist) + G.node[walk[-1]][node_label])
  335. return walk_strs
  336. return all_walks
  337. def find_walks(G, source_node, length):
  338. """Find all walks with a certain length those start from a source node. A
  339. recursive depth first search is applied.
  340. Parameters
  341. ----------
  342. G : NetworkX graphs
  343. The graph in which walks are searched.
  344. source_node : integer
  345. The number of the node from where all walks start.
  346. length : integer
  347. The length of walks.
  348. Return
  349. ------
  350. walk : list of list
  351. List of walks retrieved, where each walk is represented by a list of
  352. nodes.
  353. """
  354. return [[source_node]] if length == 0 else \
  355. [[source_node] + walk for neighbor in G[source_node]
  356. for walk in find_walks(G, neighbor, length - 1)]
  357. def find_all_walks(G, length):
  358. """Find all walks with a certain length in a graph. A recursive depth first
  359. search is applied.
  360. Parameters
  361. ----------
  362. G : NetworkX graphs
  363. The graph in which walks are searched.
  364. length : integer
  365. The length of walks.
  366. Return
  367. ------
  368. walk : list of list
  369. List of walks retrieved, where each walk is represented by a list of
  370. nodes.
  371. """
  372. all_walks = []
  373. for node in G:
  374. all_walks.extend(find_walks(G, node, length))
  375. # The following process is not carried out according to the original article
  376. # all_paths_r = [ path[::-1] for path in all_paths ]
  377. # # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  378. # for idx, path in enumerate(all_paths[:-1]):
  379. # for path2 in all_paths_r[idx+1::]:
  380. # if path == path2:
  381. # all_paths[idx] = []
  382. # break
  383. # return list(filter(lambda a: a != [], all_paths))
  384. return all_walks

A Python package for graph kernels, graph edit distances and graph pre-image problem.