From 2a065f4fb6ee1d8871154403423061880e9e0544 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Fri, 16 Oct 2020 17:50:01 +0200 Subject: [PATCH] Fix a bug in the WL kernel class. --- gklearn/kernels/weisfeiler_lehman.py | 6 +++--- gklearn/tests/test_graph_kernels.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gklearn/kernels/weisfeiler_lehman.py b/gklearn/kernels/weisfeiler_lehman.py index 124e1f4..8b36b37 100644 --- a/gklearn/kernels/weisfeiler_lehman.py +++ b/gklearn/kernels/weisfeiler_lehman.py @@ -153,7 +153,7 @@ class WeisfeilerLehman(GraphKernel): # @todo: total parallelization and sp, edge all_num_of_each_label.append(dict(Counter(labels_ori))) # Compute subtree kernel with the 0th iteration and add it to the final kernel. - self._compute_gram_matrix(gram_matrix, all_num_of_each_label, Gn) + self._compute_gram_itr(gram_matrix, all_num_of_each_label, Gn) # iterate each height for h in range(1, self._height + 1): @@ -199,12 +199,12 @@ class WeisfeilerLehman(GraphKernel): # @todo: total parallelization and sp, edge all_num_of_each_label.append(dict(Counter(labels_comp))) # Compute subtree kernel with h iterations and add it to the final kernel - self._compute_gram_matrix(gram_matrix, all_num_of_each_label, Gn) + self._compute_gram_itr(gram_matrix, all_num_of_each_label, Gn) return gram_matrix - def _compute_gram_matrix(self, gram_matrix, all_num_of_each_label, Gn): + def _compute_gram_itr(self, gram_matrix, all_num_of_each_label, Gn): """Compute Gram matrix using the base kernel. """ if self._parallel == 'imap_unordered': diff --git a/gklearn/tests/test_graph_kernels.py b/gklearn/tests/test_graph_kernels.py index 4fbbbe7..f1c480a 100644 --- a/gklearn/tests/test_graph_kernels.py +++ b/gklearn/tests/test_graph_kernels.py @@ -434,7 +434,8 @@ def test_WLSubtree(ds_name, parallel): if __name__ == "__main__": # test_spkernel('Alkane', 'imap_unordered') - test_StructuralSP('Fingerprint_edge', 'imap_unordered') +# test_StructuralSP('Fingerprint_edge', 'imap_unordered') + test_WLSubtree('Acyclic', 'imap_unordered') # test_RandomWalk('Acyclic', 'sylvester', None, 'imap_unordered') # test_RandomWalk('Acyclic', 'conjugate', None, 'imap_unordered') # test_RandomWalk('Acyclic', 'fp', None, None)