From f67d65bf51c53102918a590b500b4242dd1898a0 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Sun, 15 Nov 2020 17:11:45 +0100 Subject: [PATCH] Update Dataset class for predefined datasets. --- gklearn/dataset/dataset.py | 121 ++++++++++++++------------------------- gklearn/dataset/file_managers.py | 1 + gklearn/dataset/metadata.py | 12 ++-- 3 files changed, 49 insertions(+), 85 deletions(-) diff --git a/gklearn/dataset/dataset.py b/gklearn/dataset/dataset.py index 0343c0b..cf90051 100644 --- a/gklearn/dataset/dataset.py +++ b/gklearn/dataset/dataset.py @@ -7,23 +7,44 @@ Created on Thu Mar 26 18:48:27 2020 """ import numpy as np import networkx as nx -from gklearn.utils.graph_files import load_dataset import os +from gklearn.dataset import DATASET_META, DataFetcher, DataLoader class Dataset(object): - def __init__(self, filename=None, filename_targets=None, **kwargs): - if filename is None: + def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): + if inputs is None: self._graphs = None self._targets = None self._node_labels = None self._edge_labels = None self._node_attrs = None self._edge_attrs = None + + # If inputs is a list of graphs. + elif isinstance(inputs, list): + node_labels = kwargs.get('node_labels', None) + node_attrs = kwargs.get('node_attrs', None) + edge_labels = kwargs.get('edge_labels', None) + edge_attrs = kwargs.get('edge_attrs', None) + self.load_graphs(inputs, targets=targets) + self.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs) + if clean_labels: + self.clean_labels() + + elif isinstance(inputs, str): + # If inputs is predefined dataset name. + if inputs in DATASET_META: + self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose) + + # If inputs is a file name. + else: + self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs) + else: - self.load_dataset(filename, filename_targets=filename_targets, **kwargs) + raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') self._substructures = None self._node_label_dim = None @@ -51,13 +72,14 @@ class Dataset(object): self._class_number = None - def load_dataset(self, filename, filename_targets=None, **kwargs): - self._graphs, self._targets, label_names = load_dataset(filename, filename_targets=filename_targets, **kwargs) + def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): + self._graphs, self._targets, label_names = DataLoader(filename, filename_targets=filename_targets, **kwargs).data self._node_labels = label_names['node_labels'] self._node_attrs = label_names['node_attrs'] self._edge_labels = label_names['edge_labels'] self._edge_attrs = label_names['edge_attrs'] - self.clean_labels() + if clean_labels: + self.clean_labels() def load_graphs(self, graphs, targets=None): @@ -67,84 +89,25 @@ class Dataset(object): # self.set_labels_attrs() # @todo - def load_predefined_dataset(self, ds_name): - current_path = os.path.dirname(os.path.realpath(__file__)) + '/' - if ds_name == 'Acyclic': - ds_file = current_path + '../../datasets/Acyclic/dataset_bps.ds' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'AIDS': - ds_file = current_path + '../../datasets/AIDS/AIDS_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Alkane': - ds_file = current_path + '../../datasets/Alkane/dataset.ds' - fn_targets = current_path + '../../datasets/Alkane/dataset_boiling_point_names.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file, filename_targets=fn_targets) - elif ds_name == 'COIL-DEL': - ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'COIL-RAG': - ds_file = current_path + '../../datasets/COIL-RAG/COIL-RAG_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'COLORS-3': - ds_file = current_path + '../../datasets/COLORS-3/COLORS-3_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Cuneiform': - ds_file = current_path + '../../datasets/Cuneiform/Cuneiform_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'DD': - ds_file = current_path + '../../datasets/DD/DD_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'ENZYMES': - ds_file = current_path + '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Fingerprint': - ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'FRANKENSTEIN': - ds_file = current_path + '../../datasets/FRANKENSTEIN/FRANKENSTEIN_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Letter-high': # node non-symb - ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Letter-low': # node non-symb - ds_file = current_path + '../../datasets/Letter-low/Letter-low_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Letter-med': # node non-symb - ds_file = current_path + '../../datasets/Letter-med/Letter-med_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'MAO': - ds_file = current_path + '../../datasets/MAO/dataset.ds' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Monoterpenoides': - ds_file = current_path + '../../datasets/Monoterpenoides/dataset_10+.ds' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'MUTAG': - ds_file = current_path + '../../datasets/MUTAG/MUTAG_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'NCI1': - ds_file = current_path + '../../datasets/NCI1/NCI1_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'NCI109': - ds_file = current_path + '../../datasets/NCI109/NCI109_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'PAH': - ds_file = current_path + '../../datasets/PAH/dataset.ds' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'SYNTHETIC': - pass - elif ds_name == 'SYNTHETICnew': - ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt' - self._graphs, self._targets, label_names = load_dataset(ds_file) - elif ds_name == 'Synthie': - pass + def load_predefined_dataset(self, ds_name, root='datasets', clean_labels=True, reload=False, verbose=False): + path = DataFetcher(name=ds_name, root=root, reload=reload, verbose=verbose).path + + if DATASET_META[ds_name]['database'] == 'tudataset': + ds_file = os.path.join(path, ds_name + '_A.txt') + fn_targets = None else: - raise Exception('The dataset name "', ds_name, '" is not pre-defined.') + load_files = DATASET_META[ds_name]['load_files'] + ds_file = os.path.join(path, load_files[0]) + fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None + + self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data self._node_labels = label_names['node_labels'] self._node_attrs = label_names['node_attrs'] self._edge_labels = label_names['edge_labels'] self._edge_attrs = label_names['edge_attrs'] - self.clean_labels() + if clean_labels: + self.clean_labels() def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]): diff --git a/gklearn/dataset/file_managers.py b/gklearn/dataset/file_managers.py index f2e539e..76ea9b0 100644 --- a/gklearn/dataset/file_managers.py +++ b/gklearn/dataset/file_managers.py @@ -74,6 +74,7 @@ class DataLoader(): label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} with open(filename) as fn: content = fn.read().splitlines() + content = [line for line in content if not line.endswith('.ds')] extension = splitext(content[0].split(' ')[0])[1][1:] if extension == 'ct': load_file_fun = self.load_ct diff --git a/gklearn/dataset/metadata.py b/gklearn/dataset/metadata.py index 4fa48d9..9725517 100644 --- a/gklearn/dataset/metadata.py +++ b/gklearn/dataset/metadata.py @@ -32,7 +32,7 @@ GREYC_META = { 'domain': 'small molecules', 'train_valid_test': [], 'stereoisomerism': True, - 'load_files': [], + 'load_files': ['data.ds'], }, 'Acyclic': { 'database': 'greyc', @@ -165,7 +165,7 @@ GREYC_META = { 'domain': 'small molecules', 'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'], 'stereoisomerism': False, - 'load_files': [], + 'load_files': ['dataset.ds'], }, 'PTC': { 'database': 'greyc', @@ -654,7 +654,7 @@ TUDataset_META = { 'node_attr_dim': 0, 'geometry': None, 'edge_attr_dim': 0, - 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23.zip-H23', + 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23.zip', 'domain': 'small molecules', }, 'NCI-H23H': { @@ -670,7 +670,7 @@ TUDataset_META = { 'node_attr_dim': 0, 'geometry': None, 'edge_attr_dim': 0, - 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23H.zip-H23H', + 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23H.zip', 'domain': 'small molecules', }, 'OVCAR-8': { @@ -686,7 +686,7 @@ TUDataset_META = { 'node_attr_dim': 0, 'geometry': None, 'edge_attr_dim': 0, - 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8.zip-8', + 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8.zip', 'domain': 'small molecules', }, 'OVCAR-8H': { @@ -702,7 +702,7 @@ TUDataset_META = { 'node_attr_dim': 0, 'geometry': None, 'edge_attr_dim': 0, - 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8H.zip-8H', + 'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8H.zip', 'domain': 'small molecules', }, 'P388': {