diff --git a/gklearn/dataset/data_fetcher.py b/gklearn/dataset/data_fetcher.py index 8f3f167..4349753 100644 --- a/gklearn/dataset/data_fetcher.py +++ b/gklearn/dataset/data_fetcher.py @@ -74,6 +74,8 @@ class DataFetcher(): message = 'Invalid Dataset name "' + self._name + '".' message += '\nAvailable datasets are as follows: \n\n' message += '\n'.join(ds for ds in sorted(DATASET_META)) + message += '\n\nFollowing special suffices can be added to the name:' + message += '\n\n' + '\n'.join(['_unlabeled']) raise ValueError(message) else: self.write_archive_file(self._name) @@ -127,9 +129,9 @@ class DataFetcher(): def write_archive_file(self, ds_name): path = osp.join(self._root, ds_name) - url = DATASET_META[ds_name]['url'] # filename_dir = osp.join(path,filename) if not osp.exists(path) or self._reload: + url = DATASET_META[ds_name]['url'] response = self.download_file(url) if response is None: return False @@ -152,7 +154,7 @@ class DataFetcher(): with tarfile.open(filename_archive, 'r:gz') as tar: if self._reload and self._verbose: print(filename + ' Downloaded.') - subpath = os.path.join(path, tar.getnames()[0]) + subpath = os.path.join(path, tar.getnames()[0].split('/')[0]) if not osp.exists(subpath) or self._reload: tar.extractall(path = path) return subpath diff --git a/gklearn/dataset/dataset.py b/gklearn/dataset/dataset.py index cf90051..6911b12 100644 --- a/gklearn/dataset/dataset.py +++ b/gklearn/dataset/dataset.py @@ -14,7 +14,33 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader class Dataset(object): - def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): + def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): + self._substructures = None + self._node_label_dim = None + self._edge_label_dim = None + self._directed = None + self._dataset_size = None + self._total_node_num = None + self._ave_node_num = None + self._min_node_num = None + self._max_node_num = None + self._total_edge_num = None + self._ave_edge_num = None + self._min_edge_num = None + self._max_edge_num = None + self._ave_node_degree = None + self._min_node_degree = None + self._max_node_degree = None + self._ave_fill_factor = None + self._min_fill_factor = None + self._max_fill_factor = None + self._node_label_nums = None + self._edge_label_nums = None + self._node_attr_dim = None + self._edge_attr_dim = None + self._class_number = None + self._ds_name = None + if inputs is None: self._graphs = None self._targets = None @@ -38,38 +64,26 @@ class Dataset(object): # If inputs is predefined dataset name. if inputs in DATASET_META: self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose) + self._ds_name = inputs + + elif inputs.endswith('_unlabeled'): + self.load_predefined_dataset(inputs[:len(inputs) - 10], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose) + self._ds_name = inputs + + # Deal with special suffices. + self.check_special_suffices() + + # If inputs is a file name. + elif os.path.isfile(inputs): + self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs) # If inputs is a file name. else: - self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs) + raise ValueError('The "inputs" argument "' + inputs + '" is not a valid dataset name or file name.') else: - raise TypeError('The "inputs" argument cannot be recoganized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') - - self._substructures = None - self._node_label_dim = None - self._edge_label_dim = None - self._directed = None - self._dataset_size = None - self._total_node_num = None - self._ave_node_num = None - self._min_node_num = None - self._max_node_num = None - self._total_edge_num = None - self._ave_edge_num = None - self._min_edge_num = None - self._max_edge_num = None - self._ave_node_degree = None - self._min_node_degree = None - self._max_node_degree = None - self._ave_fill_factor = None - self._min_fill_factor = None - self._max_fill_factor = None - self._node_label_nums = None - self._edge_label_nums = None - self._node_attr_dim = None - self._edge_attr_dim = None - self._class_number = None + raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') + def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): @@ -97,7 +111,10 @@ class Dataset(object): fn_targets = None else: load_files = DATASET_META[ds_name]['load_files'] - ds_file = os.path.join(path, load_files[0]) + if isinstance(load_files[0], str): + ds_file = os.path.join(path, load_files[0]) + else: # load_files[0] is a list of files. + ds_file = [os.path.join(path, fn) for fn in load_files[0]] fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data @@ -108,6 +125,11 @@ class Dataset(object): self._edge_attrs = label_names['edge_attrs'] if clean_labels: self.clean_labels() + + # Deal with specific datasets. + if ds_name == 'Alkane': + self.trim_dataset(edge_required=True) + self.remove_labels(node_labels=['atom_symbol']) def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]): @@ -536,6 +558,14 @@ class Dataset(object): return dataset + def check_special_suffices(self): + if self._ds_name.endswith('_unlabeled'): + self.remove_labels(node_labels=self._node_labels, + edge_labels=self._edge_labels, + node_attrs=self._node_attrs, + edge_attrs=self._edge_attrs) + + def get_all_node_labels(self): node_labels = [] for g in self._graphs: diff --git a/gklearn/dataset/file_managers.py b/gklearn/dataset/file_managers.py index 76ea9b0..d7d333b 100644 --- a/gklearn/dataset/file_managers.py +++ b/gklearn/dataset/file_managers.py @@ -38,7 +38,11 @@ class DataLoader(): for details. Note here filename is the name of either .txt file in the dataset directory. """ - extension = splitext(filename)[1][1:] + if isinstance(filename, str): + extension = splitext(filename)[1][1:] + else: # filename is a list of files. + extension = splitext(filename[0])[1][1:] + if extension == "ds": self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets) elif extension == "cxl": @@ -67,14 +71,24 @@ class DataLoader(): Note these graph formats are checked automatically by the extensions of graph files. - """ - dirname_dataset = dirname(filename) + """ + if isinstance(filename, str): + dirname_dataset = dirname(filename) + with open(filename) as f: + content = f.read().splitlines() + else: # filename is a list of files. + dirname_dataset = dirname(filename[0]) + content = [] + for fn in filename: + with open(fn) as f: + content += f.read().splitlines() + # to remove duplicate file names. + data = [] y = [] label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} - with open(filename) as fn: - content = fn.read().splitlines() - content = [line for line in content if not line.endswith('.ds')] + content = [line for line in content if not line.endswith('.ds')] # Alkane + content = [line for line in content if not line.startswith('#')] # Acyclic extension = splitext(content[0].split(' ')[0])[1][1:] if extension == 'ct': load_file_fun = self.load_ct diff --git a/gklearn/dataset/metadata.py b/gklearn/dataset/metadata.py index 9725517..13844a4 100644 --- a/gklearn/dataset/metadata.py +++ b/gklearn/dataset/metadata.py @@ -165,7 +165,7 @@ GREYC_META = { 'domain': 'small molecules', 'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'], 'stereoisomerism': False, - 'load_files': ['dataset.ds'], + 'load_files': [['trainset_0.ds', 'testset_0.ds']], }, 'PTC': { 'database': 'greyc',