Browse Source

Fix bugs in gklearn.dataset.

v0.2.x
jajupmochi 4 years ago
parent
commit
180c614b44
4 changed files with 84 additions and 38 deletions
  1. +4
    -2
      gklearn/dataset/data_fetcher.py
  2. +59
    -29
      gklearn/dataset/dataset.py
  3. +20
    -6
      gklearn/dataset/file_managers.py
  4. +1
    -1
      gklearn/dataset/metadata.py

+ 4
- 2
gklearn/dataset/data_fetcher.py View File

@@ -74,6 +74,8 @@ class DataFetcher():
message = 'Invalid Dataset name "' + self._name + '".'
message += '\nAvailable datasets are as follows: \n\n'
message += '\n'.join(ds for ds in sorted(DATASET_META))
message += '\n\nFollowing special suffices can be added to the name:'
message += '\n\n' + '\n'.join(['_unlabeled'])
raise ValueError(message)
else:
self.write_archive_file(self._name)
@@ -127,9 +129,9 @@ class DataFetcher():
def write_archive_file(self, ds_name):
path = osp.join(self._root, ds_name)
url = DATASET_META[ds_name]['url']
# filename_dir = osp.join(path,filename)
if not osp.exists(path) or self._reload:
url = DATASET_META[ds_name]['url']
response = self.download_file(url)
if response is None:
return False
@@ -152,7 +154,7 @@ class DataFetcher():
with tarfile.open(filename_archive, 'r:gz') as tar:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
subpath = os.path.join(path, tar.getnames()[0])
subpath = os.path.join(path, tar.getnames()[0].split('/')[0])
if not osp.exists(subpath) or self._reload:
tar.extractall(path = path)
return subpath


+ 59
- 29
gklearn/dataset/dataset.py View File

@@ -14,7 +14,33 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
class Dataset(object):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
self._ds_name = None
if inputs is None:
self._graphs = None
self._targets = None
@@ -38,38 +64,26 @@ class Dataset(object):
# If inputs is predefined dataset name.
if inputs in DATASET_META:
self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs
elif inputs.endswith('_unlabeled'):
self.load_predefined_dataset(inputs[:len(inputs) - 10], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
self._ds_name = inputs

# Deal with special suffices.
self.check_special_suffices()
# If inputs is a file name.
elif os.path.isfile(inputs):
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
# If inputs is a file name.
else:
self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
raise ValueError('The "inputs" argument "' + inputs + '" is not a valid dataset name or file name.')
else:
raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
self._directed = None
self._dataset_size = None
self._total_node_num = None
self._ave_node_num = None
self._min_node_num = None
self._max_node_num = None
self._total_edge_num = None
self._ave_edge_num = None
self._min_edge_num = None
self._max_edge_num = None
self._ave_node_degree = None
self._min_node_degree = None
self._max_node_degree = None
self._ave_fill_factor = None
self._min_fill_factor = None
self._max_fill_factor = None
self._node_label_nums = None
self._edge_label_nums = None
self._node_attr_dim = None
self._edge_attr_dim = None
self._class_number = None
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
@@ -97,7 +111,10 @@ class Dataset(object):
fn_targets = None
else:
load_files = DATASET_META[ds_name]['load_files']
ds_file = os.path.join(path, load_files[0])
if isinstance(load_files[0], str):
ds_file = os.path.join(path, load_files[0])
else: # load_files[0] is a list of files.
ds_file = [os.path.join(path, fn) for fn in load_files[0]]
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
@@ -108,6 +125,11 @@ class Dataset(object):
self._edge_attrs = label_names['edge_attrs']
if clean_labels:
self.clean_labels()
# Deal with specific datasets.
if ds_name == 'Alkane':
self.trim_dataset(edge_required=True)
self.remove_labels(node_labels=['atom_symbol'])

def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
@@ -536,6 +558,14 @@ class Dataset(object):
return dataset
def check_special_suffices(self):
if self._ds_name.endswith('_unlabeled'):
self.remove_labels(node_labels=self._node_labels,
edge_labels=self._edge_labels,
node_attrs=self._node_attrs,
edge_attrs=self._edge_attrs)
def get_all_node_labels(self):
node_labels = []
for g in self._graphs:


+ 20
- 6
gklearn/dataset/file_managers.py View File

@@ -38,7 +38,11 @@ class DataLoader():
for details. Note here filename is the name of either .txt file in
the dataset directory.
"""
extension = splitext(filename)[1][1:]
if isinstance(filename, str):
extension = splitext(filename)[1][1:]
else: # filename is a list of files.
extension = splitext(filename[0])[1][1:]
if extension == "ds":
self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)
elif extension == "cxl":
@@ -67,14 +71,24 @@ class DataLoader():
Note these graph formats are checked automatically by the extensions of
graph files.
"""
dirname_dataset = dirname(filename)
"""
if isinstance(filename, str):
dirname_dataset = dirname(filename)
with open(filename) as f:
content = f.read().splitlines()
else: # filename is a list of files.
dirname_dataset = dirname(filename[0])
content = []
for fn in filename:
with open(fn) as f:
content += f.read().splitlines()
# to remove duplicate file names.

data = []
y = []
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
with open(filename) as fn:
content = fn.read().splitlines()
content = [line for line in content if not line.endswith('.ds')]
content = [line for line in content if not line.endswith('.ds')] # Alkane
content = [line for line in content if not line.startswith('#')] # Acyclic
extension = splitext(content[0].split(' ')[0])[1][1:]
if extension == 'ct':
load_file_fun = self.load_ct


+ 1
- 1
gklearn/dataset/metadata.py View File

@@ -165,7 +165,7 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'],
'stereoisomerism': False,
'load_files': ['dataset.ds'],
'load_files': [['trainset_0.ds', 'testset_0.ds']],
},
'PTC': {
'database': 'greyc',


Loading…
Cancel
Save