|
@@ -14,7 +14,7 @@ from gklearn.dataset import DATASET_META, DataFetcher, DataLoader |
|
|
class Dataset(object): |
|
|
class Dataset(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', clean_labels=True, reload=False, verbose=False, **kwargs): |
|
|
|
|
|
|
|
|
def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', remove_null_graphs=True, clean_labels=True, reload=False, verbose=False, **kwargs): |
|
|
self._substructures = None |
|
|
self._substructures = None |
|
|
self._node_label_dim = None |
|
|
self._node_label_dim = None |
|
|
self._edge_label_dim = None |
|
|
self._edge_label_dim = None |
|
@@ -82,6 +82,8 @@ class Dataset(object): |
|
|
else: |
|
|
else: |
|
|
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') |
|
|
raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.') |
|
|
|
|
|
|
|
|
|
|
|
if remove_null_graphs: |
|
|
|
|
|
self.trim_dataset(edge_required=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): |
|
|
def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs): |
|
@@ -537,7 +539,7 @@ class Dataset(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def trim_dataset(self, edge_required=False): |
|
|
def trim_dataset(self, edge_required=False): |
|
|
if edge_required: |
|
|
|
|
|
|
|
|
if edge_required: # @todo: there is a possibility that some node labels will be removed. |
|
|
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)] |
|
|
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)] |
|
|
else: |
|
|
else: |
|
|
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0] |
|
|
trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0] |
|
|