Browse Source

Use dataset.Dataset in the tests of graph kernel classes.

v0.2.x
jajupmochi 4 years ago
parent
commit
edc189005f
1 changed files with 9 additions and 15 deletions
  1. +9
    -15
      gklearn/tests/test_graph_kernels.py

+ 9
- 15
gklearn/tests/test_graph_kernels.py View File

@@ -19,44 +19,38 @@ def test_list_graph_kernels():
def chooseDataset(ds_name): def chooseDataset(ds_name):
"""Choose dataset according to name. """Choose dataset according to name.
""" """
from gklearn.utils import Dataset

dataset = Dataset()
from gklearn.dataset import Dataset


# no node labels (and no edge labels). # no node labels (and no edge labels).
if ds_name == 'Alkane': if ds_name == 'Alkane':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Alkane_unlabeled')
dataset.trim_dataset(edge_required=False) dataset.trim_dataset(edge_required=False)
irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
dataset.remove_labels(**irrelevant_labels)
dataset.cut_graphs(range(1, 10)) dataset.cut_graphs(range(1, 10))
# node symbolic labels. # node symbolic labels.
elif ds_name == 'Acyclic': elif ds_name == 'Acyclic':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Acyclic')
dataset.trim_dataset(edge_required=False) dataset.trim_dataset(edge_required=False)
irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']}
dataset.remove_labels(**irrelevant_labels)
# node non-symbolic labels. # node non-symbolic labels.
elif ds_name == 'Letter-med': elif ds_name == 'Letter-med':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Letter-med')
dataset.trim_dataset(edge_required=False) dataset.trim_dataset(edge_required=False)
# node symbolic and non-symbolic labels (and edge symbolic labels). # node symbolic and non-symbolic labels (and edge symbolic labels).
elif ds_name == 'AIDS': elif ds_name == 'AIDS':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('AIDS')
dataset.trim_dataset(edge_required=False) dataset.trim_dataset(edge_required=False)
# edge non-symbolic labels (no node labels). # edge non-symbolic labels (no node labels).
elif ds_name == 'Fingerprint_edge': elif ds_name == 'Fingerprint_edge':
dataset.load_predefined_dataset('Fingerprint')
dataset = Dataset('Fingerprint')
dataset.trim_dataset(edge_required=True) dataset.trim_dataset(edge_required=True)
irrelevant_labels = {'edge_attrs': ['orient', 'angle']} irrelevant_labels = {'edge_attrs': ['orient', 'angle']}
dataset.remove_labels(**irrelevant_labels) dataset.remove_labels(**irrelevant_labels)
# edge non-symbolic labels (and node non-symbolic labels). # edge non-symbolic labels (and node non-symbolic labels).
elif ds_name == 'Fingerprint': elif ds_name == 'Fingerprint':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Fingerprint')
dataset.trim_dataset(edge_required=True) dataset.trim_dataset(edge_required=True)
# edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels). # edge symbolic and non-symbolic labels (and node symbolic and non-symbolic labels).
elif ds_name == 'Cuneiform': elif ds_name == 'Cuneiform':
dataset.load_predefined_dataset(ds_name)
dataset = Dataset('Cuneiform')
dataset.trim_dataset(edge_required=True) dataset.trim_dataset(edge_required=True)


dataset.cut_graphs(range(0, 3)) dataset.cut_graphs(range(0, 3))
@@ -544,4 +538,4 @@ if __name__ == "__main__":
# test_RandomWalk('Acyclic', 'fp', None, None) # test_RandomWalk('Acyclic', 'fp', None, None)
# test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered') # test_RandomWalk('Acyclic', 'spectral', 'exp', 'imap_unordered')
# test_CommonWalk('Alkane', 0.01, 'geo') # test_CommonWalk('Alkane', 0.01, 'geo')
# test_ShortestPath('Acyclic')
# test_ShortestPath('Acyclic')

Loading…
Cancel
Save