You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_kernels.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. """Tests of graph kernels.
  2. """
  3. import pytest
  4. import multiprocessing
  5. import numpy as np
  6. ##############################################################################
  7. def test_list_graph_kernels():
  8. """
  9. """
  10. from gklearn.kernels import GRAPH_KERNELS, list_of_graph_kernels
  11. assert list_of_graph_kernels() == [i for i in GRAPH_KERNELS]
  12. ##############################################################################
  13. def chooseDataset(ds_name):
  14. """Choose dataset according to name.
  15. """
  16. from gklearn.dataset import Dataset
  17. import os
  18. current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
  19. root = current_path + '../../datasets/'
  20. # no node labels (and no edge labels).
  21. if ds_name == 'Alkane':
  22. dataset = Dataset('Alkane_unlabeled', root=root)
  23. dataset.trim_dataset(edge_required=False)
  24. dataset.cut_graphs(range(1, 10))
  25. # node symbolic labels.
  26. elif ds_name == 'Acyclic':
  27. dataset = Dataset('Acyclic', root=root)
  28. dataset.trim_dataset(edge_required=False)
  29. # node non-symbolic labels.
  30. elif ds_name == 'Letter-med':
  31. dataset = Dataset('Letter-med', root=root)
  32. dataset.trim_dataset(edge_required=False)
  33. # node symbolic and non-symbolic labels (and edge symbolic labels).
  34. elif ds_name == 'AIDS':
  35. dataset = Dataset('AIDS', root=root)
  36. dataset.trim_dataset(edge_required=False)
  37. # edge non-symbolic labels (no node labels).
  38. elif ds_name == 'Fingerprint_edge':
  39. dataset = Dataset('Fingerprint', root=root)
  40. dataset.trim_dataset(edge_required=True)
  41. irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
  42. dataset.remove_labels(**irrelevant_labels)
  43. # edge non-symbolic labels (and node non-symbolic labels).
  44. elif ds_name == 'Fingerprint':
  45. dataset = Dataset('Fingerprint', root=root)
  46. dataset.trim_dataset(edge_required=True)
  47. # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
  48. elif ds_name == 'Cuneiform':
  49. dataset = Dataset('Cuneiform', root=root)
  50. dataset.trim_dataset(edge_required=True)
  51. dataset.cut_graphs(range(0, 3))
  52. return dataset
  53. def assert_equality(compute_fun, **kwargs):
  54. """Check if outputs are the same using different methods to compute.
  55. Parameters
  56. ----------
  57. compute_fun : function
  58. The function to compute the kernel, with the same key word arguments as
  59. kwargs.
  60. **kwargs : dict
  61. The key word arguments over the grid of which the kernel results are
  62. compared.
  63. Returns
  64. -------
  65. None.
  66. """
  67. from sklearn.model_selection import ParameterGrid
  68. param_grid = ParameterGrid(kwargs)
  69. result_lists = [[], [], []]
  70. for params in list(param_grid):
  71. results = compute_fun(**params)
  72. for rs, lst in zip(results, result_lists):
  73. lst.append(rs)
  74. for lst in result_lists:
  75. for i in range(len(lst[:-1])):
  76. assert np.array_equal(lst[i], lst[i + 1])
  77. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  78. @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
  79. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  80. def test_CommonWalk(ds_name, weight, compute_method):
  81. """Test common walk kernel.
  82. """
  83. def compute(parallel=None):
  84. from gklearn.kernels import CommonWalk
  85. import networkx as nx
  86. dataset = chooseDataset(ds_name)
  87. dataset.load_graphs([g for g in dataset.graphs if nx.number_of_nodes(g) > 1])
  88. try:
  89. graph_kernel = CommonWalk(node_labels=dataset.node_labels,
  90. edge_labels=dataset.edge_labels,
  91. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  92. weight=weight,
  93. compute_method=compute_method)
  94. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  95. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  96. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  97. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  98. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  99. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  100. except Exception as exception:
  101. assert False, exception
  102. else:
  103. return gram_matrix, kernel_list, kernel
  104. assert_equality(compute, parallel=['imap_unordered', None])
  105. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  106. @pytest.mark.parametrize('remove_totters', [False]) #[True, False])
  107. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  108. def test_Marginalized(ds_name, remove_totters):
  109. """Test marginalized kernel.
  110. """
  111. def compute(parallel=None):
  112. from gklearn.kernels import Marginalized
  113. dataset = chooseDataset(ds_name)
  114. try:
  115. graph_kernel = Marginalized(node_labels=dataset.node_labels,
  116. edge_labels=dataset.edge_labels,
  117. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  118. p_quit=0.5,
  119. n_iteration=2,
  120. remove_totters=remove_totters)
  121. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  122. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  123. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  124. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  125. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  126. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  127. except Exception as exception:
  128. assert False, exception
  129. else:
  130. return gram_matrix, kernel_list, kernel
  131. assert_equality(compute, parallel=['imap_unordered', None])
  132. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  133. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  134. def test_SylvesterEquation(ds_name):
  135. """Test sylvester equation kernel.
  136. """
  137. def compute(parallel=None):
  138. from gklearn.kernels import SylvesterEquation
  139. dataset = chooseDataset(ds_name)
  140. try:
  141. graph_kernel = SylvesterEquation(
  142. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  143. weight=1e-3,
  144. p=None,
  145. q=None,
  146. edge_weight=None)
  147. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  148. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  149. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  150. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  151. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  152. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  153. except Exception as exception:
  154. assert False, exception
  155. else:
  156. return gram_matrix, kernel_list, kernel
  157. assert_equality(compute, parallel=['imap_unordered', None])
  158. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  159. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  160. def test_ConjugateGradient(ds_name):
  161. """Test conjugate gradient kernel.
  162. """
  163. def compute(parallel=None):
  164. from gklearn.kernels import ConjugateGradient
  165. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  166. import functools
  167. dataset = chooseDataset(ds_name)
  168. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  169. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  170. try:
  171. graph_kernel = ConjugateGradient(
  172. node_labels=dataset.node_labels,
  173. node_attrs=dataset.node_attrs,
  174. edge_labels=dataset.edge_labels,
  175. edge_attrs=dataset.edge_attrs,
  176. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  177. weight=1e-3,
  178. p=None,
  179. q=None,
  180. edge_weight=None,
  181. node_kernels=sub_kernels,
  182. edge_kernels=sub_kernels)
  183. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  184. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  185. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  186. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  187. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  188. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  189. except Exception as exception:
  190. assert False, exception
  191. else:
  192. return gram_matrix, kernel_list, kernel
  193. assert_equality(compute, parallel=['imap_unordered', None])
  194. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  195. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  196. def test_FixedPoint(ds_name):
  197. """Test fixed point kernel.
  198. """
  199. def compute(parallel=None):
  200. from gklearn.kernels import FixedPoint
  201. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  202. import functools
  203. dataset = chooseDataset(ds_name)
  204. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  205. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  206. try:
  207. graph_kernel = FixedPoint(
  208. node_labels=dataset.node_labels,
  209. node_attrs=dataset.node_attrs,
  210. edge_labels=dataset.edge_labels,
  211. edge_attrs=dataset.edge_attrs,
  212. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  213. weight=1e-3,
  214. p=None,
  215. q=None,
  216. edge_weight=None,
  217. node_kernels=sub_kernels,
  218. edge_kernels=sub_kernels)
  219. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  220. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  221. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  222. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  223. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  224. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  225. except Exception as exception:
  226. assert False, exception
  227. else:
  228. return gram_matrix, kernel_list, kernel
  229. assert_equality(compute, parallel=['imap_unordered', None])
  230. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  231. @pytest.mark.parametrize('sub_kernel', ['exp', 'geo'])
  232. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  233. def test_SpectralDecomposition(ds_name, sub_kernel):
  234. """Test spectral decomposition kernel.
  235. """
  236. def compute(parallel=None):
  237. from gklearn.kernels import SpectralDecomposition
  238. dataset = chooseDataset(ds_name)
  239. try:
  240. graph_kernel = SpectralDecomposition(
  241. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  242. weight=1e-3,
  243. p=None,
  244. q=None,
  245. edge_weight=None,
  246. sub_kernel=sub_kernel)
  247. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  248. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  249. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  250. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  251. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  252. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  253. except Exception as exception:
  254. assert False, exception
  255. else:
  256. return gram_matrix, kernel_list, kernel
  257. assert_equality(compute, parallel=['imap_unordered', None])
  258. # @pytest.mark.parametrize(
  259. # 'compute_method,ds_name,sub_kernel',
  260. # [
  261. # ('sylvester', 'Alkane', None),
  262. # ('conjugate', 'Alkane', None),
  263. # ('conjugate', 'AIDS', None),
  264. # ('fp', 'Alkane', None),
  265. # ('fp', 'AIDS', None),
  266. # ('spectral', 'Alkane', 'exp'),
  267. # ('spectral', 'Alkane', 'geo'),
  268. # ]
  269. # )
  270. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  271. # def test_RandomWalk(ds_name, compute_method, sub_kernel, parallel):
  272. # """Test random walk kernel.
  273. # """
  274. # from gklearn.kernels import RandomWalk
  275. # from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  276. # import functools
  277. #
  278. # dataset = chooseDataset(ds_name)
  279. # mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  280. # sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  281. # # try:
  282. # graph_kernel = RandomWalk(node_labels=dataset.node_labels,
  283. # node_attrs=dataset.node_attrs,
  284. # edge_labels=dataset.edge_labels,
  285. # edge_attrs=dataset.edge_attrs,
  286. # ds_infos=dataset.get_dataset_infos(keys=['directed']),
  287. # compute_method=compute_method,
  288. # weight=1e-3,
  289. # p=None,
  290. # q=None,
  291. # edge_weight=None,
  292. # node_kernels=sub_kernels,
  293. # edge_kernels=sub_kernels,
  294. # sub_kernel=sub_kernel)
  295. # gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  296. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  297. # kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  298. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  299. # kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  300. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  301. # except Exception as exception:
  302. # assert False, exception
  303. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  304. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  305. def test_ShortestPath(ds_name):
  306. """Test shortest path kernel.
  307. """
  308. def compute(parallel=None, fcsp=None):
  309. from gklearn.kernels import ShortestPath
  310. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  311. import functools
  312. dataset = chooseDataset(ds_name)
  313. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  314. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  315. try:
  316. graph_kernel = ShortestPath(node_labels=dataset.node_labels,
  317. node_attrs=dataset.node_attrs,
  318. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  319. fcsp=fcsp,
  320. node_kernels=sub_kernels)
  321. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  322. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  323. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  324. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  325. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  326. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  327. except Exception as exception:
  328. assert False, exception
  329. else:
  330. return gram_matrix, kernel_list, kernel
  331. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  332. #@pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  333. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint', 'Fingerprint_edge', 'Cuneiform'])
  334. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  335. def test_StructuralSP(ds_name):
  336. """Test structural shortest path kernel.
  337. """
  338. def compute(parallel=None, fcsp=None):
  339. from gklearn.kernels import StructuralSP
  340. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  341. import functools
  342. dataset = chooseDataset(ds_name)
  343. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  344. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  345. try:
  346. graph_kernel = StructuralSP(node_labels=dataset.node_labels,
  347. edge_labels=dataset.edge_labels,
  348. node_attrs=dataset.node_attrs,
  349. edge_attrs=dataset.edge_attrs,
  350. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  351. fcsp=fcsp,
  352. node_kernels=sub_kernels,
  353. edge_kernels=sub_kernels)
  354. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  355. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  356. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  357. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  358. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  359. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  360. except Exception as exception:
  361. assert False, exception
  362. else:
  363. return gram_matrix, kernel_list, kernel
  364. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  365. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  366. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  367. #@pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto', None])
  368. @pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto'])
  369. # @pytest.mark.parametrize('compute_method', ['trie', 'naive'])
  370. def test_PathUpToH(ds_name, k_func):
  371. """Test path kernel up to length $h$.
  372. """
  373. def compute(parallel=None, compute_method=None):
  374. from gklearn.kernels import PathUpToH
  375. dataset = chooseDataset(ds_name)
  376. try:
  377. graph_kernel = PathUpToH(node_labels=dataset.node_labels,
  378. edge_labels=dataset.edge_labels,
  379. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  380. depth=2, k_func=k_func, compute_method=compute_method)
  381. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  382. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  383. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  384. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  385. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  386. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  387. except Exception as exception:
  388. assert False, exception
  389. else:
  390. return gram_matrix, kernel_list, kernel
  391. assert_equality(compute, parallel=['imap_unordered', None],
  392. compute_method=['trie', 'naive'])
  393. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  394. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  395. def test_Treelet(ds_name):
  396. """Test treelet kernel.
  397. """
  398. def compute(parallel=None):
  399. from gklearn.kernels import Treelet
  400. from gklearn.utils.kernels import polynomialkernel
  401. import functools
  402. dataset = chooseDataset(ds_name)
  403. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  404. try:
  405. graph_kernel = Treelet(node_labels=dataset.node_labels,
  406. edge_labels=dataset.edge_labels,
  407. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  408. sub_kernel=pkernel)
  409. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  410. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  411. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  412. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  413. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  414. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  415. except Exception as exception:
  416. assert False, exception
  417. else:
  418. return gram_matrix, kernel_list, kernel
  419. assert_equality(compute, parallel=['imap_unordered', None])
  420. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  421. #@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge'])
  422. # @pytest.mark.parametrize('base_kernel', ['subtree'])
  423. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  424. def test_WLSubtree(ds_name):
  425. """Test Weisfeiler-Lehman subtree kernel.
  426. """
  427. def compute(parallel=None):
  428. from gklearn.kernels import WLSubtree
  429. dataset = chooseDataset(ds_name)
  430. try:
  431. graph_kernel = WLSubtree(node_labels=dataset.node_labels,
  432. edge_labels=dataset.edge_labels,
  433. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  434. height=2)
  435. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  436. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  437. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  438. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  439. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  440. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  441. except Exception as exception:
  442. assert False, exception
  443. else:
  444. return gram_matrix, kernel_list, kernel
  445. assert_equality(compute, parallel=['imap_unordered', None])
  446. if __name__ == "__main__":
  447. test_list_graph_kernels()
  448. # test_spkernel('Alkane', 'imap_unordered')
  449. # test_ShortestPath('Alkane')
  450. # test_StructuralSP('Fingerprint_edge', 'imap_unordered')
  451. # test_StructuralSP('Alkane', None)
  452. # test_StructuralSP('Cuneiform', None)
  453. # test_WLSubtree('Acyclic', 'imap_unordered')
  454. # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered')
  455. # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered')
  456. # test_RandomWalk('Acyclic', 'fp', None, None)
  457. # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
  458. # test_CommonWalk('AIDS', 0.01, 'geo')
  459. # test_ShortestPath('Acyclic')

A Python package for graph kernels, graph edit distances and graph pre-image problem.