diff --git a/gklearn/dataset/dataset.py b/gklearn/dataset/dataset.py index 75684c2..595826c 100644 --- a/gklearn/dataset/dataset.py +++ b/gklearn/dataset/dataset.py @@ -40,6 +40,7 @@ class Dataset(object): self._edge_attr_dim = None self._class_number = None self._ds_name = None + self._task_type = None if inputs is None: self._graphs = None @@ -117,11 +118,16 @@ class Dataset(object): ds_file = [os.path.join(path, fn) for fn in load_files[0]] fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None + # Get extra_params. if 'extra_params' in DATASET_META[ds_name]: kwargs = DATASET_META[ds_name]['extra_params'] else: kwargs = {} + # Get the task type that is associated with the dataset. If it is classification, get the number of classes. + self._get_task_type(ds_name) + + self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data self._node_labels = label_names['node_labels'] @@ -276,7 +282,8 @@ class Dataset(object): 'edge_attr_dim', 'class_number', 'all_degree_entropy', - 'ave_degree_entropy' + 'ave_degree_entropy', + 'class_type' ] # dataset size @@ -408,7 +415,7 @@ class Dataset(object): if 'class_number' in keys: if self._class_number is None: - self._class_number = self._get_class_number() + self._class_number = self._get_class_num() infos['class_number'] = self._class_number if 'node_attr_dim' in keys: @@ -437,6 +444,11 @@ class Dataset(object): base = None infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base)) + if 'task_type' in keys: + if self._task_type is None: + self._task_type = self._get_task_type() + infos['task_type'] = self._task_type + return infos @@ -790,6 +802,13 @@ class Dataset(object): return degree_entropy + def _get_task_type(self, ds_name): + if 'task_type' in DATASET_META[ds_name]: + self._task_type = DATASET_META[ds_name]['task_type'] + if self._task_type == 'classification' and self._class_number is None and 'class_number' in DATASET_META[ds_name]: + self._class_number = DATASET_META[ds_name]['class_number'] + + @property def graphs(self): return self._graphs