@@ -14,7 +14,33 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
class Dataset(object):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
self._ds_name = None
if inputs is None:
self._graphs = None
self._targets = None
@@ -38,38 +64,26 @@ class Dataset(object):
# If inputs is predefined dataset name.
if inputs in DATASET_META:
self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs
elif inputs.endswith('_unlabeled'):
self.load_predefined_dataset(inputs[:len(inputs) - 10], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs
# Deal with special suffices.
self.check_special_suffices()
# If inputs is a file name.
elif os.path.isfile(inputs):
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
# If inputs is a file name.
else:
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
raise ValueError('The "inputs" argument "' + inputs + '" is not a valid dataset name or file name.' )
else:
raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
@@ -97,7 +111,10 @@ class Dataset(object):
fn_targets = None
else:
load_files = DATASET_META[ds_name]['load_files']
ds_file = os.path.join(path, load_files[0])
if isinstance(load_files[0], str):
ds_file = os.path.join(path, load_files[0])
else: # load_files[0] is a list of files.
ds_file = [os.path.join(path, fn) for fn in load_files[0]]
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
@@ -108,6 +125,11 @@ class Dataset(object):
self._edge_attrs = label_names['edge_attrs']
if clean_labels:
self.clean_labels()
# Deal with specific datasets.
if ds_name == 'Alkane':
self.trim_dataset(edge_required=True)
self.remove_labels(node_labels=['atom_symbol'])
def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
@@ -536,6 +558,14 @@ class Dataset(object):
return dataset
def check_special_suffices(self):
if self._ds_name.endswith('_unlabeled'):
self.remove_labels(node_labels=self._node_labels,
edge_labels=self._edge_labels,
node_attrs=self._node_attrs,
edge_attrs=self._edge_attrs)
def get_all_node_labels(self):
node_labels = []
for g in self._graphs: