You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_kernels.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. """Tests of graph kernels.
  2. """
  3. import pytest
  4. import multiprocessing
  5. import numpy as np
  6. ##############################################################################
  7. def test_list_graph_kernels():
  8. """
  9. """
  10. from gklearn.kernels import GRAPH_KERNELS, list_of_graph_kernels
  11. assert list_of_graph_kernels() == [i for i in GRAPH_KERNELS]
  12. ##############################################################################
  13. def chooseDataset(ds_name):
  14. """Choose dataset according to name.
  15. """
  16. from gklearn.utils import Dataset
  17. dataset = Dataset()
  18. # no node labels (and no edge labels).
  19. if ds_name == 'Alkane':
  20. dataset.load_predefined_dataset(ds_name)
  21. dataset.trim_dataset(edge_required=False)
  22. irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
  23. dataset.remove_labels(**irrelevant_labels)
  24. dataset.cut_graphs(range(1, 10))
  25. # node symbolic labels.
  26. elif ds_name == 'Acyclic':
  27. dataset.load_predefined_dataset(ds_name)
  28. dataset.trim_dataset(edge_required=False)
  29. irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
  30. dataset.remove_labels(**irrelevant_labels)
  31. # node non-symbolic labels.
  32. elif ds_name == 'Letter-med':
  33. dataset.load_predefined_dataset(ds_name)
  34. dataset.trim_dataset(edge_required=False)
  35. # node symbolic and non-symbolic labels (and edge symbolic labels).
  36. elif ds_name == 'AIDS':
  37. dataset.load_predefined_dataset(ds_name)
  38. dataset.trim_dataset(edge_required=False)
  39. # edge non-symbolic labels (no node labels).
  40. elif ds_name == 'Fingerprint_edge':
  41. dataset.load_predefined_dataset('Fingerprint')
  42. dataset.trim_dataset(edge_required=True)
  43. irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
  44. dataset.remove_labels(**irrelevant_labels)
  45. # edge non-symbolic labels (and node non-symbolic labels).
  46. elif ds_name == 'Fingerprint':
  47. dataset.load_predefined_dataset(ds_name)
  48. dataset.trim_dataset(edge_required=True)
  49. # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
  50. elif ds_name == 'Cuneiform':
  51. dataset.load_predefined_dataset(ds_name)
  52. dataset.trim_dataset(edge_required=True)
  53. dataset.cut_graphs(range(0, 3))
  54. return dataset
  55. def assert_equality(compute_fun, **kwargs):
  56. """Check if outputs are the same using different methods to compute.
  57. Parameters
  58. ----------
  59. compute_fun : function
  60. The function to compute the kernel, with the same key word arguments as
  61. kwargs.
  62. **kwargs : dict
  63. The key word arguments over the grid of which the kernel results are
  64. compared.
  65. Returns
  66. -------
  67. None.
  68. """
  69. from sklearn.model_selection import ParameterGrid
  70. param_grid = ParameterGrid(kwargs)
  71. result_lists = [[], [], []]
  72. for params in list(param_grid):
  73. results = compute_fun(**params)
  74. for rs, lst in zip(results, result_lists):
  75. lst.append(rs)
  76. for lst in result_lists:
  77. for i in range(len(lst[:-1])):
  78. assert np.array_equal(lst[i], lst[i + 1])
  79. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  80. @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
  81. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  82. def test_CommonWalk(ds_name, weight, compute_method):
  83. """Test common walk kernel.
  84. """
  85. def compute(parallel=None):
  86. from gklearn.kernels import CommonWalk
  87. import networkx as nx
  88. dataset = chooseDataset(ds_name)
  89. dataset.load_graphs([g for g in dataset.graphs if nx.number_of_nodes(g) > 1])
  90. try:
  91. graph_kernel = CommonWalk(node_labels=dataset.node_labels,
  92. edge_labels=dataset.edge_labels,
  93. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  94. weight=weight,
  95. compute_method=compute_method)
  96. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  97. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  98. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  99. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  100. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  101. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  102. except Exception as exception:
  103. assert False, exception
  104. else:
  105. return gram_matrix, kernel_list, kernel
  106. assert_equality(compute, parallel=['imap_unordered', None])
  107. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  108. @pytest.mark.parametrize('remove_totters', [False]) #[True, False])
  109. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  110. def test_Marginalized(ds_name, remove_totters):
  111. """Test marginalized kernel.
  112. """
  113. def compute(parallel=None):
  114. from gklearn.kernels import Marginalized
  115. dataset = chooseDataset(ds_name)
  116. try:
  117. graph_kernel = Marginalized(node_labels=dataset.node_labels,
  118. edge_labels=dataset.edge_labels,
  119. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  120. p_quit=0.5,
  121. n_iteration=2,
  122. remove_totters=remove_totters)
  123. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  124. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  125. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  126. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  127. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  128. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  129. except Exception as exception:
  130. assert False, exception
  131. else:
  132. return gram_matrix, kernel_list, kernel
  133. assert_equality(compute, parallel=['imap_unordered', None])
  134. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  135. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  136. def test_SylvesterEquation(ds_name):
  137. """Test sylvester equation kernel.
  138. """
  139. def compute(parallel=None):
  140. from gklearn.kernels import SylvesterEquation
  141. dataset = chooseDataset(ds_name)
  142. try:
  143. graph_kernel = SylvesterEquation(
  144. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  145. weight=1e-3,
  146. p=None,
  147. q=None,
  148. edge_weight=None)
  149. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  150. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  151. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  152. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  153. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  154. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  155. except Exception as exception:
  156. assert False, exception
  157. else:
  158. return gram_matrix, kernel_list, kernel
  159. assert_equality(compute, parallel=['imap_unordered', None])
  160. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  161. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  162. def test_ConjugateGradient(ds_name):
  163. """Test conjugate gradient kernel.
  164. """
  165. def compute(parallel=None):
  166. from gklearn.kernels import ConjugateGradient
  167. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  168. import functools
  169. dataset = chooseDataset(ds_name)
  170. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  171. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  172. try:
  173. graph_kernel = ConjugateGradient(
  174. node_labels=dataset.node_labels,
  175. node_attrs=dataset.node_attrs,
  176. edge_labels=dataset.edge_labels,
  177. edge_attrs=dataset.edge_attrs,
  178. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  179. weight=1e-3,
  180. p=None,
  181. q=None,
  182. edge_weight=None,
  183. node_kernels=sub_kernels,
  184. edge_kernels=sub_kernels)
  185. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  186. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  187. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  188. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  189. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  190. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  191. except Exception as exception:
  192. assert False, exception
  193. else:
  194. return gram_matrix, kernel_list, kernel
  195. assert_equality(compute, parallel=['imap_unordered', None])
  196. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  197. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  198. def test_FixedPoint(ds_name):
  199. """Test fixed point kernel.
  200. """
  201. def compute(parallel=None):
  202. from gklearn.kernels import FixedPoint
  203. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  204. import functools
  205. dataset = chooseDataset(ds_name)
  206. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  207. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  208. try:
  209. graph_kernel = FixedPoint(
  210. node_labels=dataset.node_labels,
  211. node_attrs=dataset.node_attrs,
  212. edge_labels=dataset.edge_labels,
  213. edge_attrs=dataset.edge_attrs,
  214. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  215. weight=1e-3,
  216. p=None,
  217. q=None,
  218. edge_weight=None,
  219. node_kernels=sub_kernels,
  220. edge_kernels=sub_kernels)
  221. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  222. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  223. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  224. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  225. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  226. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  227. except Exception as exception:
  228. assert False, exception
  229. else:
  230. return gram_matrix, kernel_list, kernel
  231. assert_equality(compute, parallel=['imap_unordered', None])
  232. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  233. @pytest.mark.parametrize('sub_kernel', ['exp', 'geo'])
  234. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  235. def test_SpectralDecomposition(ds_name, sub_kernel):
  236. """Test spectral decomposition kernel.
  237. """
  238. def compute(parallel=None):
  239. from gklearn.kernels import SpectralDecomposition
  240. dataset = chooseDataset(ds_name)
  241. try:
  242. graph_kernel = SpectralDecomposition(
  243. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  244. weight=1e-3,
  245. p=None,
  246. q=None,
  247. edge_weight=None,
  248. sub_kernel=sub_kernel)
  249. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  250. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  251. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  252. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  253. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  254. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  255. except Exception as exception:
  256. assert False, exception
  257. else:
  258. return gram_matrix, kernel_list, kernel
  259. assert_equality(compute, parallel=['imap_unordered', None])
  260. # @pytest.mark.parametrize(
  261. # 'compute_method,ds_name,sub_kernel',
  262. # [
  263. # ('sylvester', 'Alkane', None),
  264. # ('conjugate', 'Alkane', None),
  265. # ('conjugate', 'AIDS', None),
  266. # ('fp', 'Alkane', None),
  267. # ('fp', 'AIDS', None),
  268. # ('spectral', 'Alkane', 'exp'),
  269. # ('spectral', 'Alkane', 'geo'),
  270. # ]
  271. # )
  272. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  273. # def test_RandomWalk(ds_name, compute_method, sub_kernel, parallel):
  274. # """Test random walk kernel.
  275. # """
  276. # from gklearn.kernels import RandomWalk
  277. # from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  278. # import functools
  279. #
  280. # dataset = chooseDataset(ds_name)
  281. # mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  282. # sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  283. # # try:
  284. # graph_kernel = RandomWalk(node_labels=dataset.node_labels,
  285. # node_attrs=dataset.node_attrs,
  286. # edge_labels=dataset.edge_labels,
  287. # edge_attrs=dataset.edge_attrs,
  288. # ds_infos=dataset.get_dataset_infos(keys=['directed']),
  289. # compute_method=compute_method,
  290. # weight=1e-3,
  291. # p=None,
  292. # q=None,
  293. # edge_weight=None,
  294. # node_kernels=sub_kernels,
  295. # edge_kernels=sub_kernels,
  296. # sub_kernel=sub_kernel)
  297. # gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  298. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  299. # kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  300. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  301. # kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  302. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  303. # except Exception as exception:
  304. # assert False, exception
  305. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  306. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  307. def test_ShortestPath(ds_name):
  308. """Test shortest path kernel.
  309. """
  310. def compute(parallel=None, fcsp=None):
  311. from gklearn.kernels import ShortestPath
  312. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  313. import functools
  314. dataset = chooseDataset(ds_name)
  315. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  316. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  317. try:
  318. graph_kernel = ShortestPath(node_labels=dataset.node_labels,
  319. node_attrs=dataset.node_attrs,
  320. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  321. fcsp=fcsp,
  322. node_kernels=sub_kernels)
  323. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  324. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  325. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  326. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  327. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  328. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  329. except Exception as exception:
  330. assert False, exception
  331. else:
  332. return gram_matrix, kernel_list, kernel
  333. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  334. #@pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  335. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint', 'Fingerprint_edge', 'Cuneiform'])
  336. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  337. def test_StructuralSP(ds_name):
  338. """Test structural shortest path kernel.
  339. """
  340. def compute(parallel=None, fcsp=None):
  341. from gklearn.kernels import StructuralSP
  342. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  343. import functools
  344. dataset = chooseDataset(ds_name)
  345. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  346. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  347. try:
  348. graph_kernel = StructuralSP(node_labels=dataset.node_labels,
  349. edge_labels=dataset.edge_labels,
  350. node_attrs=dataset.node_attrs,
  351. edge_attrs=dataset.edge_attrs,
  352. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  353. fcsp=fcsp,
  354. node_kernels=sub_kernels,
  355. edge_kernels=sub_kernels)
  356. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  357. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  358. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  359. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  360. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  361. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  362. except Exception as exception:
  363. assert False, exception
  364. else:
  365. return gram_matrix, kernel_list, kernel
  366. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  367. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  368. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  369. #@pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto', None])
  370. @pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto'])
  371. # @pytest.mark.parametrize('compute_method', ['trie', 'naive'])
  372. def test_PathUpToH(ds_name, k_func):
  373. """Test path kernel up to length $h$.
  374. """
  375. def compute(parallel=None, compute_method=None):
  376. from gklearn.kernels import PathUpToH
  377. dataset = chooseDataset(ds_name)
  378. try:
  379. graph_kernel = PathUpToH(node_labels=dataset.node_labels,
  380. edge_labels=dataset.edge_labels,
  381. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  382. depth=2, k_func=k_func, compute_method=compute_method)
  383. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  384. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  385. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  386. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  387. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  388. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  389. except Exception as exception:
  390. assert False, exception
  391. else:
  392. return gram_matrix, kernel_list, kernel
  393. assert_equality(compute, parallel=['imap_unordered', None],
  394. compute_method=['trie', 'naive'])
  395. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  396. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  397. def test_Treelet(ds_name):
  398. """Test treelet kernel.
  399. """
  400. def compute(parallel=None):
  401. from gklearn.kernels import Treelet
  402. from gklearn.utils.kernels import polynomialkernel
  403. import functools
  404. dataset = chooseDataset(ds_name)
  405. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  406. try:
  407. graph_kernel = Treelet(node_labels=dataset.node_labels,
  408. edge_labels=dataset.edge_labels,
  409. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  410. sub_kernel=pkernel)
  411. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  412. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  413. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  414. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  415. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  416. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  417. except Exception as exception:
  418. assert False, exception
  419. else:
  420. return gram_matrix, kernel_list, kernel
  421. assert_equality(compute, parallel=['imap_unordered', None])
  422. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  423. #@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge'])
  424. # @pytest.mark.parametrize('base_kernel', ['subtree'])
  425. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  426. def test_WLSubtree(ds_name):
  427. """Test Weisfeiler-Lehman subtree kernel.
  428. """
  429. def compute(parallel=None):
  430. from gklearn.kernels import WLSubtree
  431. dataset = chooseDataset(ds_name)
  432. try:
  433. graph_kernel = WLSubtree(node_labels=dataset.node_labels,
  434. edge_labels=dataset.edge_labels,
  435. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  436. height=2)
  437. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  438. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  439. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  440. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  441. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  442. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  443. except Exception as exception:
  444. assert False, exception
  445. else:
  446. return gram_matrix, kernel_list, kernel
  447. assert_equality(compute, parallel=['imap_unordered', None])
  448. if __name__ == "__main__":
  449. test_list_graph_kernels()
  450. # test_spkernel('Alkane', 'imap_unordered')
  451. # test_ShortestPath('Alkane')
  452. # test_StructuralSP('Fingerprint_edge', 'imap_unordered')
  453. # test_StructuralSP('Alkane', None)
  454. # test_StructuralSP('Cuneiform', None)
  455. # test_WLSubtree('Acyclic', 'imap_unordered')
  456. # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered')
  457. # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered')
  458. # test_RandomWalk('Acyclic', 'fp', None, None)
  459. # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
  460. # test_CommonWalk('Alkane', 0.01, 'geo')
  461. # test_ShortestPath('Acyclic')

A Python package for graph kernels, graph edit distances and graph pre-image problem.