You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

untilnWalkKernel.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Thomas Gärtner, Peter Flach, and Stefan Wrobel. On graph kernels: Hardness results and efficient alternatives. Learning Theory and Kernel Machines, pages 129–143, 2003.
  5. """
  6. import sys
  7. import pathlib
  8. sys.path.insert(0, "../")
  9. import time
  10. from collections import Counter
  11. import networkx as nx
  12. import numpy as np
  13. def untilnwalkkernel(*args,
  14. node_label='atom',
  15. edge_label='bond_type',
  16. labeled=True,
  17. n=None,
  18. compute_method='direct'):
  19. """Calculate common walk graph kernels up to depth d between graphs.
  20. Parameters
  21. ----------
  22. Gn : List of NetworkX graph
  23. List of graphs between which the kernels are calculated.
  24. /
  25. G1, G2 : NetworkX graphs
  26. 2 graphs between which the kernel is calculated.
  27. node_label : string
  28. node attribute used as label. The default node label is atom.
  29. edge_label : string
  30. edge attribute used as label. The default edge label is bond_type.
  31. labeled : boolean
  32. Whether the graphs are labeled. The default is True.
  33. n : integer
  34. Longest length of walks.
  35. compute_method : string
  36. Method used to compute walk kernel. The Following choices are available:
  37. 'direct' : direct product graph method, as shown in reference [1]. The time complexity is O(n^6) for unlabeled graphs with n vertices.
  38. 'brute' : brute force, simply search for all walks and compare them.
  39. Return
  40. ------
  41. Kmatrix : Numpy matrix
  42. Kernel matrix, each element of which is the path kernel up to d between 2 graphs.
  43. """
  44. # arrange all graphs in a list
  45. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  46. Kmatrix = np.zeros((len(Gn), len(Gn)))
  47. n = int(n)
  48. start_time = time.time()
  49. # direct product graph method
  50. if compute_method == 'direct':
  51. for i in range(0, len(Gn)):
  52. for j in range(i, len(Gn)):
  53. Kmatrix[i][j] = _untilnwalkkernel_direct(
  54. Gn[i], Gn[j], node_label, edge_label, labeled)
  55. Kmatrix[j][i] = Kmatrix[i][j]
  56. # search all paths use brute force.
  57. elif compute_method == 'brute':
  58. # get all paths of all graphs before calculating kernels to save time, but this may cost a lot of memory for large dataset.
  59. all_walks = [
  60. find_all_walks_until_length(Gn[i], n, node_label, edge_label,
  61. labeled) for i in range(0, len(Gn))
  62. ]
  63. for i in range(0, len(Gn)):
  64. for j in range(i, len(Gn)):
  65. Kmatrix[i][j] = _untilnwalkkernel_brute(
  66. all_walks[i],
  67. all_walks[j],
  68. node_label=node_label,
  69. edge_label=edge_label,
  70. labeled=labeled)
  71. Kmatrix[j][i] = Kmatrix[i][j]
  72. run_time = time.time() - start_time
  73. print(
  74. "\n --- kernel matrix of walk kernel up to %d of size %d built in %s seconds ---"
  75. % (n, len(Gn), run_time))
  76. return Kmatrix, run_time
  77. def _untilnwalkkernel_direct(G1, G2, node_label, edge_label, labeled):
  78. """Calculate walk graph kernels up to n between 2 graphs using direct product graphs.
  79. Parameters
  80. ----------
  81. G1, G2 : NetworkX graph
  82. Graphs between which the kernel is calculated.
  83. node_label : string
  84. node attribute used as label.
  85. edge_label : string
  86. edge attribute used as label.
  87. labeled : boolean
  88. Whether the graphs are labeled.
  89. Return
  90. ------
  91. kernel : float
  92. Treelet Kernel between 2 graphs.
  93. """
  94. # get tensor product / direct product
  95. gp = nx.tensor_product(G1, G2)
  96. from matplotlib import pyplot as plt
  97. nx.draw_networkx(G1)
  98. plt.show()
  99. nx.draw_networkx(G2)
  100. plt.show()
  101. kernel = 0
  102. nx.draw_networkx(gp)
  103. plt.show()
  104. return kernel
  105. def _untilnwalkkernel_brute(walks1,
  106. walks2,
  107. node_label='atom',
  108. edge_label='bond_type',
  109. labeled=True):
  110. """Calculate walk graph kernels up to n between 2 graphs.
  111. Parameters
  112. ----------
  113. walks1, walks2 : list
  114. List of walks in 2 graphs, where for unlabeled graphs, each walk is represented by a list of nodes; while for labeled graphs, each walk is represented by a string consists of labels of nodes and edges on that walk.
  115. node_label : string
  116. node attribute used as label. The default node label is atom.
  117. edge_label : string
  118. edge attribute used as label. The default edge label is bond_type.
  119. labeled : boolean
  120. Whether the graphs are labeled. The default is True.
  121. Return
  122. ------
  123. kernel : float
  124. Treelet Kernel between 2 graphs.
  125. """
  126. counts_walks1 = dict(Counter(walks1))
  127. counts_walks2 = dict(Counter(walks2))
  128. all_walks = list(set(walks1 + walks2))
  129. vector1 = [(counts_walks1[walk] if walk in walks1 else 0)
  130. for walk in all_walks]
  131. vector2 = [(counts_walks2[walk] if walk in walks2 else 0)
  132. for walk in all_walks]
  133. kernel = np.dot(vector1, vector2)
  134. return kernel
  135. # this method find walks repetively, it could be faster.
  136. def find_all_walks_until_length(G,
  137. length,
  138. node_label='atom',
  139. edge_label='bond_type',
  140. labeled=True):
  141. """Find all walks with a certain maximum length in a graph. A recursive depth first search is applied.
  142. Parameters
  143. ----------
  144. G : NetworkX graphs
  145. The graph in which walks are searched.
  146. length : integer
  147. The maximum length of walks.
  148. node_label : string
  149. node attribute used as label. The default node label is atom.
  150. edge_label : string
  151. edge attribute used as label. The default edge label is bond_type.
  152. labeled : boolean
  153. Whether the graphs are labeled. The default is True.
  154. Return
  155. ------
  156. walk : list
  157. List of walks retrieved, where for unlabeled graphs, each walk is represented by a list of nodes; while for labeled graphs, each walk is represented by a string consists of labels of nodes and edges on that walk.
  158. """
  159. all_walks = []
  160. # @todo: in this way, the time complexity is close to N(d^n+d^(n+1)+...+1), which could be optimized to O(Nd^n)
  161. for i in range(0, length + 1):
  162. new_walks = find_all_walks(G, i)
  163. if new_walks == []:
  164. break
  165. all_walks.extend(new_walks)
  166. if labeled == True: # convert paths to strings
  167. walk_strs = []
  168. for walk in all_walks:
  169. strlist = [
  170. G.node[node][node_label] +
  171. G[node][walk[walk.index(node) + 1]][edge_label]
  172. for node in walk[:-1]
  173. ]
  174. walk_strs.append(''.join(strlist) + G.node[walk[-1]][node_label])
  175. return walk_strs
  176. return all_walks
  177. def find_walks(G, source_node, length):
  178. """Find all walks with a certain length those start from a source node. A recursive depth first search is applied.
  179. Parameters
  180. ----------
  181. G : NetworkX graphs
  182. The graph in which walks are searched.
  183. source_node : integer
  184. The number of the node from where all walks start.
  185. length : integer
  186. The length of walks.
  187. Return
  188. ------
  189. walk : list of list
  190. List of walks retrieved, where each walk is represented by a list of nodes.
  191. """
  192. return [[source_node]] if length == 0 else \
  193. [ [source_node] + walk for neighbor in G[source_node] \
  194. for walk in find_walks(G, neighbor, length - 1) ]
  195. def find_all_walks(G, length):
  196. """Find all walks with a certain length in a graph. A recursive depth first search is applied.
  197. Parameters
  198. ----------
  199. G : NetworkX graphs
  200. The graph in which walks are searched.
  201. length : integer
  202. The length of walks.
  203. Return
  204. ------
  205. walk : list of list
  206. List of walks retrieved, where each walk is represented by a list of nodes.
  207. """
  208. all_walks = []
  209. for node in G:
  210. all_walks.extend(find_walks(G, node, length))
  211. ### The following process is not carried out according to the original article
  212. # all_paths_r = [ path[::-1] for path in all_paths ]
  213. # # For each path, two presentation are retrieved from its two extremities. Remove one of them.
  214. # for idx, path in enumerate(all_paths[:-1]):
  215. # for path2 in all_paths_r[idx+1::]:
  216. # if path == path2:
  217. # all_paths[idx] = []
  218. # break
  219. # return list(filter(lambda a: a != [], all_paths))
  220. return all_walks

A Python package for graph kernels, graph edit distances and graph pre-image problem.