diff --git a/gklearn/tests/test_graph_kernels.py b/gklearn/tests/test_graph_kernels.py index 021b1cc..3b69de5 100644 --- a/gklearn/tests/test_graph_kernels.py +++ b/gklearn/tests/test_graph_kernels.py @@ -19,44 +19,38 @@ def test_list_graph_kernels(): def chooseDataset(ds_name): """Choose dataset according to name. """ - from gklearn.utils import Dataset - - dataset = Dataset() + from gklearn.dataset import Dataset # no node labels (and no edge labels). if ds_name == 'Alkane': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('Alkane_unlabeled') dataset.trim_dataset(edge_required=False) - irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} - dataset.remove_labels(**irrelevant_labels) dataset.cut_graphs(range(1, 10)) # node symbolic labels. elif ds_name == 'Acyclic': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('Acyclic') dataset.trim_dataset(edge_required=False) - irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} - dataset.remove_labels(**irrelevant_labels) # node non-symbolic labels. elif ds_name == 'Letter-med': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('Letter-med') dataset.trim_dataset(edge_required=False) # node symbolic and non-symbolic labels (and edge symbolic labels). elif ds_name == 'AIDS': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('AIDS') dataset.trim_dataset(edge_required=False) # edge non-symbolic labels (no node labels). elif ds_name == 'Fingerprint_edge': - dataset.load_predefined_dataset('Fingerprint') + dataset = Dataset('Fingerprint') dataset.trim_dataset(edge_required=True) irrelevant_labels = {'edge_attrs': ['orient', 'angle']} dataset.remove_labels(**irrelevant_labels) # edge non-symbolic labels (and node non-symbolic labels). elif ds_name == 'Fingerprint': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('Fingerprint') dataset.trim_dataset(edge_required=True) # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels). elif ds_name == 'Cuneiform': - dataset.load_predefined_dataset(ds_name) + dataset = Dataset('Cuneiform') dataset.trim_dataset(edge_required=True) dataset.cut_graphs(range(0, 3)) @@ -544,4 +538,4 @@ if __name__ == "__main__": # test_RandomWalk('Acyclic', 'fp', None, None) # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered') # test_CommonWalk('Alkane', 0.01, 'geo') -# test_ShortestPath('Acyclic') + # test_ShortestPath('Acyclic')