Browse Source

Update Dataset class for predefined datasets.

v0.2.x
jajupmochi 4 years ago
parent
commit
f67d65bf51
3 changed files with 49 additions and 85 deletions
  1. +42
    -79
      gklearn/dataset/dataset.py
  2. +1
    -0
      gklearn/dataset/file_managers.py
  3. +6
    -6
      gklearn/dataset/metadata.py

+ 42
- 79
gklearn/dataset/dataset.py View File

@@ -7,23 +7,44 @@ Created on Thu Mar 26 18:48:27 2020
"""
import numpy as np
import networkx as nx
from gklearn.utils.graph_files import load_dataset
import os
from gklearn.dataset import DATASET_META, DataFetcher, DataLoader


class Dataset(object):
def __init__(self, filename=None, filename_targets=None, **kwargs):
if filename is None:
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
if inputs is None:
self._graphs = None
self._targets = None
self._node_labels = None
self._edge_labels = None
self._node_attrs = None
self._edge_attrs = None
# If inputs is a list of graphs.
elif isinstance(inputs, list):
node_labels = kwargs.get('node_labels', None)
node_attrs = kwargs.get('node_attrs', None)
edge_labels = kwargs.get('edge_labels', None)
edge_attrs = kwargs.get('edge_attrs', None)
self.load_graphs(inputs, targets=targets)
self.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
if clean_labels:
self.clean_labels()
elif isinstance(inputs, str):
# If inputs is predefined dataset name.
if inputs in DATASET_META:
self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
# If inputs is a file name.
else:
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
else:
self.load_dataset(filename, filename_targets=filename_targets, **kwargs)
raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
self._substructures = None
self._node_label_dim = None
@@ -51,13 +72,14 @@ class Dataset(object):
self._class_number = None
def load_dataset(self, filename, filename_targets=None, **kwargs):
self._graphs, self._targets, label_names = load_dataset(filename, filename_targets=filename_targets, **kwargs)
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
self._graphs, self._targets, label_names = DataLoader(filename, filename_targets=filename_targets, **kwargs).data
self._node_labels = label_names['node_labels']
self._node_attrs = label_names['node_attrs']
self._edge_labels = label_names['edge_labels']
self._edge_attrs = label_names['edge_attrs']
self.clean_labels()
if clean_labels:
self.clean_labels()
def load_graphs(self, graphs, targets=None):
@@ -67,84 +89,25 @@ class Dataset(object):
# self.set_labels_attrs() # @todo
def load_predefined_dataset(self, ds_name):
current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
if ds_name == 'Acyclic':
ds_file = current_path + '../../datasets/Acyclic/dataset_bps.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'AIDS':
ds_file = current_path + '../../datasets/AIDS/AIDS_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Alkane':
ds_file = current_path + '../../datasets/Alkane/dataset.ds'
fn_targets = current_path + '../../datasets/Alkane/dataset_boiling_point_names.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file, filename_targets=fn_targets)
elif ds_name == 'COIL-DEL':
ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'COIL-RAG':
ds_file = current_path + '../../datasets/COIL-RAG/COIL-RAG_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'COLORS-3':
ds_file = current_path + '../../datasets/COLORS-3/COLORS-3_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Cuneiform':
ds_file = current_path + '../../datasets/Cuneiform/Cuneiform_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'DD':
ds_file = current_path + '../../datasets/DD/DD_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'ENZYMES':
ds_file = current_path + '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Fingerprint':
ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'FRANKENSTEIN':
ds_file = current_path + '../../datasets/FRANKENSTEIN/FRANKENSTEIN_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-high': # node non-symb
ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-low': # node non-symb
ds_file = current_path + '../../datasets/Letter-low/Letter-low_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-med': # node non-symb
ds_file = current_path + '../../datasets/Letter-med/Letter-med_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'MAO':
ds_file = current_path + '../../datasets/MAO/dataset.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Monoterpenoides':
ds_file = current_path + '../../datasets/Monoterpenoides/dataset_10+.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'MUTAG':
ds_file = current_path + '../../datasets/MUTAG/MUTAG_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'NCI1':
ds_file = current_path + '../../datasets/NCI1/NCI1_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'NCI109':
ds_file = current_path + '../../datasets/NCI109/NCI109_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'PAH':
ds_file = current_path + '../../datasets/PAH/dataset.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'SYNTHETIC':
pass
elif ds_name == 'SYNTHETICnew':
ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Synthie':
pass
def load_predefined_dataset(self, ds_name, root='datasets', clean_labels=True, reload=False, verbose=False):
path = DataFetcher(name=ds_name, root=root, reload=reload, verbose=verbose).path
if DATASET_META[ds_name]['database'] == 'tudataset':
ds_file = os.path.join(path, ds_name + '_A.txt')
fn_targets = None
else:
raise Exception('The dataset name "', ds_name, '" is not pre-defined.')
load_files = DATASET_META[ds_name]['load_files']
ds_file = os.path.join(path, load_files[0])
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
self._node_labels = label_names['node_labels']
self._node_attrs = label_names['node_attrs']
self._edge_labels = label_names['edge_labels']
self._edge_attrs = label_names['edge_attrs']
self.clean_labels()
if clean_labels:
self.clean_labels()

def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):


+ 1
- 0
gklearn/dataset/file_managers.py View File

@@ -74,6 +74,7 @@ class DataLoader():
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
with open(filename) as fn:
content = fn.read().splitlines()
content = [line for line in content if not line.endswith('.ds')]
extension = splitext(content[0].split(' ')[0])[1][1:]
if extension == 'ct':
load_file_fun = self.load_ct


+ 6
- 6
gklearn/dataset/metadata.py View File

@@ -32,7 +32,7 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': [],
'stereoisomerism': True,
'load_files': [],
'load_files': ['data.ds'],
},
'Acyclic': {
'database': 'greyc',
@@ -165,7 +165,7 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'],
'stereoisomerism': False,
'load_files': [],
'load_files': ['dataset.ds'],
},
'PTC': {
'database': 'greyc',
@@ -654,7 +654,7 @@ TUDataset_META = {
'node_attr_dim': 0,
'geometry': None,
'edge_attr_dim': 0,
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23.zip-H23',
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23.zip',
'domain': 'small molecules',
},
'NCI-H23H': {
@@ -670,7 +670,7 @@ TUDataset_META = {
'node_attr_dim': 0,
'geometry': None,
'edge_attr_dim': 0,
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23H.zip-H23H',
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/NCI-H23H.zip',
'domain': 'small molecules',
},
'OVCAR-8': {
@@ -686,7 +686,7 @@ TUDataset_META = {
'node_attr_dim': 0,
'geometry': None,
'edge_attr_dim': 0,
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8.zip-8',
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8.zip',
'domain': 'small molecules',
},
'OVCAR-8H': {
@@ -702,7 +702,7 @@ TUDataset_META = {
'node_attr_dim': 0,
'geometry': None,
'edge_attr_dim': 0,
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8H.zip-8H',
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/OVCAR-8H.zip',
'domain': 'small molecules',
},
'P388': {


Loading…
Cancel
Save