You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

spKernel.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. """
  2. @author: linlin
  3. @references: Borgwardt KM, Kriegel HP. Shortest-path kernels on graphs. InData Mining, Fifth IEEE International Conference on 2005 Nov 27 (pp. 8-pp). IEEE.
  4. """
  5. import sys
  6. import pathlib
  7. sys.path.insert(0, "../")
  8. from tqdm import tqdm
  9. import time
  10. from itertools import combinations_with_replacement, product
  11. from functools import partial
  12. from joblib import Parallel, delayed
  13. from multiprocessing import Pool
  14. import networkx as nx
  15. import numpy as np
  16. from pygraph.utils.utils import getSPGraph
  17. from pygraph.utils.graphdataset import get_dataset_attributes
  18. def spkernel(*args,
  19. node_label='atom',
  20. edge_weight=None,
  21. node_kernels=None,
  22. n_jobs=None):
  23. """Calculate shortest-path kernels between graphs.
  24. Parameters
  25. ----------
  26. Gn : List of NetworkX graph
  27. List of graphs between which the kernels are calculated.
  28. /
  29. G1, G2 : NetworkX graphs
  30. 2 graphs between which the kernel is calculated.
  31. edge_weight : string
  32. Edge attribute name corresponding to the edge weight.
  33. node_kernels: dict
  34. A dictionary of kernel functions for nodes, including 3 items: 'symb' for symbolic node labels, 'nsymb' for non-symbolic node labels, 'mix' for both labels. The first 2 functions take two node labels as parameters, and the 'mix' function takes 4 parameters, a symbolic and a non-symbolic label for each the two nodes. Each label is in form of 2-D dimension array (n_samples, n_features). Each function returns an number as the kernel value. Ignored when nodes are unlabeled.
  35. Return
  36. ------
  37. Kmatrix : Numpy matrix
  38. Kernel matrix, each element of which is the sp kernel between 2 praphs.
  39. """
  40. # pre-process
  41. Gn = args[0] if len(args) == 1 else [args[0], args[1]]
  42. weight = None
  43. if edge_weight == None:
  44. print('\n None edge weight specified. Set all weight to 1.\n')
  45. else:
  46. try:
  47. some_weight = list(
  48. nx.get_edge_attributes(Gn[0], edge_weight).values())[0]
  49. if isinstance(some_weight, float) or isinstance(some_weight, int):
  50. weight = edge_weight
  51. else:
  52. print(
  53. '\n Edge weight with name %s is not float or integer. Set all weight to 1.\n'
  54. % edge_weight)
  55. except:
  56. print(
  57. '\n Edge weight with name "%s" is not found in the edge attributes. Set all weight to 1.\n'
  58. % edge_weight)
  59. ds_attrs = get_dataset_attributes(
  60. Gn,
  61. attr_names=['node_labeled', 'node_attr_dim', 'is_directed'],
  62. node_label=node_label)
  63. # remove graphs with no edges, as no sp can be found in their structures, so the kernel between such a graph and itself will be zero.
  64. len_gn = len(Gn)
  65. Gn = [(idx, G) for idx, G in enumerate(Gn) if nx.number_of_edges(G) != 0]
  66. idx = [G[0] for G in Gn]
  67. Gn = [G[1] for G in Gn]
  68. if len(Gn) != len_gn:
  69. print('\n %d graphs are removed as they don\'t contain edges.\n' %
  70. (len_gn - len(Gn)))
  71. start_time = time.time()
  72. pool = Pool(n_jobs)
  73. # get shortest path graphs of Gn
  74. getsp_partial = partial(wrap_getSPGraph, Gn, edge_weight)
  75. result_sp = pool.map(getsp_partial, range(0, len(Gn)))
  76. for i in result_sp:
  77. Gn[i[0]] = i[1]
  78. # Gn = [
  79. # getSPGraph(G, edge_weight=edge_weight)
  80. # for G in tqdm(Gn, desc='getting sp graphs', file=sys.stdout)
  81. # ]
  82. Kmatrix = np.zeros((len(Gn), len(Gn)))
  83. do_partial = partial(spkernel_do, Gn, ds_attrs, node_label, node_kernels)
  84. itr = combinations_with_replacement(range(0, len(Gn)), 2)
  85. # chunksize = 2000 # int(len(list(itr)) / n_jobs)
  86. # for i, j, kernel in tqdm(pool.imap_unordered(do_partial, itr, chunksize)):
  87. # Kmatrix[i][j] = kernel
  88. # Kmatrix[j][i] = kernel
  89. result_perf = pool.map(do_partial, itr)
  90. pool.close()
  91. pool.join()
  92. # result_perf = Parallel(
  93. # n_jobs=n_jobs, verbose=10)(
  94. # delayed(do_partial)(ij)
  95. # for ij in combinations_with_replacement(range(0, len(Gn)), 2))
  96. # result_perf = [
  97. # do_partial(ij)
  98. # for ij in combinations_with_replacement(range(0, len(Gn)), 2)
  99. # ]
  100. for i in result_perf:
  101. Kmatrix[i[0]][i[1]] = i[2]
  102. Kmatrix[i[1]][i[0]] = i[2]
  103. # pbar = tqdm(
  104. # total=((len(Gn) + 1) * len(Gn) / 2),
  105. # desc='calculating kernels',
  106. # file=sys.stdout)
  107. # if ds_attrs['node_labeled']:
  108. # # node symb and non-synb labeled
  109. # if ds_attrs['node_attr_dim'] > 0:
  110. # if ds_attrs['is_directed']:
  111. # for i, j in combinations_with_replacement(
  112. # range(0, len(Gn)), 2):
  113. # for e1, e2 in product(
  114. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  115. # if e1[2]['cost'] == e2[2]['cost']:
  116. # kn = node_kernels['mix']
  117. # try:
  118. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  119. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  120. # j].nodes[e2[1]]
  121. # kn1 = kn(n11[node_label], n21[node_label], [
  122. # n11['attributes']
  123. # ], [n21['attributes']]) * kn(
  124. # n12[node_label], n22[node_label],
  125. # [n12['attributes']], [n22['attributes']])
  126. # Kmatrix[i][j] += kn1
  127. # except KeyError: # missing labels or attributes
  128. # pass
  129. # Kmatrix[j][i] = Kmatrix[i][j]
  130. # pbar.update(1)
  131. # else:
  132. # for i, j in combinations_with_replacement(
  133. # range(0, len(Gn)), 2):
  134. # for e1, e2 in product(
  135. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  136. # if e1[2]['cost'] == e2[2]['cost']:
  137. # kn = node_kernels['mix']
  138. # try:
  139. # # each edge walk is counted twice, starting from both its extreme nodes.
  140. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  141. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  142. # j].nodes[e2[1]]
  143. # kn1 = kn(n11[node_label], n21[node_label], [
  144. # n11['attributes']
  145. # ], [n21['attributes']]) * kn(
  146. # n12[node_label], n22[node_label],
  147. # [n12['attributes']], [n22['attributes']])
  148. # kn2 = kn(n11[node_label], n22[node_label], [
  149. # n11['attributes']
  150. # ], [n22['attributes']]) * kn(
  151. # n12[node_label], n21[node_label],
  152. # [n12['attributes']], [n21['attributes']])
  153. # Kmatrix[i][j] += kn1 + kn2
  154. # except KeyError: # missing labels or attributes
  155. # pass
  156. # Kmatrix[j][i] = Kmatrix[i][j]
  157. # pbar.update(1)
  158. # # node symb labeled
  159. # else:
  160. # if ds_attrs['is_directed']:
  161. # for i, j in combinations_with_replacement(
  162. # range(0, len(Gn)), 2):
  163. # for e1, e2 in product(
  164. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  165. # if e1[2]['cost'] == e2[2]['cost']:
  166. # kn = node_kernels['symb']
  167. # try:
  168. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  169. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  170. # j].nodes[e2[1]]
  171. # kn1 = kn(n11[node_label],
  172. # n21[node_label]) * kn(
  173. # n12[node_label], n22[node_label])
  174. # Kmatrix[i][j] += kn1
  175. # except KeyError: # missing labels
  176. # pass
  177. # Kmatrix[j][i] = Kmatrix[i][j]
  178. # pbar.update(1)
  179. # else:
  180. # for i, j in combinations_with_replacement(
  181. # range(0, len(Gn)), 2):
  182. # for e1, e2 in product(
  183. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  184. # if e1[2]['cost'] == e2[2]['cost']:
  185. # kn = node_kernels['symb']
  186. # try:
  187. # # each edge walk is counted twice, starting from both its extreme nodes.
  188. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  189. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  190. # j].nodes[e2[1]]
  191. # kn1 = kn(n11[node_label],
  192. # n21[node_label]) * kn(
  193. # n12[node_label], n22[node_label])
  194. # kn2 = kn(n11[node_label],
  195. # n22[node_label]) * kn(
  196. # n12[node_label], n21[node_label])
  197. # Kmatrix[i][j] += kn1 + kn2
  198. # except KeyError: # missing labels
  199. # pass
  200. # Kmatrix[j][i] = Kmatrix[i][j]
  201. # pbar.update(1)
  202. # else:
  203. # # node non-synb labeled
  204. # if ds_attrs['node_attr_dim'] > 0:
  205. # if ds_attrs['is_directed']:
  206. # for i, j in combinations_with_replacement(
  207. # range(0, len(Gn)), 2):
  208. # for e1, e2 in product(
  209. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  210. # if e1[2]['cost'] == e2[2]['cost']:
  211. # kn = node_kernels['nsymb']
  212. # try:
  213. # # each edge walk is counted twice, starting from both its extreme nodes.
  214. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  215. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  216. # j].nodes[e2[1]]
  217. # kn1 = kn([n11['attributes']],
  218. # [n21['attributes']]) * kn(
  219. # [n12['attributes']],
  220. # [n22['attributes']])
  221. # Kmatrix[i][j] += kn1
  222. # except KeyError: # missing attributes
  223. # pass
  224. # Kmatrix[j][i] = Kmatrix[i][j]
  225. # pbar.update(1)
  226. # else:
  227. # for i, j in combinations_with_replacement(
  228. # range(0, len(Gn)), 2):
  229. # for e1, e2 in product(
  230. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  231. # if e1[2]['cost'] == e2[2]['cost']:
  232. # kn = node_kernels['nsymb']
  233. # try:
  234. # # each edge walk is counted twice, starting from both its extreme nodes.
  235. # n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  236. # i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  237. # j].nodes[e2[1]]
  238. # kn1 = kn([n11['attributes']],
  239. # [n21['attributes']]) * kn(
  240. # [n12['attributes']],
  241. # [n22['attributes']])
  242. # kn2 = kn([n11['attributes']],
  243. # [n22['attributes']]) * kn(
  244. # [n12['attributes']],
  245. # [n21['attributes']])
  246. # Kmatrix[i][j] += kn1 + kn2
  247. # except KeyError: # missing attributes
  248. # pass
  249. # Kmatrix[j][i] = Kmatrix[i][j]
  250. # pbar.update(1)
  251. # # node unlabeled
  252. # else:
  253. # for i, j in combinations_with_replacement(range(0, len(Gn)), 2):
  254. # for e1, e2 in product(
  255. # Gn[i].edges(data=True), Gn[j].edges(data=True)):
  256. # if e1[2]['cost'] == e2[2]['cost']:
  257. # Kmatrix[i][j] += 1
  258. # Kmatrix[j][i] = Kmatrix[i][j]
  259. # pbar.update(1)
  260. run_time = time.time() - start_time
  261. print(
  262. "\n --- shortest path kernel matrix of size %d built in %s seconds ---"
  263. % (len(Gn), run_time))
  264. return Kmatrix, run_time, idx
  265. def spkernel_do(Gn, ds_attrs, node_label, node_kernels, ij):
  266. i = ij[0]
  267. j = ij[1]
  268. Kmatrix = 0
  269. if ds_attrs['node_labeled']:
  270. # node symb and non-synb labeled
  271. if ds_attrs['node_attr_dim'] > 0:
  272. if ds_attrs['is_directed']:
  273. for e1, e2 in product(
  274. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  275. if e1[2]['cost'] == e2[2]['cost']:
  276. kn = node_kernels['mix']
  277. try:
  278. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  279. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  280. j].nodes[e2[1]]
  281. kn1 = kn(
  282. n11[node_label], n21[node_label],
  283. [n11['attributes']], [n21['attributes']]) * kn(
  284. n12[node_label], n22[node_label],
  285. [n12['attributes']], [n22['attributes']])
  286. Kmatrix += kn1
  287. except KeyError: # missing labels or attributes
  288. pass
  289. else:
  290. for e1, e2 in product(
  291. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  292. if e1[2]['cost'] == e2[2]['cost']:
  293. kn = node_kernels['mix']
  294. try:
  295. # each edge walk is counted twice, starting from both its extreme nodes.
  296. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  297. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  298. j].nodes[e2[1]]
  299. kn1 = kn(
  300. n11[node_label], n21[node_label],
  301. [n11['attributes']], [n21['attributes']]) * kn(
  302. n12[node_label], n22[node_label],
  303. [n12['attributes']], [n22['attributes']])
  304. kn2 = kn(
  305. n11[node_label], n22[node_label],
  306. [n11['attributes']], [n22['attributes']]) * kn(
  307. n12[node_label], n21[node_label],
  308. [n12['attributes']], [n21['attributes']])
  309. Kmatrix += kn1 + kn2
  310. except KeyError: # missing labels or attributes
  311. pass
  312. # node symb labeled
  313. else:
  314. if ds_attrs['is_directed']:
  315. for e1, e2 in product(
  316. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  317. if e1[2]['cost'] == e2[2]['cost']:
  318. kn = node_kernels['symb']
  319. try:
  320. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  321. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  322. j].nodes[e2[1]]
  323. kn1 = kn(n11[node_label], n21[node_label]) * kn(
  324. n12[node_label], n22[node_label])
  325. Kmatrix += kn1
  326. except KeyError: # missing labels
  327. pass
  328. else:
  329. for e1, e2 in product(
  330. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  331. if e1[2]['cost'] == e2[2]['cost']:
  332. kn = node_kernels['symb']
  333. try:
  334. # each edge walk is counted twice, starting from both its extreme nodes.
  335. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  336. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  337. j].nodes[e2[1]]
  338. kn1 = kn(n11[node_label], n21[node_label]) * kn(
  339. n12[node_label], n22[node_label])
  340. kn2 = kn(n11[node_label], n22[node_label]) * kn(
  341. n12[node_label], n21[node_label])
  342. Kmatrix += kn1 + kn2
  343. except KeyError: # missing labels
  344. pass
  345. else:
  346. # node non-synb labeled
  347. if ds_attrs['node_attr_dim'] > 0:
  348. if ds_attrs['is_directed']:
  349. for e1, e2 in product(
  350. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  351. if e1[2]['cost'] == e2[2]['cost']:
  352. kn = node_kernels['nsymb']
  353. try:
  354. # each edge walk is counted twice, starting from both its extreme nodes.
  355. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  356. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  357. j].nodes[e2[1]]
  358. kn1 = kn(
  359. [n11['attributes']], [n21['attributes']]) * kn(
  360. [n12['attributes']], [n22['attributes']])
  361. Kmatrix += kn1
  362. except KeyError: # missing attributes
  363. pass
  364. else:
  365. for e1, e2 in product(
  366. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  367. if e1[2]['cost'] == e2[2]['cost']:
  368. kn = node_kernels['nsymb']
  369. try:
  370. # each edge walk is counted twice, starting from both its extreme nodes.
  371. n11, n12, n21, n22 = Gn[i].nodes[e1[0]], Gn[
  372. i].nodes[e1[1]], Gn[j].nodes[e2[0]], Gn[
  373. j].nodes[e2[1]]
  374. kn1 = kn(
  375. [n11['attributes']], [n21['attributes']]) * kn(
  376. [n12['attributes']], [n22['attributes']])
  377. kn2 = kn(
  378. [n11['attributes']], [n22['attributes']]) * kn(
  379. [n12['attributes']], [n21['attributes']])
  380. Kmatrix += kn1 + kn2
  381. except KeyError: # missing attributes
  382. pass
  383. # node unlabeled
  384. else:
  385. for e1, e2 in product(
  386. Gn[i].edges(data=True), Gn[j].edges(data=True)):
  387. if e1[2]['cost'] == e2[2]['cost']:
  388. Kmatrix += 1
  389. return i, j, Kmatrix
  390. def wrap_getSPGraph(Gn, weight, i):
  391. return i, getSPGraph(Gn[i], edge_weight=weight)

A Python package for graph kernels, graph edit distances and graph pre-image problem.