You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_kernels.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """Tests of graph kernels.
  2. """
  3. import pytest
  4. import multiprocessing
  5. import numpy as np
  6. ##############################################################################
  7. def test_list_graph_kernels():
  8. """
  9. """
  10. from gklearn.kernels import GRAPH_KERNELS, list_of_graph_kernels
  11. assert list_of_graph_kernels() == [i for i in GRAPH_KERNELS]
  12. ##############################################################################
  13. def chooseDataset(ds_name):
  14. """Choose dataset according to name.
  15. """
  16. from gklearn.dataset import Dataset
  17. import os
  18. current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
  19. root = current_path + '../../datasets/'
  20. # no labels at all.
  21. if ds_name == 'Alkane_unlabeled':
  22. dataset = Dataset('Alkane_unlabeled', root=root)
  23. dataset.trim_dataset(edge_required=False)
  24. dataset.cut_graphs(range(1, 10))
  25. # node symbolic labels only.
  26. elif ds_name == 'Acyclic':
  27. dataset = Dataset('Acyclic', root=root)
  28. dataset.trim_dataset(edge_required=False)
  29. # node non-symbolic labels only.
  30. elif ds_name == 'Letter-med':
  31. dataset = Dataset('Letter-med', root=root)
  32. dataset.trim_dataset(edge_required=False)
  33. # node symbolic + non-symbolic labels + edge symbolic labels.
  34. elif ds_name == 'AIDS':
  35. dataset = Dataset('AIDS', root=root)
  36. dataset.trim_dataset(edge_required=False)
  37. # node non-symbolic labels + edge non-symbolic labels.
  38. elif ds_name == 'Fingerprint':
  39. dataset = Dataset('Fingerprint', root=root)
  40. dataset.trim_dataset(edge_required=True)
  41. # edge symbolic only.
  42. elif ds_name == 'MAO':
  43. dataset = Dataset('MAO', root=root)
  44. dataset.trim_dataset(edge_required=True)
  45. irrelevant_labels = {'node_labels': ['atom_symbol'], 'node_attrs': ['x', 'y']}
  46. dataset.remove_labels(**irrelevant_labels)
  47. # edge non-symbolic labels only.
  48. elif ds_name == 'Fingerprint_edge':
  49. dataset = Dataset('Fingerprint', root=root)
  50. dataset.trim_dataset(edge_required=True)
  51. irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
  52. dataset.remove_labels(**irrelevant_labels)
  53. # node symbolic and non-symbolic labels + edge symbolic and non-symbolic labels.
  54. elif ds_name == 'Cuneiform':
  55. dataset = Dataset('Cuneiform', root=root)
  56. dataset.trim_dataset(edge_required=True)
  57. dataset.cut_graphs(range(0, 3))
  58. return dataset
  59. def assert_equality(compute_fun, **kwargs):
  60. """Check if outputs are the same using different methods to compute.
  61. Parameters
  62. ----------
  63. compute_fun : function
  64. The function to compute the kernel, with the same key word arguments as
  65. kwargs.
  66. **kwargs : dict
  67. The key word arguments over the grid of which the kernel results are
  68. compared.
  69. Returns
  70. -------
  71. None.
  72. """
  73. from sklearn.model_selection import ParameterGrid
  74. param_grid = ParameterGrid(kwargs)
  75. result_lists = [[], [], []]
  76. for params in list(param_grid):
  77. results = compute_fun(**params)
  78. for rs, lst in zip(results, result_lists):
  79. lst.append(rs)
  80. for lst in result_lists:
  81. for i in range(len(lst[:-1])):
  82. assert np.array_equal(lst[i], lst[i + 1])
  83. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS'])
  84. @pytest.mark.parametrize('weight,compute_method', [(0.01, 'geo'), (1, 'exp')])
  85. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  86. def test_CommonWalk(ds_name, weight, compute_method):
  87. """Test common walk kernel.
  88. """
  89. def compute(parallel=None):
  90. from gklearn.kernels import CommonWalk
  91. import networkx as nx
  92. dataset = chooseDataset(ds_name)
  93. dataset.load_graphs([g for g in dataset.graphs if nx.number_of_nodes(g) > 1])
  94. try:
  95. graph_kernel = CommonWalk(node_labels=dataset.node_labels,
  96. edge_labels=dataset.edge_labels,
  97. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  98. weight=weight,
  99. compute_method=compute_method)
  100. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  101. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  102. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  103. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  104. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  105. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  106. except Exception as exception:
  107. print(repr(exception))
  108. assert False, exception
  109. else:
  110. return gram_matrix, kernel_list, kernel
  111. assert_equality(compute, parallel=['imap_unordered', None])
  112. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS'])
  113. @pytest.mark.parametrize('remove_totters', [False]) #[True, False])
  114. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  115. def test_Marginalized(ds_name, remove_totters):
  116. """Test marginalized kernel.
  117. """
  118. def compute(parallel=None):
  119. from gklearn.kernels import Marginalized
  120. dataset = chooseDataset(ds_name)
  121. try:
  122. graph_kernel = Marginalized(node_labels=dataset.node_labels,
  123. edge_labels=dataset.edge_labels,
  124. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  125. p_quit=0.5,
  126. n_iteration=2,
  127. remove_totters=remove_totters)
  128. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  129. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  130. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  131. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  132. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  133. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  134. except Exception as exception:
  135. print(repr(exception))
  136. assert False, exception
  137. else:
  138. return gram_matrix, kernel_list, kernel
  139. assert_equality(compute, parallel=['imap_unordered', None])
  140. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  141. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  142. def test_SylvesterEquation(ds_name):
  143. """Test sylvester equation kernel.
  144. """
  145. def compute(parallel=None):
  146. from gklearn.kernels import SylvesterEquation
  147. dataset = chooseDataset(ds_name)
  148. try:
  149. graph_kernel = SylvesterEquation(
  150. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  151. weight=1e-3,
  152. p=None,
  153. q=None,
  154. edge_weight=None)
  155. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  156. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  157. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  158. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  159. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  160. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  161. except Exception as exception:
  162. print(repr(exception))
  163. assert False, exception
  164. else:
  165. return gram_matrix, kernel_list, kernel
  166. assert_equality(compute, parallel=['imap_unordered', None])
  167. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  168. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  169. def test_ConjugateGradient(ds_name):
  170. """Test conjugate gradient kernel.
  171. """
  172. def compute(parallel=None):
  173. from gklearn.kernels import ConjugateGradient
  174. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  175. import functools
  176. dataset = chooseDataset(ds_name)
  177. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  178. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  179. try:
  180. graph_kernel = ConjugateGradient(
  181. node_labels=dataset.node_labels,
  182. node_attrs=dataset.node_attrs,
  183. edge_labels=dataset.edge_labels,
  184. edge_attrs=dataset.edge_attrs,
  185. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  186. weight=1e-3,
  187. p=None,
  188. q=None,
  189. edge_weight=None,
  190. node_kernels=sub_kernels,
  191. edge_kernels=sub_kernels)
  192. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  193. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  194. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  195. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  196. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  197. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  198. except Exception as exception:
  199. print(repr(exception))
  200. assert False, exception
  201. else:
  202. return gram_matrix, kernel_list, kernel
  203. assert_equality(compute, parallel=['imap_unordered', None])
  204. @pytest.mark.parametrize('ds_name', ['Acyclic', 'AIDS'])
  205. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  206. def test_FixedPoint(ds_name):
  207. """Test fixed point kernel.
  208. """
  209. def compute(parallel=None):
  210. from gklearn.kernels import FixedPoint
  211. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  212. import functools
  213. dataset = chooseDataset(ds_name)
  214. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  215. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  216. try:
  217. graph_kernel = FixedPoint(
  218. node_labels=dataset.node_labels,
  219. node_attrs=dataset.node_attrs,
  220. edge_labels=dataset.edge_labels,
  221. edge_attrs=dataset.edge_attrs,
  222. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  223. weight=1e-3,
  224. p=None,
  225. q=None,
  226. edge_weight=None,
  227. node_kernels=sub_kernels,
  228. edge_kernels=sub_kernels)
  229. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  230. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  231. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  232. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  233. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  234. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  235. except Exception as exception:
  236. print(repr(exception))
  237. assert False, exception
  238. else:
  239. return gram_matrix, kernel_list, kernel
  240. assert_equality(compute, parallel=['imap_unordered', None])
  241. @pytest.mark.parametrize('ds_name', ['Acyclic'])
  242. @pytest.mark.parametrize('sub_kernel', ['exp', 'geo'])
  243. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  244. def test_SpectralDecomposition(ds_name, sub_kernel):
  245. """Test spectral decomposition kernel.
  246. """
  247. def compute(parallel=None):
  248. from gklearn.kernels import SpectralDecomposition
  249. dataset = chooseDataset(ds_name)
  250. try:
  251. graph_kernel = SpectralDecomposition(
  252. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  253. weight=1e-3,
  254. p=None,
  255. q=None,
  256. edge_weight=None,
  257. sub_kernel=sub_kernel)
  258. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  259. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  260. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  261. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  262. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  263. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  264. except Exception as exception:
  265. print(repr(exception))
  266. assert False, exception
  267. else:
  268. return gram_matrix, kernel_list, kernel
  269. assert_equality(compute, parallel=['imap_unordered', None])
  270. # @pytest.mark.parametrize(
  271. # 'compute_method,ds_name,sub_kernel',
  272. # [
  273. # ('sylvester', 'Alkane_unlabeled', None),
  274. # ('conjugate', 'Alkane_unlabeled', None),
  275. # ('conjugate', 'AIDS', None),
  276. # ('fp', 'Alkane_unlabeled', None),
  277. # ('fp', 'AIDS', None),
  278. # ('spectral', 'Alkane_unlabeled', 'exp'),
  279. # ('spectral', 'Alkane_unlabeled', 'geo'),
  280. # ]
  281. # )
  282. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  283. # def test_RandomWalk(ds_name, compute_method, sub_kernel, parallel):
  284. # """Test random walk kernel.
  285. # """
  286. # from gklearn.kernels import RandomWalk
  287. # from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  288. # import functools
  289. #
  290. # dataset = chooseDataset(ds_name)
  291. # mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  292. # sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  293. # # try:
  294. # graph_kernel = RandomWalk(node_labels=dataset.node_labels,
  295. # node_attrs=dataset.node_attrs,
  296. # edge_labels=dataset.edge_labels,
  297. # edge_attrs=dataset.edge_attrs,
  298. # ds_infos=dataset.get_dataset_infos(keys=['directed']),
  299. # compute_method=compute_method,
  300. # weight=1e-3,
  301. # p=None,
  302. # q=None,
  303. # edge_weight=None,
  304. # node_kernels=sub_kernels,
  305. # edge_kernels=sub_kernels,
  306. # sub_kernel=sub_kernel)
  307. # gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  308. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  309. # kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  310. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  311. # kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  312. # parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  313. # except Exception as exception:
  314. # assert False, exception
  315. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  316. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  317. def test_ShortestPath(ds_name):
  318. """Test shortest path kernel.
  319. """
  320. def compute(parallel=None, fcsp=None):
  321. from gklearn.kernels import ShortestPath
  322. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  323. import functools
  324. dataset = chooseDataset(ds_name)
  325. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  326. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  327. try:
  328. graph_kernel = ShortestPath(node_labels=dataset.node_labels,
  329. node_attrs=dataset.node_attrs,
  330. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  331. fcsp=fcsp,
  332. node_kernels=sub_kernels)
  333. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  334. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  335. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  336. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  337. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  338. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  339. except Exception as exception:
  340. print(repr(exception))
  341. assert False, exception
  342. else:
  343. return gram_matrix, kernel_list, kernel
  344. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  345. #@pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint'])
  346. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'Acyclic', 'Letter-med', 'AIDS', 'Fingerprint', 'Fingerprint_edge', 'Cuneiform'])
  347. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  348. def test_StructuralSP(ds_name):
  349. """Test structural shortest path kernel.
  350. """
  351. def compute(parallel=None, fcsp=None):
  352. from gklearn.kernels import StructuralSP
  353. from gklearn.utils.kernels import deltakernel, gaussiankernel, kernelproduct
  354. import functools
  355. dataset = chooseDataset(ds_name)
  356. mixkernel = functools.partial(kernelproduct, deltakernel, gaussiankernel)
  357. sub_kernels = {'symb': deltakernel, 'nsymb': gaussiankernel, 'mix': mixkernel}
  358. try:
  359. graph_kernel = StructuralSP(node_labels=dataset.node_labels,
  360. edge_labels=dataset.edge_labels,
  361. node_attrs=dataset.node_attrs,
  362. edge_attrs=dataset.edge_attrs,
  363. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  364. fcsp=fcsp,
  365. node_kernels=sub_kernels,
  366. edge_kernels=sub_kernels)
  367. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  368. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  369. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  370. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  371. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  372. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  373. except Exception as exception:
  374. print(repr(exception))
  375. assert False, exception
  376. else:
  377. return gram_matrix, kernel_list, kernel
  378. assert_equality(compute, parallel=['imap_unordered', None], fcsp=[True, False])
  379. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS'])
  380. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  381. #@pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto', None])
  382. @pytest.mark.parametrize('k_func', ['MinMax', 'tanimoto'])
  383. # @pytest.mark.parametrize('compute_method', ['trie', 'naive'])
  384. def test_PathUpToH(ds_name, k_func):
  385. """Test path kernel up to length $h$.
  386. """
  387. def compute(parallel=None, compute_method=None):
  388. from gklearn.kernels import PathUpToH
  389. dataset = chooseDataset(ds_name)
  390. try:
  391. graph_kernel = PathUpToH(node_labels=dataset.node_labels,
  392. edge_labels=dataset.edge_labels,
  393. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  394. depth=2, k_func=k_func, compute_method=compute_method)
  395. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  396. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  397. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  398. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  399. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  400. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  401. except Exception as exception:
  402. print(repr(exception))
  403. assert False, exception
  404. else:
  405. return gram_matrix, kernel_list, kernel
  406. assert_equality(compute, parallel=['imap_unordered', None],
  407. compute_method=['trie', 'naive'])
  408. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'AIDS'])
  409. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  410. def test_Treelet(ds_name):
  411. """Test treelet kernel.
  412. """
  413. def compute(parallel=None):
  414. from gklearn.kernels import Treelet
  415. from gklearn.utils.kernels import polynomialkernel
  416. import functools
  417. dataset = chooseDataset(ds_name)
  418. pkernel = functools.partial(polynomialkernel, d=2, c=1e5)
  419. try:
  420. graph_kernel = Treelet(node_labels=dataset.node_labels,
  421. edge_labels=dataset.edge_labels,
  422. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  423. sub_kernel=pkernel)
  424. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  425. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  426. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  427. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  428. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  429. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  430. except Exception as exception:
  431. print(repr(exception))
  432. assert False, exception
  433. else:
  434. return gram_matrix, kernel_list, kernel
  435. assert_equality(compute, parallel=['imap_unordered', None])
  436. @pytest.mark.parametrize('ds_name', ['Alkane_unlabeled', 'Acyclic', 'MAO', 'AIDS'])
  437. #@pytest.mark.parametrize('base_kernel', ['subtree', 'sp', 'edge'])
  438. # @pytest.mark.parametrize('base_kernel', ['subtree'])
  439. # @pytest.mark.parametrize('parallel', ['imap_unordered', None])
  440. def test_WLSubtree(ds_name):
  441. """Test Weisfeiler-Lehman subtree kernel.
  442. """
  443. def compute(parallel=None):
  444. from gklearn.kernels import WLSubtree
  445. dataset = chooseDataset(ds_name)
  446. try:
  447. graph_kernel = WLSubtree(node_labels=dataset.node_labels,
  448. edge_labels=dataset.edge_labels,
  449. ds_infos=dataset.get_dataset_infos(keys=['directed']),
  450. height=2)
  451. gram_matrix, run_time = graph_kernel.compute(dataset.graphs,
  452. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  453. kernel_list, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1:],
  454. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  455. kernel, run_time = graph_kernel.compute(dataset.graphs[0], dataset.graphs[1],
  456. parallel=parallel, n_jobs=multiprocessing.cpu_count(), verbose=True)
  457. except Exception as exception:
  458. print(repr(exception))
  459. assert False, exception
  460. else:
  461. return gram_matrix, kernel_list, kernel
  462. assert_equality(compute, parallel=[None, 'imap_unordered'])
  463. if __name__ == "__main__":
  464. # test_list_graph_kernels()
  465. # test_spkernel('Alkane_unlabeled', 'imap_unordered')
  466. # test_ShortestPath('Alkane_unlabeled')
  467. # test_StructuralSP('Fingerprint_edge', 'imap_unordered')
  468. # test_StructuralSP('Acyclic')
  469. # test_StructuralSP('Cuneiform', None)
  470. test_WLSubtree('MAO') # 'Alkane_unlabeled', 'Acyclic', 'AIDS'
  471. # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered')
  472. # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered')
  473. # test_RandomWalk('Acyclic', 'fp', None, None)
  474. # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
  475. # test_CommonWalk('Acyclic', 0.01, 'geo')
  476. # test_Marginalized('Acyclic', False)
  477. # test_ShortestPath('Acyclic')
  478. # test_PathUpToH('Acyclic', 'MinMax')
  479. # test_Treelet('AIDS')
  480. # test_SylvesterEquation('Acyclic')
  481. # test_ConjugateGradient('Acyclic')
  482. # test_FixedPoint('Acyclic')
  483. # test_SpectralDecomposition('Acyclic', 'exp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.