You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_kernels.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. """Tests of graph kernels.
  2. """
  3. import pytest
  4. import multiprocessing
  5. import numpy as np
  6. ##############################################################################
  7. def test_list_graph_kernels():
  8. """
  9. """
  10. from gklearn.kernels import GRAPH_KERNELS, list_of_graph_kernels
  11. assert list_of_graph_kernels() == [i for i in GRAPH_KERNELS]
  12. ##############################################################################
  13. def chooseDataset(ds_name):
  14. """Choose dataset according to name.
  15. """
  16. from gklearn.dataset import Dataset
  17. import os
  18. current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
  19. root = current_path + '../../datasets/'
  20. # no node labels (and no edge labels).
  21. if ds_name == 'Alkane':
  22. dataset = Dataset('Alkane_unlabeled', root=root)
  23. dataset.trim_dataset(edge_required=False)
  24. dataset.cut_graphs(range(1, 10))
  25. # node symbolic labels.
  26. elif ds_name == 'Acyclic':
  27. dataset = Dataset('Acyclic', root=root)
  28. dataset.trim_dataset(edge_required=False)
  29. # node non-symbolic labels.
  30. elif ds_name == 'Letter-med':
  31. dataset = Dataset('Letter-med', root=root)
  32. dataset.trim_dataset(edge_required=False)
  33. # node symbolic and non-symbolic labels (and edge symbolic labels).
  34. elif ds_name == 'AIDS':
  35. dataset = Dataset('AIDS', root=root)
  36. dataset.trim_dataset(edge_required=False)
  37. # edge non-symbolic labels (no node labels).
  38. elif ds_name == 'Fingerprint_edge':
  39. dataset = Dataset('Fingerprint', root=root)
  40. dataset.trim_dataset(edge_required=True)
  41. irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
  42. dataset.remove_labels(**irrelevant_labels)
  43. # edge non-symbolic labels (and node non-symbolic labels).
  44. elif ds_name == 'Fingerprint':
  45. dataset = Dataset('Fingerprint', root=root)
  46. dataset.trim_dataset(edge_required=True)
  47. # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
  48. elif ds_name == 'Cuneiform':
  49. dataset = Dataset('Cuneiform', root=root)
  50. dataset.trim_dataset(edge_required=True)
  51. dataset.cut_graphs(range(0, 3))
  52. return dataset
  53. def assert_equality(compute_fun, **kwargs):
  54. """Check if outputs are the same using different methods to compute.
  55. Parameters
  56. ----------
  57. compute_fun : function
  58. The function to compute the kernel, with the same key word arguments as
  59. kwargs.
  60. **kwargs : dict
  61. The key word arguments over the grid of which the kernel results are
  62. compared.
  63. Returns
  64. -------
  65. None.
  66. """
  67. from sklearn.model_selection import ParameterGrid
  68. param_grid = ParameterGrid(kwargs)
  69. result_lists = [[], [], []]
  70. for params in list(param_grid):
  71. results = compute_fun(**params)
  72. for rs, lst in zip(results, result_lists):
  73. lst.append(rs)
  74. for lst in result_lists:
  75. for i in range(len(lst[:-1])):
  76. assert np.array_equal(lst[i], lst[i + 1])
  77. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  78. @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
  79. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  80. def test_CommonWalk(ds_name, weight, compute_method):
  81. """Test common walk kernel.
  82. """
  83. def compute(parallel=None):
  84. from gklearn.kernels import CommonWalk
  85. import networkx as nx
  86. dataset = chooseDataset(ds_name)
  87. dataset.load_graphs([g for g in dataset.graphs if nx.number_of_nodes(g) > 1])
  88. try:
  89. graph_kernel = CommonWalk(node_labels=dataset.node_labels,
  90. edge_labels=dataset.edge_labels,
  91. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  92. weight=weight,
  93. compute_method=compute_method)
  94. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  95. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  96. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  97. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  98. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  99. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  100. except Exception as exception:
  101. print(repr(exception))
  102. assert False, exception
  103. else:
  104. return gram_matrix, kernel_list, kernel
  105. assert_equality(compute, parallel=['imap_unordered', None])
  106. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  107. @pytest.mark.parametrize('remove_totters', [False]) #[True, False])
  108. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  109. def test_Marginalized(ds_name, remove_totters):
  110. """Test marginalized kernel.
  111. """
  112. def compute(parallel=None):
  113. from gklearn.kernels import Marginalized
  114. dataset = chooseDataset(ds_name)
  115. try:
  116. graph_kernel = Marginalized(node_labels=dataset.node_labels,
  117. edge_labels=dataset.edge_labels,
  118. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  119. p_quit=0.5,
  120. n_iteration=2,
  121. remove_totters=remove_totters)
  122. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  123. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  124. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  125. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  126. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  127. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  128. except Exception as exception:
  129. print(repr(exception))
  130. assert False, exception
  131. else:
  132. return gram_matrix, kernel_list, kernel
  133. assert_equality(compute, parallel=['imap_unordered', None])
  134. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  135. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  136. def test_SylvesterEquation(ds_name):
  137. """Test sylvester equation kernel.
  138. """
  139. def compute(parallel=None):
  140. from gklearn.kernels import SylvesterEquation
  141. dataset = chooseDataset(ds_name)
  142. try:
  143. graph_kernel = SylvesterEquation(
  144. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  145. weight=1e-3,
  146. p=None,
  147. q=None,
  148. edge_weight=None)
  149. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  150. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  151. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  152. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  153. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  154. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  155. except Exception as exception:
  156. print(repr(exception))
  157. assert False, exception
  158. else:
  159. return gram_matrix, kernel_list, kernel
  160. assert_equality(compute, parallel=['imap_unordered', None])
  161. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  162. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  163. def test_ConjugateGradient(ds_name):
  164. """Test conjugate gradient kernel.
  165. """
  166. def compute(parallel=None):
  167. from gklearn.kernels import ConjugateGradient
  168. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  169. import functools
  170. dataset = chooseDataset(ds_name)
  171. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  172. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  173. try:
  174. graph_kernel = ConjugateGradient(
  175. node_labels=dataset.node_labels,
  176. node_attrs=dataset.node_attrs,
  177. edge_labels=dataset.edge_labels,
  178. edge_attrs=dataset.edge_attrs,
  179. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  180. weight=1e-3,
  181. p=None,
  182. q=None,
  183. edge_weight=None,
  184. node_kernels=sub_kernels,
  185. edge_kernels=sub_kernels)
  186. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  187. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  188. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  189. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  190. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  191. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  192. except Exception as exception:
  193. print(repr(exception))
  194. assert False, exception
  195. else:
  196. return gram_matrix, kernel_list, kernel
  197. assert_equality(compute, parallel=['imap_unordered', None])
  198. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  199. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  200. def test_FixedPoint(ds_name):
  201. """Test fixed point kernel.
  202. """
  203. def compute(parallel=None):
  204. from gklearn.kernels import FixedPoint
  205. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  206. import functools
  207. dataset = chooseDataset(ds_name)
  208. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  209. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  210. try:
  211. graph_kernel = FixedPoint(
  212. node_labels=dataset.node_labels,
  213. node_attrs=dataset.node_attrs,
  214. edge_labels=dataset.edge_labels,
  215. edge_attrs=dataset.edge_attrs,
  216. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  217. weight=1e-3,
  218. p=None,
  219. q=None,
  220. edge_weight=None,
  221. node_kernels=sub_kernels,
  222. edge_kernels=sub_kernels)
  223. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  224. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  225. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  226. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  227. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  228. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  229. except Exception as exception:
  230. print(repr(exception))
  231. assert False, exception
  232. else:
  233. return gram_matrix, kernel_list, kernel
  234. assert_equality(compute, parallel=['imap_unordered', None])
  235. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  236. @pytest.mark.parametrize('sub_kernel', ['exp', 'geo'])
  237. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  238. def test_SpectralDecomposition(ds_name, sub_kernel):
  239. """Test spectral decomposition kernel.
  240. """
  241. def compute(parallel=None):
  242. from gklearn.kernels import SpectralDecomposition
  243. dataset = chooseDataset(ds_name)
  244. try:
  245. graph_kernel = SpectralDecomposition(
  246. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  247. weight=1e-3,
  248. p=None,
  249. q=None,
  250. edge_weight=None,
  251. sub_kernel=sub_kernel)
  252. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  253. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  254. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  255. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  256. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  257. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  258. except Exception as exception:
  259. print(repr(exception))
  260. assert False, exception
  261. else:
  262. return gram_matrix, kernel_list, kernel
  263. assert_equality(compute, parallel=['imap_unordered', None])
  264. # @pytest.mark.parametrize(
  265. # 'compute_method,ds_name,sub_kernel',
  266. # [
  267. # ('sylvester', 'Alkane', None),
  268. # ('conjugate', 'Alkane', None),
  269. # ('conjugate', 'AIDS', None),
  270. # ('fp', 'Alkane', None),
  271. # ('fp', 'AIDS', None),
  272. # ('spectral', 'Alkane', 'exp'),
  273. # ('spectral', 'Alkane', 'geo'),
  274. # ]
  275. # )
  276. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  277. # def test_RandomWalk(ds_name, compute_method, sub_kernel, parallel):
  278. # """Test random walk kernel.
  279. # """
  280. # from gklearn.kernels import RandomWalk
  281. # from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  282. # import functools
  283. #
  284. # dataset = chooseDataset(ds_name)
  285. # mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  286. # sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  287. # # try:
  288. # graph_kernel = RandomWalk(node_labels=dataset.node_labels,
  289. # node_attrs=dataset.node_attrs,
  290. # edge_labels=dataset.edge_labels,
  291. # edge_attrs=dataset.edge_attrs,
  292. # ds_infos=dataset.get_dataset_infos(keys=['directed']),
  293. # compute_method=compute_method,
  294. # weight=1e-3,
  295. # p=None,
  296. # q=None,
  297. # edge_weight=None,
  298. # node_kernels=sub_kernels,
  299. # edge_kernels=sub_kernels,
  300. # sub_kernel=sub_kernel)
  301. # gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  302. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  303. # kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  304. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  305. # kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  306. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  307. # except Exception as exception:
  308. # assert False, exception
  309. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  310. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  311. def test_ShortestPath(ds_name):
  312. """Test shortest path kernel.
  313. """
  314. def compute(parallel=None, fcsp=None):
  315. from gklearn.kernels import ShortestPath
  316. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  317. import functools
  318. dataset = chooseDataset(ds_name)
  319. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  320. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  321. try:
  322. graph_kernel = ShortestPath(node_labels=dataset.node_labels,
  323. node_attrs=dataset.node_attrs,
  324. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  325. fcsp=fcsp,
  326. node_kernels=sub_kernels)
  327. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  328. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  329. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  330. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  331. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  332. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  333. except Exception as exception:
  334. print(repr(exception))
  335. assert False, exception
  336. else:
  337. return gram_matrix, kernel_list, kernel
  338. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  339. #@pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  340. @pytest.mark.parametrize('ds_name', ['Alkane', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint', 'Fingerprint_edge', 'Cuneiform'])
  341. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  342. def test_StructuralSP(ds_name):
  343. """Test structural shortest path kernel.
  344. """
  345. def compute(parallel=None, fcsp=None):
  346. from gklearn.kernels import StructuralSP
  347. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  348. import functools
  349. dataset = chooseDataset(ds_name)
  350. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  351. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  352. try:
  353. graph_kernel = StructuralSP(node_labels=dataset.node_labels,
  354. edge_labels=dataset.edge_labels,
  355. node_attrs=dataset.node_attrs,
  356. edge_attrs=dataset.edge_attrs,
  357. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  358. fcsp=fcsp,
  359. node_kernels=sub_kernels,
  360. edge_kernels=sub_kernels)
  361. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  362. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  363. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  364. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  365. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  366. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  367. except Exception as exception:
  368. print(repr(exception))
  369. assert False, exception
  370. else:
  371. return gram_matrix, kernel_list, kernel
  372. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  373. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  374. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  375. #@pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto', None])
  376. @pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto'])
  377. # @pytest.mark.parametrize('compute_method', ['trie', 'naive'])
  378. def test_PathUpToH(ds_name, k_func):
  379. """Test path kernel up to length $h$.
  380. """
  381. def compute(parallel=None, compute_method=None):
  382. from gklearn.kernels import PathUpToH
  383. dataset = chooseDataset(ds_name)
  384. try:
  385. graph_kernel = PathUpToH(node_labels=dataset.node_labels,
  386. edge_labels=dataset.edge_labels,
  387. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  388. depth=2, k_func=k_func, compute_method=compute_method)
  389. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  390. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  391. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  392. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  393. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  394. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  395. except Exception as exception:
  396. print(repr(exception))
  397. assert False, exception
  398. else:
  399. return gram_matrix, kernel_list, kernel
  400. assert_equality(compute, parallel=['imap_unordered', None],
  401. compute_method=['trie', 'naive'])
  402. @pytest.mark.parametrize('ds_name', ['Alkane', 'AIDS'])
  403. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  404. def test_Treelet(ds_name):
  405. """Test treelet kernel.
  406. """
  407. def compute(parallel=None):
  408. from gklearn.kernels import Treelet
  409. from gklearn.utils.kernels import polynomialkernel
  410. import functools
  411. dataset = chooseDataset(ds_name)
  412. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  413. try:
  414. graph_kernel = Treelet(node_labels=dataset.node_labels,
  415. edge_labels=dataset.edge_labels,
  416. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  417. sub_kernel=pkernel)
  418. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  419. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  420. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  421. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  422. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  423. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  424. except Exception as exception:
  425. print(repr(exception))
  426. assert False, exception
  427. else:
  428. return gram_matrix, kernel_list, kernel
  429. assert_equality(compute, parallel=['imap_unordered', None])
  430. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  431. #@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge'])
  432. # @pytest.mark.parametrize('base_kernel', ['subtree'])
  433. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  434. def test_WLSubtree(ds_name):
  435. """Test Weisfeiler-Lehman subtree kernel.
  436. """
  437. def compute(parallel=None):
  438. from gklearn.kernels import WLSubtree
  439. dataset = chooseDataset(ds_name)
  440. try:
  441. graph_kernel = WLSubtree(node_labels=dataset.node_labels,
  442. edge_labels=dataset.edge_labels,
  443. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  444. height=2)
  445. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  446. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  447. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  448. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  449. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  450. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  451. except Exception as exception:
  452. print(repr(exception))
  453. assert False, exception
  454. else:
  455. return gram_matrix, kernel_list, kernel
  456. assert_equality(compute, parallel=['imap_unordered', None])
  457. if __name__ == "__main__":
  458. test_list_graph_kernels()
  459. # test_spkernel('Alkane', 'imap_unordered')
  460. # test_ShortestPath('Alkane')
  461. # test_StructuralSP('Fingerprint_edge', 'imap_unordered')
  462. # test_StructuralSP('Acyclic')
  463. # test_StructuralSP('Cuneiform', None)
  464. # test_WLSubtree('Acyclic')
  465. # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered')
  466. # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered')
  467. # test_RandomWalk('Acyclic', 'fp', None, None)
  468. # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
  469. # test_CommonWalk('Acyclic', 0.01, 'geo')
  470. # test_Marginalized('Acyclic', False)
  471. # test_ShortestPath('Acyclic')
  472. # test_PathUpToH('Acyclic', 'MinMax')
  473. # test_Treelet('Acyclic')
  474. # test_SylvesterEquation('Acyclic')
  475. # test_ConjugateGradient('Acyclic')
  476. # test_FixedPoint('Acyclic')
  477. # test_SpectralDecomposition('Acyclic', 'exp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.