diff --git a/gklearn/dataset/dataset.py b/gklearn/dataset/dataset.py index faca89b..75684c2 100644 --- a/gklearn/dataset/dataset.py +++ b/gklearn/dataset/dataset.py @@ -14,7 +14,7 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader class Dataset(object): - def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): + def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', remove_null_graphs=True, clean_labels=True, reload=False, verbose=False, **kwargs): self._substructures = None self._node_label_dim = None self._edge_label_dim = None @@ -82,6 +82,8 @@ class Dataset(object): else: raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') + if remove_null_graphs: + self.trim_dataset(edge_required=False) def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): @@ -537,7 +539,7 @@ class Dataset(object): def trim_dataset(self, edge_required=False): - if edge_required: + if edge_required: # @todo: there is a possibility that some node labels will be removed. trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)] else: trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0]