You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

weisfeilerLehmanKernel.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. """
  2. @author: linlin
  3. @references:
  4. [1] Shervashidze N, Schweitzer P, Leeuwen EJ, Mehlhorn K, Borgwardt KM.
  5. Weisfeiler-lehman graph kernels. Journal of Machine Learning Research.
  6. 2011;12(Sep):2539-61.
  7. """
  8. import sys
  9. from collections import Counter
  10. from functools import partial
  11. import time
  12. #from multiprocessing import Pool
  13. from tqdm import tqdm
  14. import networkx as nx
  15. import numpy as np
  16. #from gklearn.kernels.pathKernel import pathkernel
  17. from gklearn.utils.graphdataset import get_dataset_attributes
  18. from gklearn.utils.parallel import parallel_gm
  19. # @todo: support edge kernel, sp kernel, user-defined kernel.
  20. def weisfeilerlehmankernel(*args,
  21. node_label='atom',
  22. edge_label='bond_type',
  23. height=0,
  24. base_kernel='subtree',
  25. parallel=None,
  26. n_jobs=None,
  27. chunksize=None,
  28. verbose=True):
  29. """Calculate Weisfeiler-Lehman kernels between graphs.
  30. Parameters
  31. ----------
  32. Gn : List of NetworkX graph
  33. List of graphs between which the kernels are calculated.
  34. G1, G2 : NetworkX graphs
  35. Two graphs between which the kernel is calculated.
  36. node_label : string
  37. Node attribute used as label. The default node label is atom.
  38. edge_label : string
  39. Edge attribute used as label. The default edge label is bond_type.
  40. height : int
  41. Subtree height.
  42. base_kernel : string
  43. Base kernel used in each iteration of WL kernel. Only default 'subtree'
  44. kernel can be applied for now.
  45. parallel : None
  46. Which paralleliztion method is applied to compute the kernel. No
  47. parallelization can be applied for now.
  48. n_jobs : int
  49. Number of jobs for parallelization. The default is to use all
  50. computational cores. This argument is only valid when one of the
  51. parallelization method is applied and can be ignored for now.
  52. Return
  53. ------
  54. Kmatrix : Numpy matrix
  55. Kernel matrix, each element of which is the Weisfeiler-Lehman kernel between 2 praphs.
  56. Notes
  57. -----
  58. This function now supports WL subtree kernel only.
  59. """
  60. # The default base
  61. # kernel is subtree kernel. For user-defined kernel, base_kernel is the
  62. # name of the base kernel function used in each iteration of WL kernel.
  63. # This function returns a Numpy matrix, each element of which is the
  64. # user-defined Weisfeiler-Lehman kernel between 2 praphs.
  65. # pre-process
  66. base_kernel = base_kernel.lower()
  67. Gn = args[0] if len(args) == 1 else [args[0], args[1]] # arrange all graphs in a list
  68. Gn = [g.copy() for g in Gn]
  69. ds_attrs = get_dataset_attributes(Gn, attr_names=['node_labeled'],
  70. node_label=node_label)
  71. if not ds_attrs['node_labeled']:
  72. for G in Gn:
  73. nx.set_node_attributes(G, '0', 'atom')
  74. start_time = time.time()
  75. # for WL subtree kernel
  76. if base_kernel == 'subtree':
  77. Kmatrix = _wl_kernel_do(Gn, node_label, edge_label, height, parallel, n_jobs, chunksize, verbose)
  78. # for WL shortest path kernel
  79. elif base_kernel == 'sp':
  80. Kmatrix = _wl_spkernel_do(Gn, node_label, edge_label, height)
  81. # for WL edge kernel
  82. elif base_kernel == 'edge':
  83. Kmatrix = _wl_edgekernel_do(Gn, node_label, edge_label, height)
  84. # for user defined base kernel
  85. else:
  86. Kmatrix = _wl_userkernel_do(Gn, node_label, edge_label, height, base_kernel)
  87. run_time = time.time() - start_time
  88. if verbose:
  89. print("\n --- Weisfeiler-Lehman %s kernel matrix of size %d built in %s seconds ---"
  90. % (base_kernel, len(args[0]), run_time))
  91. return Kmatrix, run_time
  92. def _wl_kernel_do(Gn, node_label, edge_label, height, parallel, n_jobs, chunksize, verbose):
  93. """Calculate Weisfeiler-Lehman kernels between graphs.
  94. Parameters
  95. ----------
  96. Gn : List of NetworkX graph
  97. List of graphs between which the kernels are calculated.
  98. node_label : string
  99. node attribute used as label.
  100. edge_label : string
  101. edge attribute used as label.
  102. height : int
  103. wl height.
  104. Return
  105. ------
  106. Kmatrix : Numpy matrix
  107. Kernel matrix, each element of which is the Weisfeiler-Lehman kernel between 2 praphs.
  108. """
  109. height = int(height)
  110. Kmatrix = np.zeros((len(Gn), len(Gn)))
  111. # initial for height = 0
  112. all_num_of_each_label = [] # number of occurence of each label in each graph in this iteration
  113. # for each graph
  114. for G in Gn:
  115. # get the set of original labels
  116. labels_ori = list(nx.get_node_attributes(G, node_label).values())
  117. # number of occurence of each label in G
  118. all_num_of_each_label.append(dict(Counter(labels_ori)))
  119. # calculate subtree kernel with the 0th iteration and add it to the final kernel
  120. compute_kernel_matrix(Kmatrix, all_num_of_each_label, Gn, parallel, n_jobs, chunksize, False)
  121. # iterate each height
  122. for h in range(1, height + 1):
  123. all_set_compressed = {} # a dictionary mapping original labels to new ones in all graphs in this iteration
  124. num_of_labels_occured = 0 # number of the set of letters that occur before as node labels at least once in all graphs
  125. # all_labels_ori = set() # all unique orignal labels in all graphs in this iteration
  126. all_num_of_each_label = [] # number of occurence of each label in G
  127. # # for each graph
  128. # # ---- use pool.imap_unordered to parallel and track progress. ----
  129. # pool = Pool(n_jobs)
  130. # itr = zip(Gn, range(0, len(Gn)))
  131. # if len(Gn) < 100 * n_jobs:
  132. # chunksize = int(len(Gn) / n_jobs) + 1
  133. # else:
  134. # chunksize = 100
  135. # all_multisets_list = [[] for _ in range(len(Gn))]
  136. ## set_unique_list = [[] for _ in range(len(Gn))]
  137. # get_partial = partial(wrapper_wl_iteration, node_label)
  138. ## if verbose:
  139. ## iterator = tqdm(pool.imap_unordered(get_partial, itr, chunksize),
  140. ## desc='wl iteration', file=sys.stdout)
  141. ## else:
  142. # iterator = pool.imap_unordered(get_partial, itr, chunksize)
  143. # for i, all_multisets in iterator:
  144. # all_multisets_list[i] = all_multisets
  145. ## set_unique_list[i] = set_unique
  146. ## all_set_unique = all_set_unique | set(set_unique)
  147. # pool.close()
  148. # pool.join()
  149. # all_set_unique = set()
  150. # for uset in all_multisets_list:
  151. # all_set_unique = all_set_unique | set(uset)
  152. #
  153. # all_set_unique = list(all_set_unique)
  154. ## # a dictionary mapping original labels to new ones.
  155. ## set_compressed = {}
  156. ## for idx, uset in enumerate(all_set_unique):
  157. ## set_compressed.update({uset: idx})
  158. #
  159. # for ig, G in enumerate(Gn):
  160. #
  161. ## # a dictionary mapping original labels to new ones.
  162. ## set_compressed = {}
  163. ## # if a label occured before, assign its former compressed label,
  164. ## # else assign the number of labels occured + 1 as the compressed label.
  165. ## for value in set_unique_list[i]:
  166. ## if uset in all_set_unique:
  167. ## set_compressed.update({uset: all_set_compressed[value]})
  168. ## else:
  169. ## set_compressed.update({value: str(num_of_labels_occured + 1)})
  170. ## num_of_labels_occured += 1
  171. #
  172. ## all_set_compressed.update(set_compressed)
  173. #
  174. # # relabel nodes
  175. # for idx, node in enumerate(G.nodes()):
  176. # G.nodes[node][node_label] = all_set_unique.index(all_multisets_list[ig][idx])
  177. #
  178. # # get the set of compressed labels
  179. # labels_comp = list(nx.get_node_attributes(G, node_label).values())
  180. ## all_labels_ori.update(labels_comp)
  181. # all_num_of_each_label[ig] = dict(Counter(labels_comp))
  182. # all_set_unique = list(all_set_unique)
  183. # @todo: parallel this part.
  184. for idx, G in enumerate(Gn):
  185. all_multisets = []
  186. for node, attrs in G.nodes(data=True):
  187. # Multiset-label determination.
  188. multiset = [G.nodes[neighbors][node_label] for neighbors in G[node]]
  189. # sorting each multiset
  190. multiset.sort()
  191. multiset = [attrs[node_label]] + multiset # add the prefix
  192. all_multisets.append(tuple(multiset))
  193. # label compression
  194. set_unique = list(set(all_multisets)) # set of unique multiset labels
  195. # a dictionary mapping original labels to new ones.
  196. set_compressed = {}
  197. # if a label occured before, assign its former compressed label,
  198. # else assign the number of labels occured + 1 as the compressed label.
  199. for value in set_unique:
  200. if value in all_set_compressed.keys():
  201. set_compressed.update({value: all_set_compressed[value]})
  202. else:
  203. set_compressed.update({value: str(num_of_labels_occured + 1)})
  204. num_of_labels_occured += 1
  205. all_set_compressed.update(set_compressed)
  206. # relabel nodes
  207. for idx, node in enumerate(G.nodes()):
  208. G.nodes[node][node_label] = set_compressed[all_multisets[idx]]
  209. # get the set of compressed labels
  210. labels_comp = list(nx.get_node_attributes(G, node_label).values())
  211. # all_labels_ori.update(labels_comp)
  212. all_num_of_each_label.append(dict(Counter(labels_comp)))
  213. # calculate subtree kernel with h iterations and add it to the final kernel
  214. compute_kernel_matrix(Kmatrix, all_num_of_each_label, Gn, parallel, n_jobs, chunksize, False)
  215. return Kmatrix
  216. def wl_iteration(G, node_label):
  217. all_multisets = []
  218. for node, attrs in G.nodes(data=True):
  219. # Multiset-label determination.
  220. multiset = [G.nodes[neighbors][node_label] for neighbors in G[node]]
  221. # sorting each multiset
  222. multiset.sort()
  223. multiset = [attrs[node_label]] + multiset # add the prefix
  224. all_multisets.append(tuple(multiset))
  225. # # label compression
  226. # set_unique = list(set(all_multisets)) # set of unique multiset labels
  227. return all_multisets
  228. # # a dictionary mapping original labels to new ones.
  229. # set_compressed = {}
  230. # # if a label occured before, assign its former compressed label,
  231. # # else assign the number of labels occured + 1 as the compressed label.
  232. # for value in set_unique:
  233. # if value in all_set_compressed.keys():
  234. # set_compressed.update({value: all_set_compressed[value]})
  235. # else:
  236. # set_compressed.update({value: str(num_of_labels_occured + 1)})
  237. # num_of_labels_occured += 1
  238. #
  239. # all_set_compressed.update(set_compressed)
  240. #
  241. # # relabel nodes
  242. # for idx, node in enumerate(G.nodes()):
  243. # G.nodes[node][node_label] = set_compressed[all_multisets[idx]]
  244. #
  245. # # get the set of compressed labels
  246. # labels_comp = list(nx.get_node_attributes(G, node_label).values())
  247. # all_labels_ori.update(labels_comp)
  248. # all_num_of_each_label.append(dict(Counter(labels_comp)))
  249. # return
  250. def wrapper_wl_iteration(node_label, itr_item):
  251. g = itr_item[0]
  252. i = itr_item[1]
  253. all_multisets = wl_iteration(g, node_label)
  254. return i, all_multisets
  255. def compute_kernel_matrix(Kmatrix, all_num_of_each_label, Gn, parallel, n_jobs, chunksize, verbose):
  256. """Compute kernel matrix using the base kernel.
  257. """
  258. if parallel == 'imap_unordered':
  259. # compute kernels.
  260. def init_worker(alllabels_toshare):
  261. global G_alllabels
  262. G_alllabels = alllabels_toshare
  263. do_partial = partial(wrapper_compute_subtree_kernel, Kmatrix)
  264. parallel_gm(do_partial, Kmatrix, Gn, init_worker=init_worker,
  265. glbv=(all_num_of_each_label,), n_jobs=n_jobs, chunksize=chunksize, verbose=verbose)
  266. elif parallel == None:
  267. for i in range(len(Kmatrix)):
  268. for j in range(i, len(Kmatrix)):
  269. Kmatrix[i][j] = compute_subtree_kernel(all_num_of_each_label[i],
  270. all_num_of_each_label[j], Kmatrix[i][j])
  271. Kmatrix[j][i] = Kmatrix[i][j]
  272. def compute_subtree_kernel(num_of_each_label1, num_of_each_label2, kernel):
  273. """Compute the subtree kernel.
  274. """
  275. labels = set(list(num_of_each_label1.keys()) + list(num_of_each_label2.keys()))
  276. vector1 = np.array([(num_of_each_label1[label]
  277. if (label in num_of_each_label1.keys()) else 0)
  278. for label in labels])
  279. vector2 = np.array([(num_of_each_label2[label]
  280. if (label in num_of_each_label2.keys()) else 0)
  281. for label in labels])
  282. kernel += np.dot(vector1, vector2)
  283. return kernel
  284. def wrapper_compute_subtree_kernel(Kmatrix, itr):
  285. i = itr[0]
  286. j = itr[1]
  287. return i, j, compute_subtree_kernel(G_alllabels[i], G_alllabels[j], Kmatrix[i][j])
  288. def _wl_spkernel_do(Gn, node_label, edge_label, height):
  289. """Calculate Weisfeiler-Lehman shortest path kernels between graphs.
  290. Parameters
  291. ----------
  292. Gn : List of NetworkX graph
  293. List of graphs between which the kernels are calculated.
  294. node_label : string
  295. node attribute used as label.
  296. edge_label : string
  297. edge attribute used as label.
  298. height : int
  299. subtree height.
  300. Return
  301. ------
  302. Kmatrix : Numpy matrix
  303. Kernel matrix, each element of which is the Weisfeiler-Lehman kernel between 2 praphs.
  304. """
  305. pass
  306. from gklearn.utils.utils import getSPGraph
  307. # init.
  308. height = int(height)
  309. Kmatrix = np.zeros((len(Gn), len(Gn))) # init kernel
  310. Gn = [ getSPGraph(G, edge_weight = edge_label) for G in Gn ] # get shortest path graphs of Gn
  311. # initial for height = 0
  312. for i in range(0, len(Gn)):
  313. for j in range(i, len(Gn)):
  314. for e1 in Gn[i].edges(data = True):
  315. for e2 in Gn[j].edges(data = True):
  316. if e1[2]['cost'] != 0 and e1[2]['cost'] == e2[2]['cost'] and ((e1[0] == e2[0] and e1[1] == e2[1]) or (e1[0] == e2[1] and e1[1] == e2[0])):
  317. Kmatrix[i][j] += 1
  318. Kmatrix[j][i] = Kmatrix[i][j]
  319. # iterate each height
  320. for h in range(1, height + 1):
  321. all_set_compressed = {} # a dictionary mapping original labels to new ones in all graphs in this iteration
  322. num_of_labels_occured = 0 # number of the set of letters that occur before as node labels at least once in all graphs
  323. for G in Gn: # for each graph
  324. set_multisets = []
  325. for node in G.nodes(data = True):
  326. # Multiset-label determination.
  327. multiset = [ G.node[neighbors][node_label] for neighbors in G[node[0]] ]
  328. # sorting each multiset
  329. multiset.sort()
  330. multiset = node[1][node_label] + ''.join(multiset) # concatenate to a string and add the prefix
  331. set_multisets.append(multiset)
  332. # label compression
  333. set_unique = list(set(set_multisets)) # set of unique multiset labels
  334. # a dictionary mapping original labels to new ones.
  335. set_compressed = {}
  336. # if a label occured before, assign its former compressed label, else assign the number of labels occured + 1 as the compressed label
  337. for value in set_unique:
  338. if value in all_set_compressed.keys():
  339. set_compressed.update({ value : all_set_compressed[value] })
  340. else:
  341. set_compressed.update({ value : str(num_of_labels_occured + 1) })
  342. num_of_labels_occured += 1
  343. all_set_compressed.update(set_compressed)
  344. # relabel nodes
  345. for node in G.nodes(data = True):
  346. node[1][node_label] = set_compressed[set_multisets[node[0]]]
  347. # calculate subtree kernel with h iterations and add it to the final kernel
  348. for i in range(0, len(Gn)):
  349. for j in range(i, len(Gn)):
  350. for e1 in Gn[i].edges(data = True):
  351. for e2 in Gn[j].edges(data = True):
  352. if e1[2]['cost'] != 0 and e1[2]['cost'] == e2[2]['cost'] and ((e1[0] == e2[0] and e1[1] == e2[1]) or (e1[0] == e2[1] and e1[1] == e2[0])):
  353. Kmatrix[i][j] += 1
  354. Kmatrix[j][i] = Kmatrix[i][j]
  355. return Kmatrix
  356. def _wl_edgekernel_do(Gn, node_label, edge_label, height):
  357. """Calculate Weisfeiler-Lehman edge kernels between graphs.
  358. Parameters
  359. ----------
  360. Gn : List of NetworkX graph
  361. List of graphs between which the kernels are calculated.
  362. node_label : string
  363. node attribute used as label.
  364. edge_label : string
  365. edge attribute used as label.
  366. height : int
  367. subtree height.
  368. Return
  369. ------
  370. Kmatrix : Numpy matrix
  371. Kernel matrix, each element of which is the Weisfeiler-Lehman kernel between 2 praphs.
  372. """
  373. pass
  374. # init.
  375. height = int(height)
  376. Kmatrix = np.zeros((len(Gn), len(Gn))) # init kernel
  377. # initial for height = 0
  378. for i in range(0, len(Gn)):
  379. for j in range(i, len(Gn)):
  380. for e1 in Gn[i].edges(data = True):
  381. for e2 in Gn[j].edges(data = True):
  382. if e1[2][edge_label] == e2[2][edge_label] and ((e1[0] == e2[0] and e1[1] == e2[1]) or (e1[0] == e2[1] and e1[1] == e2[0])):
  383. Kmatrix[i][j] += 1
  384. Kmatrix[j][i] = Kmatrix[i][j]
  385. # iterate each height
  386. for h in range(1, height + 1):
  387. all_set_compressed = {} # a dictionary mapping original labels to new ones in all graphs in this iteration
  388. num_of_labels_occured = 0 # number of the set of letters that occur before as node labels at least once in all graphs
  389. for G in Gn: # for each graph
  390. set_multisets = []
  391. for node in G.nodes(data = True):
  392. # Multiset-label determination.
  393. multiset = [ G.node[neighbors][node_label] for neighbors in G[node[0]] ]
  394. # sorting each multiset
  395. multiset.sort()
  396. multiset = node[1][node_label] + ''.join(multiset) # concatenate to a string and add the prefix
  397. set_multisets.append(multiset)
  398. # label compression
  399. set_unique = list(set(set_multisets)) # set of unique multiset labels
  400. # a dictionary mapping original labels to new ones.
  401. set_compressed = {}
  402. # if a label occured before, assign its former compressed label, else assign the number of labels occured + 1 as the compressed label
  403. for value in set_unique:
  404. if value in all_set_compressed.keys():
  405. set_compressed.update({ value : all_set_compressed[value] })
  406. else:
  407. set_compressed.update({ value : str(num_of_labels_occured + 1) })
  408. num_of_labels_occured += 1
  409. all_set_compressed.update(set_compressed)
  410. # relabel nodes
  411. for node in G.nodes(data = True):
  412. node[1][node_label] = set_compressed[set_multisets[node[0]]]
  413. # calculate subtree kernel with h iterations and add it to the final kernel
  414. for i in range(0, len(Gn)):
  415. for j in range(i, len(Gn)):
  416. for e1 in Gn[i].edges(data = True):
  417. for e2 in Gn[j].edges(data = True):
  418. if e1[2][edge_label] == e2[2][edge_label] and ((e1[0] == e2[0] and e1[1] == e2[1]) or (e1[0] == e2[1] and e1[1] == e2[0])):
  419. Kmatrix[i][j] += 1
  420. Kmatrix[j][i] = Kmatrix[i][j]
  421. return Kmatrix
  422. def _wl_userkernel_do(Gn, node_label, edge_label, height, base_kernel):
  423. """Calculate Weisfeiler-Lehman kernels based on user-defined kernel between graphs.
  424. Parameters
  425. ----------
  426. Gn : List of NetworkX graph
  427. List of graphs between which the kernels are calculated.
  428. node_label : string
  429. node attribute used as label.
  430. edge_label : string
  431. edge attribute used as label.
  432. height : int
  433. subtree height.
  434. base_kernel : string
  435. Name of the base kernel function used in each iteration of WL kernel. This function returns a Numpy matrix, each element of which is the user-defined Weisfeiler-Lehman kernel between 2 praphs.
  436. Return
  437. ------
  438. Kmatrix : Numpy matrix
  439. Kernel matrix, each element of which is the Weisfeiler-Lehman kernel between 2 praphs.
  440. """
  441. pass
  442. # init.
  443. height = int(height)
  444. Kmatrix = np.zeros((len(Gn), len(Gn))) # init kernel
  445. # initial for height = 0
  446. Kmatrix = base_kernel(Gn, node_label, edge_label)
  447. # iterate each height
  448. for h in range(1, height + 1):
  449. all_set_compressed = {} # a dictionary mapping original labels to new ones in all graphs in this iteration
  450. num_of_labels_occured = 0 # number of the set of letters that occur before as node labels at least once in all graphs
  451. for G in Gn: # for each graph
  452. set_multisets = []
  453. for node in G.nodes(data = True):
  454. # Multiset-label determination.
  455. multiset = [ G.node[neighbors][node_label] for neighbors in G[node[0]] ]
  456. # sorting each multiset
  457. multiset.sort()
  458. multiset = node[1][node_label] + ''.join(multiset) # concatenate to a string and add the prefix
  459. set_multisets.append(multiset)
  460. # label compression
  461. set_unique = list(set(set_multisets)) # set of unique multiset labels
  462. # a dictionary mapping original labels to new ones.
  463. set_compressed = {}
  464. # if a label occured before, assign its former compressed label, else assign the number of labels occured + 1 as the compressed label
  465. for value in set_unique:
  466. if value in all_set_compressed.keys():
  467. set_compressed.update({ value : all_set_compressed[value] })
  468. else:
  469. set_compressed.update({ value : str(num_of_labels_occured + 1) })
  470. num_of_labels_occured += 1
  471. all_set_compressed.update(set_compressed)
  472. # relabel nodes
  473. for node in G.nodes(data = True):
  474. node[1][node_label] = set_compressed[set_multisets[node[0]]]
  475. # calculate kernel with h iterations and add it to the final kernel
  476. Kmatrix += base_kernel(Gn, node_label, edge_label)
  477. return Kmatrix

A Python package for graph kernels, graph edit distances and graph pre-image problem.