@@ -7,23 +7,44 @@ Created on Thu Mar 26 18:48:27 2020
"""
import numpy as np
import networkx as nx
from gklearn.utils.graph_files import load_dataset
import os
from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
class Dataset(object):
def __init__(self, filename=None, filename_targets=Non e, **kwargs):
if filename is None:
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=Fals e, **kwargs):
if inputs is None:
self._graphs = None
self._targets = None
self._node_labels = None
self._edge_labels = None
self._node_attrs = None
self._edge_attrs = None
# If inputs is a list of graphs.
elif isinstance(inputs, list):
node_labels = kwargs.get('node_labels', None)
node_attrs = kwargs.get('node_attrs', None)
edge_labels = kwargs.get('edge_labels', None)
edge_attrs = kwargs.get('edge_attrs', None)
self.load_graphs(inputs, targets=targets)
self.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
if clean_labels:
self.clean_labels()
elif isinstance(inputs, str):
# If inputs is predefined dataset name.
if inputs in DATASET_META:
self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
# If inputs is a file name.
else:
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
else:
self.load_dataset(filename, filename_targets=filename_targets, **kwargs)
raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.' )
self._substructures = None
self._node_label_dim = None
@@ -51,13 +72,14 @@ class Dataset(object):
self._class_number = None
def load_dataset(self, filename, filename_targets=None, **kwargs):
self._graphs, self._targets, label_names = load_dataset (filename, filename_targets=filename_targets, **kwargs)
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
self._graphs, self._targets, label_names = DataLoader (filename, filename_targets=filename_targets, **kwargs).data
self._node_labels = label_names['node_labels']
self._node_attrs = label_names['node_attrs']
self._edge_labels = label_names['edge_labels']
self._edge_attrs = label_names['edge_attrs']
self.clean_labels()
if clean_labels:
self.clean_labels()
def load_graphs(self, graphs, targets=None):
@@ -67,84 +89,25 @@ class Dataset(object):
# self.set_labels_attrs() # @todo
def load_predefined_dataset(self, ds_name):
current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
if ds_name == 'Acyclic':
ds_file = current_path + '../../datasets/Acyclic/dataset_bps.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'AIDS':
ds_file = current_path + '../../datasets/AIDS/AIDS_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Alkane':
ds_file = current_path + '../../datasets/Alkane/dataset.ds'
fn_targets = current_path + '../../datasets/Alkane/dataset_boiling_point_names.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file, filename_targets=fn_targets)
elif ds_name == 'COIL-DEL':
ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'COIL-RAG':
ds_file = current_path + '../../datasets/COIL-RAG/COIL-RAG_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'COLORS-3':
ds_file = current_path + '../../datasets/COLORS-3/COLORS-3_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Cuneiform':
ds_file = current_path + '../../datasets/Cuneiform/Cuneiform_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'DD':
ds_file = current_path + '../../datasets/DD/DD_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'ENZYMES':
ds_file = current_path + '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Fingerprint':
ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'FRANKENSTEIN':
ds_file = current_path + '../../datasets/FRANKENSTEIN/FRANKENSTEIN_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-high': # node non-symb
ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-low': # node non-symb
ds_file = current_path + '../../datasets/Letter-low/Letter-low_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Letter-med': # node non-symb
ds_file = current_path + '../../datasets/Letter-med/Letter-med_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'MAO':
ds_file = current_path + '../../datasets/MAO/dataset.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Monoterpenoides':
ds_file = current_path + '../../datasets/Monoterpenoides/dataset_10+.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'MUTAG':
ds_file = current_path + '../../datasets/MUTAG/MUTAG_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'NCI1':
ds_file = current_path + '../../datasets/NCI1/NCI1_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'NCI109':
ds_file = current_path + '../../datasets/NCI109/NCI109_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'PAH':
ds_file = current_path + '../../datasets/PAH/dataset.ds'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'SYNTHETIC':
pass
elif ds_name == 'SYNTHETICnew':
ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
self._graphs, self._targets, label_names = load_dataset(ds_file)
elif ds_name == 'Synthie':
pass
def load_predefined_dataset(self, ds_name, root='datasets', clean_labels=True, reload=False, verbose=False):
path = DataFetcher(name=ds_name, root=root, reload=reload, verbose=verbose).path
if DATASET_META[ds_name]['database'] == 'tudataset':
ds_file = os.path.join(path, ds_name + '_A.txt')
fn_targets = None
else:
raise Exception('The dataset name "', ds_name, '" is not pre-defined.')
load_files = DATASET_META[ds_name]['load_files']
ds_file = os.path.join(path, load_files[0])
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
self._node_labels = label_names['node_labels']
self._node_attrs = label_names['node_attrs']
self._edge_labels = label_names['edge_labels']
self._edge_attrs = label_names['edge_attrs']
self.clean_labels()
if clean_labels:
self.clean_labels()
def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):