Browse Source

[update] Dataset class: Remove null graphs after loading the dataset.

v0.2.x
jajupmochi 4 years ago
parent
commit
e93ab877b3
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      gklearn/dataset/dataset.py

+ 4
- 2
gklearn/dataset/dataset.py View File

@@ -14,7 +14,7 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
class Dataset(object):


def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs):
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', remove_null_graphs=True, clean_labels=True, reload=False, verbose=False, **kwargs):
self._substructures = None
self._node_label_dim = None
self._edge_label_dim = None
@@ -82,6 +82,8 @@ class Dataset(object):
else:
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')

if remove_null_graphs:
self.trim_dataset(edge_required=False)


def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
@@ -537,7 +539,7 @@ class Dataset(object):


def trim_dataset(self, edge_required=False):
if edge_required:
if edge_required: # @todo: there is a possibility that some node labels will be removed.
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)]
else:
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0]


Loading…
Cancel
Save