Browse Source

Use dataset.Dataset in the tests of graph kernel classes.

v0.2.x
jajupmochi 4 years ago
parent
commit
edc189005f
1 changed files with 9 additions and 15 deletions
  1. +9
    -15
      gklearn/tests/test_graph_kernels.py

+ 9
- 15
gklearn/tests/test_graph_kernels.py View File

@@ -19,44 +19,38 @@ def test_list_graph_kernels():
def chooseDataset(ds_name):
"""Choose dataset according to name.
"""
from gklearn.utils import Dataset

dataset = Dataset()
from gklearn.dataset import Dataset

# no node labels (and no edge labels).
if ds_name == 'Alkane':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Alkane_unlabeled')
dataset.trim_dataset(edge_required=False)
irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
dataset.remove_labels(**irrelevant_labels)
dataset.cut_graphs(range(1, 10))
# node symbolic labels.
elif ds_name == 'Acyclic':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Acyclic')
dataset.trim_dataset(edge_required=False)
irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
dataset.remove_labels(**irrelevant_labels)
# node non-symbolic labels.
elif ds_name == 'Letter-med':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Letter-med')
dataset.trim_dataset(edge_required=False)
# node symbolic and non-symbolic labels (and edge symbolic labels).
elif ds_name == 'AIDS':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('AIDS')
dataset.trim_dataset(edge_required=False)
# edge non-symbolic labels (no node labels).
elif ds_name == 'Fingerprint_edge':
dataset.load_predefined_dataset('Fingerprint')
dataset = Dataset('Fingerprint')
dataset.trim_dataset(edge_required=True)
irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
dataset.remove_labels(**irrelevant_labels)
# edge non-symbolic labels (and node non-symbolic labels).
elif ds_name == 'Fingerprint':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Fingerprint')
dataset.trim_dataset(edge_required=True)
# edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
elif ds_name == 'Cuneiform':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Cuneiform')
dataset.trim_dataset(edge_required=True)

dataset.cut_graphs(range(0, 3))
@@ -544,4 +538,4 @@ if __name__ == "__main__":
# test_RandomWalk('Acyclic', 'fp', None, None)
# test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
# test_CommonWalk('Alkane', 0.01, 'geo')
# test_ShortestPath('Acyclic')
# test_ShortestPath('Acyclic')

Loading…
Cancel
Save