Browse Source

Fix bugs in gklearn.dataset.

v0.2.x
jajupmochi 4 years ago
parent
commit
180c614b44
4 changed files with 84 additions and 38 deletions
  1. +4
    -2
      gklearn/dataset/data_fetcher.py
  2. +59
    -29
      gklearn/dataset/dataset.py
  3. +20
    -6
      gklearn/dataset/file_managers.py
  4. +1
    -1
      gklearn/dataset/metadata.py

+ 4
- 2
gklearn/dataset/data_fetcher.py View File

@@ -74,6 +74,8 @@ class DataFetcher():
message = 'Invalid Dataset name "' + self._name + '".' message = 'Invalid Dataset name "' + self._name + '".'
message += '\nAvailable datasets are as follows: \n\n' message += '\nAvailable datasets are as follows: \n\n'
message += '\n'.join(ds for ds in sorted(DATASET_META)) message += '\n'.join(ds for ds in sorted(DATASET_META))
message += '\n\nFollowing special suffices can be added to the name:'
message += '\n\n' + '\n'.join(['_unlabeled'])
raise ValueError(message) raise ValueError(message)
else: else:
self.write_archive_file(self._name) self.write_archive_file(self._name)
@@ -127,9 +129,9 @@ class DataFetcher():
def write_archive_file(self, ds_name): def write_archive_file(self, ds_name):
path = osp.join(self._root, ds_name) path = osp.join(self._root, ds_name)
url = DATASET_META[ds_name]['url']
# filename_dir = osp.join(path,filename) # filename_dir = osp.join(path,filename)
if not osp.exists(path) or self._reload: if not osp.exists(path) or self._reload:
url = DATASET_META[ds_name]['url']
response = self.download_file(url) response = self.download_file(url)
if response is None: if response is None:
return False return False
@@ -152,7 +154,7 @@ class DataFetcher():
with tarfile.open(filename_archive, 'r:gz') as tar: with tarfile.open(filename_archive, 'r:gz') as tar:
if self._reload and self._verbose: if self._reload and self._verbose:
print(filename + ' Downloaded.') print(filename + ' Downloaded.')
subpath = os.path.join(path, tar.getnames()[0])
subpath = os.path.join(path, tar.getnames()[0].split('/')[0])
if not osp.exists(subpath) or self._reload: if not osp.exists(subpath) or self._reload:
tar.extractall(path = path) tar.extractall(path = path)
return subpath return subpath


+ 59
- 29
gklearn/dataset/dataset.py View File

@@ -14,7 +14,33 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
class Dataset(object): class Dataset(object):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
self._ds_name = None
if inputs is None: if inputs is None:
self._graphs = None self._graphs = None
self._targets = None self._targets = None
@@ -38,38 +64,26 @@ class Dataset(object):
# If inputs is predefined dataset name. # If inputs is predefined dataset name.
if inputs in DATASET_META: if inputs in DATASET_META:
self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose) self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs
elif inputs.endswith('_unlabeled'):
self.load_predefined_dataset(inputs[:len(inputs) - 10], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs

# Deal with special suffices.
self.check_special_suffices()
# If inputs is a file name.
elif os.path.isfile(inputs):
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
# If inputs is a file name. # If inputs is a file name.
else: else:
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
raise ValueError('The "inputs" argument "' + inputs + '" is not a valid dataset name or file name.')
else: else:
raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
@@ -97,7 +111,10 @@ class Dataset(object):
fn_targets = None fn_targets = None
else: else:
load_files = DATASET_META[ds_name]['load_files'] load_files = DATASET_META[ds_name]['load_files']
ds_file = os.path.join(path, load_files[0])
if isinstance(load_files[0], str):
ds_file = os.path.join(path, load_files[0])
else: # load_files[0] is a list of files.
ds_file = [os.path.join(path, fn) for fn in load_files[0]]
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
@@ -108,6 +125,11 @@ class Dataset(object):
self._edge_attrs = label_names['edge_attrs'] self._edge_attrs = label_names['edge_attrs']
if clean_labels: if clean_labels:
self.clean_labels() self.clean_labels()
# Deal with specific datasets.
if ds_name == 'Alkane':
self.trim_dataset(edge_required=True)
self.remove_labels(node_labels=['atom_symbol'])


def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]): def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
@@ -536,6 +558,14 @@ class Dataset(object):
return dataset return dataset
def check_special_suffices(self):
if self._ds_name.endswith('_unlabeled'):
self.remove_labels(node_labels=self._node_labels,
edge_labels=self._edge_labels,
node_attrs=self._node_attrs,
edge_attrs=self._edge_attrs)
def get_all_node_labels(self): def get_all_node_labels(self):
node_labels = [] node_labels = []
for g in self._graphs: for g in self._graphs:


+ 20
- 6
gklearn/dataset/file_managers.py View File

@@ -38,7 +38,11 @@ class DataLoader():
for details. Note here filename is the name of either .txt file in for details. Note here filename is the name of either .txt file in
the dataset directory. the dataset directory.
""" """
extension = splitext(filename)[1][1:]
if isinstance(filename, str):
extension = splitext(filename)[1][1:]
else: # filename is a list of files.
extension = splitext(filename[0])[1][1:]
if extension == "ds": if extension == "ds":
self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets) self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)
elif extension == "cxl": elif extension == "cxl":
@@ -67,14 +71,24 @@ class DataLoader():
Note these graph formats are checked automatically by the extensions of Note these graph formats are checked automatically by the extensions of
graph files. graph files.
"""
dirname_dataset = dirname(filename)
"""
if isinstance(filename, str):
dirname_dataset = dirname(filename)
with open(filename) as f:
content = f.read().splitlines()
else: # filename is a list of files.
dirname_dataset = dirname(filename[0])
content = []
for fn in filename:
with open(fn) as f:
content += f.read().splitlines()
# to remove duplicate file names.

data = [] data = []
y = [] y = []
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
with open(filename) as fn:
content = fn.read().splitlines()
content = [line for line in content if not line.endswith('.ds')]
content = [line for line in content if not line.endswith('.ds')] # Alkane
content = [line for line in content if not line.startswith('#')] # Acyclic
extension = splitext(content[0].split(' ')[0])[1][1:] extension = splitext(content[0].split(' ')[0])[1][1:]
if extension == 'ct': if extension == 'ct':
load_file_fun = self.load_ct load_file_fun = self.load_ct


+ 1
- 1
gklearn/dataset/metadata.py View File

@@ -165,7 +165,7 @@ GREYC_META = {
'domain': 'small molecules', 'domain': 'small molecules',
'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'], 'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'],
'stereoisomerism': False, 'stereoisomerism': False,
'load_files': ['dataset.ds'],
'load_files': [['trainset_0.ds', 'testset_0.ds']],
}, },
'PTC': { 'PTC': {
'database': 'greyc', 'database': 'greyc',


Loading…
Cancel
Save