You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

treePatternKernel.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. @author: linlin
  3. @references: Pierre Mahé and Jean-Philippe Vert. Graph kernels based on tree patterns for molecules. Machine learning, 75(1):3–35, 2009.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. import time
  9. from collections import Counter
  10. import networkx as nx
  11. import numpy as np
  12. def treepatternkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True, kernel_type = 'untiln', lmda = 1, h = 1):
  13. """Calculate tree pattern graph kernels between graphs.
  14. Parameters
  15. ----------
  16. Gn : List of NetworkX graph
  17. List of graphs between which the kernels are calculated.
  18. /
  19. G1, G2 : NetworkX graphs
  20. 2 graphs between which the kernel is calculated.
  21. node_label : string
  22. node attribute used as label. The default node label is atom.
  23. edge_label : string
  24. edge attribute used as label. The default edge label is bond_type.
  25. labeled : boolean
  26. Whether the graphs are labeled. The default is True.
  27. depth : integer
  28. Depth of search. Longest length of paths.
  29. k_func : function
  30. A kernel function used using different notions of fingerprint similarity.
  31. Return
  32. ------
  33. Kmatrix: Numpy matrix
  34. Kernel matrix, each element of which is the tree pattern graph kernel between 2 praphs.
  35. """
  36. if h < 1:
  37. raise Exception('h > 0 is requested.')
  38. kernel_type = kernel_type.lower()
  39. Gn = args[0] if len(args) == 1 else [args[0], args[1]] # arrange all graphs in a list
  40. Kmatrix = np.zeros((len(Gn), len(Gn)))
  41. h = int(h)
  42. start_time = time.time()
  43. for i in range(0, len(Gn)):
  44. for j in range(i, len(Gn)):
  45. Kmatrix[i][j] = _treepatternkernel_do(Gn[i], Gn[j], node_label, edge_label, labeled, kernel_type, lmda, h)
  46. Kmatrix[j][i] = Kmatrix[i][j]
  47. run_time = time.time() - start_time
  48. print("\n --- kernel matrix of tree pattern kernel of size %d built in %s seconds ---" % (len(Gn), run_time))
  49. return Kmatrix, run_time
  50. def _treepatternkernel_do(G1, G2, node_label, edge_label, labeled, kernel_type, lmda, h):
  51. """Calculate tree pattern graph kernels between 2 graphs.
  52. Parameters
  53. ----------
  54. paths1, paths2 : list
  55. List of paths in 2 graphs, where for unlabeled graphs, each path is represented by a list of nodes; while for labeled graphs, each path is represented by a string consists of labels of nodes and edges on that path.
  56. k_func : function
  57. A kernel function used using different notions of fingerprint similarity.
  58. node_label : string
  59. node attribute used as label. The default node label is atom.
  60. edge_label : string
  61. edge attribute used as label. The default edge label is bond_type.
  62. labeled : boolean
  63. Whether the graphs are labeled. The default is True.
  64. Return
  65. ------
  66. kernel : float
  67. Treelet Kernel between 2 graphs.
  68. """
  69. def matchingset(n1, n2):
  70. """Get neiborhood matching set of two nodes in two graphs.
  71. """
  72. def mset_com(allpairs, length):
  73. """Find all sets R of pairs by combination.
  74. """
  75. if length == 1:
  76. mset = [ [pair] for pair in allpairs ]
  77. return mset, mset
  78. else:
  79. mset, mset_l = mset_com(allpairs, length - 1)
  80. mset_tmp = []
  81. for pairset in mset_l: # for each pair set of length l-1
  82. nodeset1 = [ pair[0] for pair in pairset ] # nodes already in the set
  83. nodeset2 = [ pair[1] for pair in pairset ]
  84. for pair in allpairs:
  85. if (pair[0] not in nodeset1) and (pair[1] not in nodeset2): # nodes in R should be unique
  86. mset_tmp.append(pairset + [pair]) # add this pair to the pair set of length l-1, constructing a new set of length l
  87. nodeset1.append(pair[0])
  88. nodeset2.append(pair[1])
  89. mset.extend(mset_tmp)
  90. return mset, mset_tmp
  91. allpairs = [] # all pairs those have the same node labels and edge labels
  92. for neighbor1 in G1[n1]:
  93. for neighbor2 in G2[n2]:
  94. if G1.node[neighbor1][node_label] == G2.node[neighbor2][node_label] \
  95. and G1[n1][neighbor1][edge_label] == G2[n2][neighbor2][edge_label]:
  96. allpairs.append([neighbor1, neighbor2])
  97. if allpairs != []:
  98. mset, _ = mset_com(allpairs, len(allpairs))
  99. else:
  100. mset = []
  101. return mset
  102. def kernel_h(h):
  103. """Calculate kernel of h-th iteration.
  104. """
  105. if kernel_type == 'untiln':
  106. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  107. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  108. all_kh_tmp = all_kh.copy()
  109. for i in range(2, h + 1):
  110. for n1 in G1.nodes():
  111. for n2 in G2.nodes():
  112. kh = 0
  113. mset = all_msets[str(n1) + '.' + str(n2)]
  114. for R in mset:
  115. kh_tmp = 1
  116. for pair in R:
  117. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  118. kh += 1 / lmda * kh_tmp
  119. kh = (G1.node[n1][node_label] == G2.node[n2][node_label]) * (1 + kh)
  120. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  121. all_kh = all_kh_tmp.copy()
  122. elif kernel_type == 'size':
  123. all_kh = { str(n1) + '.' + str(n2) : lmda * (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  124. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  125. all_kh_tmp = all_kh.copy()
  126. for i in range(2, h + 1):
  127. for n1 in G1.nodes():
  128. for n2 in G2.nodes():
  129. kh = 0
  130. mset = all_msets[str(n1) + '.' + str(n2)]
  131. for R in mset:
  132. kh_tmp = 1
  133. for pair in R:
  134. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  135. kh += kh_tmp
  136. kh *= lmda * (G1.node[n1][node_label] == G2.node[n2][node_label])
  137. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  138. all_kh = all_kh_tmp.copy()
  139. elif kernel_type == 'branching':
  140. all_kh = { str(n1) + '.' + str(n2) : (G1.node[n1][node_label] == G2.node[n2][node_label]) \
  141. for n1 in G1.nodes() for n2 in G2.nodes() } # kernels between all pair of nodes with h = 1 ]
  142. all_kh_tmp = all_kh.copy()
  143. for i in range(2, h + 1):
  144. for n1 in G1.nodes():
  145. for n2 in G2.nodes():
  146. kh = 0
  147. mset = all_msets[str(n1) + '.' + str(n2)]
  148. for R in mset:
  149. kh_tmp = 1
  150. for pair in R:
  151. kh_tmp *= lmda * all_kh[str(pair[0]) + '.' + str(pair[1])]
  152. kh += 1 / lmda * kh_tmp
  153. kh *= (G1.node[n1][node_label] == G2.node[n2][node_label])
  154. all_kh_tmp[str(n1) + '.' + str(n2)] = kh
  155. all_kh = all_kh_tmp.copy()
  156. return all_kh
  157. # calculate matching sets for every pair of nodes at first to avoid calculating in every iteration.
  158. all_msets = ({ str(node1) + '.' + str(node2) : matchingset(node1, node2) for node1 in G1.nodes() \
  159. for node2 in G2.nodes() } if h > 1 else {})
  160. all_kh = kernel_h(h)
  161. kernel = sum(all_kh.values())
  162. if kernel_type == 'size':
  163. kernel = kernel / (lmda ** h)
  164. return kernel

A Python package for graph kernels, graph edit distances and graph pre-image problem.