You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_kernels.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. """Tests of graph kernels.
  2. """
  3. import pytest
  4. import multiprocessing
  5. import numpy as np
  6. ##############################################################################
  7. def test_list_graph_kernels():
  8. """
  9. """
  10. from gklearn.kernels import GRAPH_KERNELS, list_of_graph_kernels
  11. assert list_of_graph_kernels() == [i for i in GRAPH_KERNELS]
  12. ##############################################################################
  13. def chooseDataset(ds_name):
  14. """Choose dataset according to name.
  15. """
  16. from gklearn.dataset import Dataset
  17. root = '../../datasets/'
  18. # no node labels (and no edge labels).
  19. if ds_name == 'Alkane':
  20. dataset = Dataset('Alkane_unlabeled', root=root)
  21. dataset.trim_dataset(edge_required=False)
  22. dataset.cut_graphs(range(1, 10))
  23. # node symbolic labels.
  24. elif ds_name == 'Acyclic':
  25. dataset = Dataset('Acyclic', root=root)
  26. dataset.trim_dataset(edge_required=False)
  27. # node non-symbolic labels.
  28. elif ds_name == 'Letter-med':
  29. dataset = Dataset('Letter-med', root=root)
  30. dataset.trim_dataset(edge_required=False)
  31. # node symbolic and non-symbolic labels (and edge symbolic labels).
  32. elif ds_name == 'AIDS':
  33. dataset = Dataset('AIDS', root=root)
  34. dataset.trim_dataset(edge_required=False)
  35. # edge non-symbolic labels (no node labels).
  36. elif ds_name == 'Fingerprint_edge':
  37. dataset = Dataset('Fingerprint', root=root)
  38. dataset.trim_dataset(edge_required=True)
  39. irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
  40. dataset.remove_labels(**irrelevant_labels)
  41. # edge non-symbolic labels (and node non-symbolic labels).
  42. elif ds_name == 'Fingerprint':
  43. dataset = Dataset('Fingerprint', root=root)
  44. dataset.trim_dataset(edge_required=True)
  45. # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
  46. elif ds_name == 'Cuneiform':
  47. dataset = Dataset('Cuneiform', root=root)
  48. dataset.trim_dataset(edge_required=True)
  49. dataset.cut_graphs(range(0, 3))
  50. return dataset
  51. def assert_equality(compute_fun, **kwargs):
  52. """Check if outputs are the same using different methods to compute.
  53. Parameters
  54. ----------
  55. compute_fun : function
  56. The function to compute the kernel, with the same key word arguments as
  57. kwargs.
  58. **kwargs : dict
  59. The key word arguments over the grid of which the kernel results are
  60. compared.
  61. Returns
  62. -------
  63. None.
  64. """
  65. from sklearn.model_selection import ParameterGrid
  66. param_grid = ParameterGrid(kwargs)
  67. result_lists = [[], [], []]
  68. for params in list(param_grid):
  69. results = compute_fun(**params)
  70. for rs, lst in zip(results, result_lists):
  71. lst.append(rs)
  72. for lst in result_lists:
  73. for i in range(len(lst[:-1])):
  74. assert np.array_equal(lst[i], lst[i + 1])
  75. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  76. @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
  77. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  78. def test_CommonWalk(ds_name, weight, compute_method):
  79. """Test common walk kernel.
  80. """
  81. def compute(parallel=None):
  82. from gklearn.kernels import CommonWalk
  83. import networkx as nx
  84. dataset = chooseDataset(ds_name)
  85. dataset.load_graphs([g for g in dataset.graphs if nx.number_of_nodes(g) > 1])
  86. try:
  87. graph_kernel = CommonWalk(node_labels=dataset.node_labels,
  88. edge_labels=dataset.edge_labels,
  89. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  90. weight=weight,
  91. compute_method=compute_method)
  92. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  93. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  94. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  95. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  96. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  97. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  98. except Exception as exception:
  99. assert False, exception
  100. else:
  101. return gram_matrix, kernel_list, kernel
  102. assert_equality(compute, parallel=['imap_unordered', None])
  103. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  104. @pytest.mark.parametrize('remove_totters', [False]) #[True, False])
  105. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  106. def test_Marginalized(ds_name, remove_totters):
  107. """Test marginalized kernel.
  108. """
  109. def compute(parallel=None):
  110. from gklearn.kernels import Marginalized
  111. dataset = chooseDataset(ds_name)
  112. try:
  113. graph_kernel = Marginalized(node_labels=dataset.node_labels,
  114. edge_labels=dataset.edge_labels,
  115. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  116. p_quit=0.5,
  117. n_iteration=2,
  118. remove_totters=remove_totters)
  119. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  120. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  121. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  122. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  123. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  124. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  125. except Exception as exception:
  126. assert False, exception
  127. else:
  128. return gram_matrix, kernel_list, kernel
  129. assert_equality(compute, parallel=['imap_unordered', None])
  130. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  131. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  132. def test_SylvesterEquation(ds_name):
  133. """Test sylvester equation kernel.
  134. """
  135. def compute(parallel=None):
  136. from gklearn.kernels import SylvesterEquation
  137. dataset = chooseDataset(ds_name)
  138. try:
  139. graph_kernel = SylvesterEquation(
  140. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  141. weight=1e-3,
  142. p=None,
  143. q=None,
  144. edge_weight=None)
  145. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  146. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  147. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  148. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  149. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  150. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  151. except Exception as exception:
  152. assert False, exception
  153. else:
  154. return gram_matrix, kernel_list, kernel
  155. assert_equality(compute, parallel=['imap_unordered', None])
  156. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  157. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  158. def test_ConjugateGradient(ds_name):
  159. """Test conjugate gradient kernel.
  160. """
  161. def compute(parallel=None):
  162. from gklearn.kernels import ConjugateGradient
  163. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  164. import functools
  165. dataset = chooseDataset(ds_name)
  166. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  167. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  168. try:
  169. graph_kernel = ConjugateGradient(
  170. node_labels=dataset.node_labels,
  171. node_attrs=dataset.node_attrs,
  172. edge_labels=dataset.edge_labels,
  173. edge_attrs=dataset.edge_attrs,
  174. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  175. weight=1e-3,
  176. p=None,
  177. q=None,
  178. edge_weight=None,
  179. node_kernels=sub_kernels,
  180. edge_kernels=sub_kernels)
  181. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  182. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  183. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  184. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  185. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  186. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  187. except Exception as exception:
  188. assert False, exception
  189. else:
  190. return gram_matrix, kernel_list, kernel
  191. assert_equality(compute, parallel=['imap_unordered', None])
  192. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  193. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  194. def test_FixedPoint(ds_name):
  195. """Test fixed point kernel.
  196. """
  197. def compute(parallel=None):
  198. from gklearn.kernels import FixedPoint
  199. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  200. import functools
  201. dataset = chooseDataset(ds_name)
  202. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  203. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  204. try:
  205. graph_kernel = FixedPoint(
  206. node_labels=dataset.node_labels,
  207. node_attrs=dataset.node_attrs,
  208. edge_labels=dataset.edge_labels,
  209. edge_attrs=dataset.edge_attrs,
  210. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  211. weight=1e-3,
  212. p=None,
  213. q=None,
  214. edge_weight=None,
  215. node_kernels=sub_kernels,
  216. edge_kernels=sub_kernels)
  217. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  218. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  219. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  220. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  221. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  222. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  223. except Exception as exception:
  224. assert False, exception
  225. else:
  226. return gram_matrix, kernel_list, kernel
  227. assert_equality(compute, parallel=['imap_unordered', None])
  228. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  229. @pytest.mark.parametrize('sub_kernel', ['exp', 'geo'])
  230. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  231. def test_SpectralDecomposition(ds_name, sub_kernel):
  232. """Test spectral decomposition kernel.
  233. """
  234. def compute(parallel=None):
  235. from gklearn.kernels import SpectralDecomposition
  236. dataset = chooseDataset(ds_name)
  237. try:
  238. graph_kernel = SpectralDecomposition(
  239. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  240. weight=1e-3,
  241. p=None,
  242. q=None,
  243. edge_weight=None,
  244. sub_kernel=sub_kernel)
  245. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  246. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  247. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  248. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  249. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  250. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  251. except Exception as exception:
  252. assert False, exception
  253. else:
  254. return gram_matrix, kernel_list, kernel
  255. assert_equality(compute, parallel=['imap_unordered', None])
  256. # @pytest.mark.parametrize(
  257. # 'compute_method,ds_name,sub_kernel',
  258. # [
  259. # ('sylvester', 'Alkane', None),
  260. # ('conjugate', 'Alkane', None),
  261. # ('conjugate', 'AIDS', None),
  262. # ('fp', 'Alkane', None),
  263. # ('fp', 'AIDS', None),
  264. # ('spectral', 'Alkane', 'exp'),
  265. # ('spectral', 'Alkane', 'geo'),
  266. # ]
  267. # )
  268. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  269. # def test_RandomWalk(ds_name, compute_method, sub_kernel, parallel):
  270. # """Test random walk kernel.
  271. # """
  272. # from gklearn.kernels import RandomWalk
  273. # from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  274. # import functools
  275. #
  276. # dataset = chooseDataset(ds_name)
  277. # mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  278. # sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  279. # # try:
  280. # graph_kernel = RandomWalk(node_labels=dataset.node_labels,
  281. # node_attrs=dataset.node_attrs,
  282. # edge_labels=dataset.edge_labels,
  283. # edge_attrs=dataset.edge_attrs,
  284. # ds_infos=dataset.get_dataset_infos(keys=['directed']),
  285. # compute_method=compute_method,
  286. # weight=1e-3,
  287. # p=None,
  288. # q=None,
  289. # edge_weight=None,
  290. # node_kernels=sub_kernels,
  291. # edge_kernels=sub_kernels,
  292. # sub_kernel=sub_kernel)
  293. # gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  294. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  295. # kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  296. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  297. # kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  298. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  299. # except Exception as exception:
  300. # assert False, exception
  301. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  302. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  303. def test_ShortestPath(ds_name):
  304. """Test shortest path kernel.
  305. """
  306. def compute(parallel=None, fcsp=None):
  307. from gklearn.kernels import ShortestPath
  308. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  309. import functools
  310. dataset = chooseDataset(ds_name)
  311. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  312. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  313. try:
  314. graph_kernel = ShortestPath(node_labels=dataset.node_labels,
  315. node_attrs=dataset.node_attrs,
  316. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  317. fcsp=fcsp,
  318. node_kernels=sub_kernels)
  319. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  320. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  321. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  322. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  323. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  324. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  325. except Exception as exception:
  326. assert False, exception
  327. else:
  328. return gram_matrix, kernel_list, kernel
  329. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  330. #@pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  331. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint', 'Fingerprint_edge', 'Cuneiform'])
  332. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  333. def test_StructuralSP(ds_name):
  334. """Test structural shortest path kernel.
  335. """
  336. def compute(parallel=None, fcsp=None):
  337. from gklearn.kernels import StructuralSP
  338. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  339. import functools
  340. dataset = chooseDataset(ds_name)
  341. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  342. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  343. try:
  344. graph_kernel = StructuralSP(node_labels=dataset.node_labels,
  345. edge_labels=dataset.edge_labels,
  346. node_attrs=dataset.node_attrs,
  347. edge_attrs=dataset.edge_attrs,
  348. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  349. fcsp=fcsp,
  350. node_kernels=sub_kernels,
  351. edge_kernels=sub_kernels)
  352. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  353. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  354. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  355. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  356. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  357. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  358. except Exception as exception:
  359. assert False, exception
  360. else:
  361. return gram_matrix, kernel_list, kernel
  362. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  363. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  364. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  365. #@pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto', None])
  366. @pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto'])
  367. # @pytest.mark.parametrize('compute_method', ['trie', 'naive'])
  368. def test_PathUpToH(ds_name, k_func):
  369. """Test path kernel up to length $h$.
  370. """
  371. def compute(parallel=None, compute_method=None):
  372. from gklearn.kernels import PathUpToH
  373. dataset = chooseDataset(ds_name)
  374. try:
  375. graph_kernel = PathUpToH(node_labels=dataset.node_labels,
  376. edge_labels=dataset.edge_labels,
  377. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  378. depth=2, k_func=k_func, compute_method=compute_method)
  379. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  380. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  381. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  382. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  383. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  384. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  385. except Exception as exception:
  386. assert False, exception
  387. else:
  388. return gram_matrix, kernel_list, kernel
  389. assert_equality(compute, parallel=['imap_unordered', None],
  390. compute_method=['trie', 'naive'])
  391. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  392. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  393. def test_Treelet(ds_name):
  394. """Test treelet kernel.
  395. """
  396. def compute(parallel=None):
  397. from gklearn.kernels import Treelet
  398. from gklearn.utils.kernels import polynomialkernel
  399. import functools
  400. dataset = chooseDataset(ds_name)
  401. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  402. try:
  403. graph_kernel = Treelet(node_labels=dataset.node_labels,
  404. edge_labels=dataset.edge_labels,
  405. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  406. sub_kernel=pkernel)
  407. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  408. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  409. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  410. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  411. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  412. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  413. except Exception as exception:
  414. assert False, exception
  415. else:
  416. return gram_matrix, kernel_list, kernel
  417. assert_equality(compute, parallel=['imap_unordered', None])
  418. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  419. #@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge'])
  420. # @pytest.mark.parametrize('base_kernel', ['subtree'])
  421. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  422. def test_WLSubtree(ds_name):
  423. """Test Weisfeiler-Lehman subtree kernel.
  424. """
  425. def compute(parallel=None):
  426. from gklearn.kernels import WLSubtree
  427. dataset = chooseDataset(ds_name)
  428. try:
  429. graph_kernel = WLSubtree(node_labels=dataset.node_labels,
  430. edge_labels=dataset.edge_labels,
  431. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  432. height=2)
  433. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  434. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  435. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  436. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  437. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  438. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  439. except Exception as exception:
  440. assert False, exception
  441. else:
  442. return gram_matrix, kernel_list, kernel
  443. assert_equality(compute, parallel=['imap_unordered', None])
  444. if __name__ == "__main__":
  445. test_list_graph_kernels()
  446. # test_spkernel('Alkane', 'imap_unordered')
  447. # test_ShortestPath('Alkane')
  448. # test_StructuralSP('Fingerprint_edge', 'imap_unordered')
  449. # test_StructuralSP('Alkane', None)
  450. # test_StructuralSP('Cuneiform', None)
  451. # test_WLSubtree('Acyclic', 'imap_unordered')
  452. # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered')
  453. # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered')
  454. # test_RandomWalk('Acyclic', 'fp', None, None)
  455. # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
  456. # test_CommonWalk('AIDS', 0.01, 'geo')
  457. # test_ShortestPath('Acyclic')

A Python package for graph kernels, graph edit distances and graph pre-image problem.