You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treePatternKernel.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. @author: linlin
  3. @references: Pierre Mahé and Jean-Philippe Vert. Graph kernels based on tree patterns for molecules. Machine learning, 75(1):3–35, 2009.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. import time
  9. from collections import Counter
  10. import networkx as nx
  11. import numpy as np
  12. def treepatternkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True, kernel_type = 'untiln', lmda = 1, h = 1):
  13. """Calculate tree pattern graph kernels between graphs.
  14. Parameters
  15. ----------
  16. Gn : List of NetworkX graph
  17. List of graphs between which the kernels are calculated.
  18. /
  19. G1, G2 : NetworkX graphs
  20. 2 graphs between which the kernel is calculated.
  21. node_label : string
  22. node attribute used as label. The default node label is atom.
  23. edge_label : string
  24. edge attribute used as label. The default edge label is bond_type.
  25. labeled : boolean
  26. Whether the graphs are labeled. The default is True.
  27. kernel_type : string
  28. Type of tree pattern kernel, could be 'untiln', 'size' or 'branching'.
  29. lmda : float
  30. Weight to decide whether linear patterns or trees pattern of increasing complexity are favored.
  31. h : integer
  32. The upper bound of the height of tree patterns.
  33. Return
  34. ------
  35. Kmatrix: Numpy matrix
  36. Kernel matrix, each element of which is the tree pattern graph kernel between 2 praphs.
  37. """
  38. if h < 1:
  39. raise Exception('h > 0 is requested.')
  40. kernel_type = kernel_type.lower()
  41. Gn = args[0] if len(args) == 1 else [args[0], args[1]] # arrange all graphs in a list
  42. Kmatrix = np.zeros((len(Gn), len(Gn)))
  43. h = int(h)
  44. start_time = time.time()
  45. for i in range(0, len(Gn)):
  46. for j in range(i, len(Gn)):
  47. Kmatrix[i][j] = _treepatternkernel_do(Gn[i], Gn[j], node_label, edge_label, labeled, kernel_type, lmda, h)
  48. Kmatrix[j][i] = Kmatrix[i][j]
  49. run_time = time.time() - start_time
  50. print("\n --- kernel matrix of tree pattern kernel of size %d built in %s seconds ---" % (len(Gn), run_time))
  51. return Kmatrix, run_time
  52. def _treepatternkernel_do(G1, G2, node_label, edge_label, labeled, kernel_type, lmda, h):
  53. """Calculate tree pattern graph kernels between 2 graphs.
  54. Parameters
  55. ----------
  56. paths1, paths2 : list
  57. List of paths in 2 graphs, where for unlabeled graphs, each path is represented by a list of nodes; while for labeled graphs, each path is represented by a string consists of labels of nodes and edges on that path.
  58. k_func : function
  59. A kernel function used using different notions of fingerprint similarity.
  60. node_label : string
  61. node attribute used as label. The default node label is atom.
  62. edge_label : string
  63. edge attribute used as label. The default edge label is bond_type.
  64. labeled : boolean
  65. Whether the graphs are labeled. The default is True.
  66. kernel_type : string
  67. Type of tree pattern kernel, could be 'untiln', 'size' or 'branching'.
  68. lmda : float
  69. Weight to decide whether linear patterns or trees pattern of increasing complexity are favored.
  70. h : integer
  71. The upper bound of the height of tree patterns.
  72. Return
  73. ------
  74. kernel : float
  75. Treelet Kernel between 2 graphs.
  76. """
  77. def matchingset(n1, n2):
  78. """Get neiborhood matching set of two nodes in two graphs.
  79. """
  80. def mset_com(allpairs, length):
  81. """Find all sets R of pairs by combination.
  82. """
  83. if length == 1:
  84. mset = [ [pair] for pair in allpairs ]
  85. return mset, mset
  86. else:
  87. mset, mset_l = mset_com(allpairs, length - 1)
  88. mset_tmp = []
  89. for pairset in mset_l: # for each pair set of length l-1
  90. nodeset1 = [ pair[0] for pair in pairset ] # nodes already in the set
  91. nodeset2 = [ pair[1] for pair in pairset ]
  92. for pair in allpairs:
  93. if (pair[0] not in nodeset1) and (pair[1] not in nodeset2): # nodes in R should be unique
  94. mset_tmp.append(pairset + [pair]) # add this pair to the pair set of length l-1, constructing a new set of length l
  95. nodeset1.append(pair[0])
  96. nodeset2.append(pair[1])
  97. mset.extend(mset_tmp)
  98. return mset, mset_tmp
  99. allpairs = [] # all pairs those have the same node labels and edge labels
  100. for neighbor1 in G1[n1]:
  101. for neighbor2 in G2[n2]:
  102. if G1.node[neighbor1][node_label] == G2.node[neighbor2][node_label] \
  103. and G1[n1][neighbor1][edge_label] == G2[n2][neighbor2][edge_label]:
  104. allpairs.append([neighbor1, neighbor2])
  105. if allpairs != []:
  106. mset, _ = mset_com(allpairs, len(allpairs))
  107. else:
  108. mset = []
  109. return mset
  110. def kernel_h(h):
  111. """Calculate kernel of h-th iteration.
  112. """
  113. if kernel_type == 'untiln':
  114. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  115. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  116. all_kh_tmp = all_kh.copy()
  117. for i in range(2, h + 1):
  118. for n1 in G1.nodes():
  119. for n2 in G2.nodes():
  120. kh = 0
  121. mset = all_msets[str(n1) + '.' + str(n2)]
  122. for R in mset:
  123. kh_tmp = 1
  124. for pair in R:
  125. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  126. kh += 1 / lmda * kh_tmp
  127. kh = (G1.node[n1][node_label] == G2.node[n2][node_label]) * (1 + kh)
  128. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  129. all_kh = all_kh_tmp.copy()
  130. elif kernel_type == 'size':
  131. all_kh = { str(n1) + '.' + str(n2) : lmda * (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  132. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  133. all_kh_tmp = all_kh.copy()
  134. for i in range(2, h + 1):
  135. for n1 in G1.nodes():
  136. for n2 in G2.nodes():
  137. kh = 0
  138. mset = all_msets[str(n1) + '.' + str(n2)]
  139. for R in mset:
  140. kh_tmp = 1
  141. for pair in R:
  142. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  143. kh += kh_tmp
  144. kh *= lmda * (G1.node[n1][node_label] == G2.node[n2][node_label])
  145. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  146. all_kh = all_kh_tmp.copy()
  147. elif kernel_type == 'branching':
  148. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  149. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  150. all_kh_tmp = all_kh.copy()
  151. for i in range(2, h + 1):
  152. for n1 in G1.nodes():
  153. for n2 in G2.nodes():
  154. kh = 0
  155. mset = all_msets[str(n1) + '.' + str(n2)]
  156. for R in mset:
  157. kh_tmp = 1
  158. for pair in R:
  159. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  160. kh += 1 / lmda * kh_tmp
  161. kh *= (G1.node[n1][node_label] == G2.node[n2][node_label])
  162. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  163. all_kh = all_kh_tmp.copy()
  164. return all_kh
  165. # calculate matching sets for every pair of nodes at first to avoid calculating in every iteration.
  166. all_msets = ({ str(node1) + '.' + str(node2) : matchingset(node1, node2) for node1 in G1.nodes() \
  167. for node2 in G2.nodes() } if h > 1 else {})
  168. all_kh = kernel_h(h)
  169. kernel = sum(all_kh.values())
  170. if kernel_type == 'size':
  171. kernel = kernel / (lmda ** h)
  172. return kernel

A Python package for graph kernels, graph edit distances and graph pre-image problem.