Browse Source

[Enhancement] gklearn.dataset.Dataset class can now automatically get the task type of the given dataset (regression or classification).

v0.2.x
jajupmochi 4 years ago
parent
commit
b843707e9c
1 changed files with 21 additions and 2 deletions
  1. +21
    -2
      gklearn/dataset/dataset.py

+ 21
- 2
gklearn/dataset/dataset.py View File

@@ -40,6 +40,7 @@ class Dataset(object):
self._edge_attr_dim = None
self._class_number = None
self._ds_name = None
self._task_type = None

if inputs is None:
self._graphs = None
@@ -117,11 +118,16 @@ class Dataset(object):
ds_file = [os.path.join(path, fn) for fn in load_files[0]]
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None

# Get extra_params.
if 'extra_params' in DATASET_META[ds_name]:
kwargs = DATASET_META[ds_name]['extra_params']
else:
kwargs = {}

# Get the task type that is associated with the dataset. If it is classification, get the number of classes.
self._get_task_type(ds_name)


self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data

self._node_labels = label_names['node_labels']
@@ -276,7 +282,8 @@ class Dataset(object):
'edge_attr_dim',
'class_number',
'all_degree_entropy',
'ave_degree_entropy'
'ave_degree_entropy',
'class_type'
]

# dataset size
@@ -408,7 +415,7 @@ class Dataset(object):

if 'class_number' in keys:
if self._class_number is None:
self._class_number = self._get_class_number()
self._class_number = self._get_class_num()
infos['class_number'] = self._class_number

if 'node_attr_dim' in keys:
@@ -437,6 +444,11 @@ class Dataset(object):
base = None
infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base))

if 'task_type' in keys:
if self._task_type is None:
self._task_type = self._get_task_type()
infos['task_type'] = self._task_type

return infos


@@ -790,6 +802,13 @@ class Dataset(object):
return degree_entropy


def _get_task_type(self, ds_name):
if 'task_type' in DATASET_META[ds_name]:
self._task_type = DATASET_META[ds_name]['task_type']
if self._task_type == 'classification' and self._class_number is None and 'class_number' in DATASET_META[ds_name]:
self._class_number = DATASET_META[ds_name]['class_number']


@property
def graphs(self):
return self._graphs


Loading…
Cancel
Save