You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treePatternKernel.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. @author: linlin
  3. @references: Pierre Mahé and Jean-Philippe Vert. Graph kernels based on tree patterns for molecules. Machine learning, 75(1):3–35, 2009.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. import time
  9. import networkx as nx
  10. import numpy as np
  11. from collections import Counter
  12. from tqdm import tqdm
  13. tqdm.monitor_interval = 0
  14. from pygraph.utils.utils import untotterTransformation
  15. def treepatternkernel(*args,
  16. node_label='atom',
  17. edge_label='bond_type',
  18. labeled=True,
  19. kernel_type='untiln',
  20. lmda=1,
  21. h=1,
  22. remove_totters=True):
  23. """Calculate tree pattern graph kernels between graphs.
  24. Parameters
  25. ----------
  26. Gn : List of NetworkX graph
  27. List of graphs between which the kernels are calculated.
  28. /
  29. G1, G2 : NetworkX graphs
  30. 2 graphs between which the kernel is calculated.
  31. node_label : string
  32. node attribute used as label. The default node label is atom.
  33. edge_label : string
  34. edge attribute used as label. The default edge label is bond_type.
  35. labeled : boolean
  36. Whether the graphs are labeled. The default is True.
  37. kernel_type : string
  38. Type of tree pattern kernel, could be 'untiln', 'size' or 'branching'.
  39. lmda : float
  40. Weight to decide whether linear patterns or trees pattern of increasing complexity are favored.
  41. h : integer
  42. The upper bound of the height of tree patterns.
  43. remove_totters : boolean
  44. whether to remove totters. The default value is True.
  45. Return
  46. ------
  47. Kmatrix: Numpy matrix
  48. Kernel matrix, each element of which is the tree pattern graph kernel between 2 praphs.
  49. """
  50. if h < 1:
  51. raise Exception('h > 0 is requested.')
  52. kernel_type = kernel_type.lower()
  53. # arrange all graphs in a list
  54. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  55. Kmatrix = np.zeros((len(Gn), len(Gn)))
  56. h = int(h)
  57. start_time = time.time()
  58. if remove_totters:
  59. Gn = [untotterTransformation(G, node_label, edge_label) for G in Gn]
  60. pbar = tqdm(
  61. total=(1 + len(Gn)) * len(Gn) / 2,
  62. desc='calculate kernels',
  63. file=sys.stdout)
  64. for i in range(0, len(Gn)):
  65. for j in range(i, len(Gn)):
  66. Kmatrix[i][j] = _treepatternkernel_do(Gn[i], Gn[j], node_label,
  67. edge_label, labeled,
  68. kernel_type, lmda, h)
  69. Kmatrix[j][i] = Kmatrix[i][j]
  70. pbar.update(1)
  71. run_time = time.time() - start_time
  72. print(
  73. "\n --- kernel matrix of tree pattern kernel of size %d built in %s seconds ---"
  74. % (len(Gn), run_time))
  75. return Kmatrix, run_time
  76. def _treepatternkernel_do(G1, G2, node_label, edge_label, labeled, kernel_type,
  77. lmda, h):
  78. """Calculate tree pattern graph kernels between 2 graphs.
  79. Parameters
  80. ----------
  81. paths1, paths2 : list
  82. List of paths in 2 graphs, where for unlabeled graphs, each path is represented by a list of nodes; while for labeled graphs, each path is represented by a string consists of labels of nodes and edges on that path.
  83. k_func : function
  84. A kernel function used using different notions of fingerprint similarity.
  85. node_label : string
  86. node attribute used as label. The default node label is atom.
  87. edge_label : string
  88. edge attribute used as label. The default edge label is bond_type.
  89. labeled : boolean
  90. Whether the graphs are labeled. The default is True.
  91. kernel_type : string
  92. Type of tree pattern kernel, could be 'untiln', 'size' or 'branching'.
  93. lmda : float
  94. Weight to decide whether linear patterns or trees pattern of increasing complexity are favored.
  95. h : integer
  96. The upper bound of the height of tree patterns.
  97. Return
  98. ------
  99. kernel : float
  100. Treelet Kernel between 2 graphs.
  101. """
  102. def matchingset(n1, n2):
  103. """Get neiborhood matching set of two nodes in two graphs.
  104. """
  105. def mset_com(allpairs, length):
  106. """Find all sets R of pairs by combination.
  107. """
  108. if length == 1:
  109. mset = [[pair] for pair in allpairs]
  110. return mset, mset
  111. else:
  112. mset, mset_l = mset_com(allpairs, length - 1)
  113. mset_tmp = []
  114. for pairset in mset_l: # for each pair set of length l-1
  115. nodeset1 = [pair[0] for pair in pairset
  116. ] # nodes already in the set
  117. nodeset2 = [pair[1] for pair in pairset]
  118. for pair in allpairs:
  119. if (pair[0] not in nodeset1) and (
  120. pair[1] not in nodeset2
  121. ): # nodes in R should be unique
  122. mset_tmp.append(
  123. pairset + [pair]
  124. ) # add this pair to the pair set of length l-1, constructing a new set of length l
  125. nodeset1.append(pair[0])
  126. nodeset2.append(pair[1])
  127. mset.extend(mset_tmp)
  128. return mset, mset_tmp
  129. allpairs = [
  130. ] # all pairs those have the same node labels and edge labels
  131. for neighbor1 in G1[n1]:
  132. for neighbor2 in G2[n2]:
  133. if G1.node[neighbor1][node_label] == G2.node[neighbor2][node_label] \
  134. and G1[n1][neighbor1][edge_label] == G2[n2][neighbor2][edge_label]:
  135. allpairs.append([neighbor1, neighbor2])
  136. if allpairs != []:
  137. mset, _ = mset_com(allpairs, len(allpairs))
  138. else:
  139. mset = []
  140. return mset
  141. def kernel_h(h):
  142. """Calculate kernel of h-th iteration.
  143. """
  144. if kernel_type == 'untiln':
  145. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  146. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  147. all_kh_tmp = all_kh.copy()
  148. for i in range(2, h + 1):
  149. for n1 in G1.nodes():
  150. for n2 in G2.nodes():
  151. kh = 0
  152. mset = all_msets[str(n1) + '.' + str(n2)]
  153. for R in mset:
  154. kh_tmp = 1
  155. for pair in R:
  156. kh_tmp *= lmda * all_kh[str(pair[0])
  157. + '.' + str(pair[1])]
  158. kh += 1 / lmda * kh_tmp
  159. kh = (G1.node[n1][node_label] == G2.node[n2][
  160. node_label]) * (1 + kh)
  161. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  162. all_kh = all_kh_tmp.copy()
  163. elif kernel_type == 'size':
  164. all_kh = { str(n1) + '.' + str(n2) : lmda * (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  165. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  166. all_kh_tmp = all_kh.copy()
  167. for i in range(2, h + 1):
  168. for n1 in G1.nodes():
  169. for n2 in G2.nodes():
  170. kh = 0
  171. mset = all_msets[str(n1) + '.' + str(n2)]
  172. for R in mset:
  173. kh_tmp = 1
  174. for pair in R:
  175. kh_tmp *= lmda * all_kh[str(pair[0])
  176. + '.' + str(pair[1])]
  177. kh += kh_tmp
  178. kh *= lmda * (
  179. G1.node[n1][node_label] == G2.node[n2][node_label])
  180. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  181. all_kh = all_kh_tmp.copy()
  182. elif kernel_type == 'branching':
  183. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  184. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  185. all_kh_tmp = all_kh.copy()
  186. for i in range(2, h + 1):
  187. for n1 in G1.nodes():
  188. for n2 in G2.nodes():
  189. kh = 0
  190. mset = all_msets[str(n1) + '.' + str(n2)]
  191. for R in mset:
  192. kh_tmp = 1
  193. for pair in R:
  194. kh_tmp *= lmda * all_kh[str(pair[0])
  195. + '.' + str(pair[1])]
  196. kh += 1 / lmda * kh_tmp
  197. kh *= (
  198. G1.node[n1][node_label] == G2.node[n2][node_label])
  199. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  200. all_kh = all_kh_tmp.copy()
  201. return all_kh
  202. # calculate matching sets for every pair of nodes at first to avoid calculating in every iteration.
  203. all_msets = ({ str(node1) + '.' + str(node2) : matchingset(node1, node2) for node1 in G1.nodes() \
  204. for node2 in G2.nodes() } if h > 1 else {})
  205. all_kh = kernel_h(h)
  206. kernel = sum(all_kh.values())
  207. if kernel_type == 'size':
  208. kernel = kernel / (lmda**h)
  209. return kernel

A Python package for graph kernels, graph edit distances and graph pre-image problem.